diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000000000000000000000000000000000000..5dbb42189a562c5a78c6c67a71cf7371412745a4 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,31 @@ +# This is a basic workflow to help you get started with Actions + +name: CI + +# Controls when the workflow will run +on: + # Triggers the workflow on push or pull request events but only for the main branch + push: + branches: [ main ] + pull_request: + branches: [ main ] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# A workflow run is made up of one or more jobs that can run sequentially or in parallel +jobs: + # This workflow contains a single job called "build" + build: + # The type of runner that the job will run on + runs-on: ubuntu-latest + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + + - name: Checkout code + uses: actions/checkout@v2 + + # Runs a single command using the runners shell + - name: docker compose + run: docker-compose up -d diff --git a/venv/lib/python3.13/site-packages/fsspec-2025.10.0.dist-info/licenses/LICENSE b/venv/lib/python3.13/site-packages/fsspec-2025.10.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..67590a5e5be5a5a2dde3fe53a7512e404a896c22 --- /dev/null +++ b/venv/lib/python3.13/site-packages/fsspec-2025.10.0.dist-info/licenses/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2018, Martin Durant +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/venv/lib/python3.13/site-packages/idna/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/idna/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e75d5e37c9139ec47b98db86b71032d2a2cc97a4 Binary files /dev/null and b/venv/lib/python3.13/site-packages/idna/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/idna/__pycache__/codec.cpython-313.pyc b/venv/lib/python3.13/site-packages/idna/__pycache__/codec.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f06f04e9b94df3e8348a6eebb02b3a1d453ad844 Binary files /dev/null and b/venv/lib/python3.13/site-packages/idna/__pycache__/codec.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/idna/__pycache__/compat.cpython-313.pyc b/venv/lib/python3.13/site-packages/idna/__pycache__/compat.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e30b66a5e641ea55699d2c5a0f4533b029e0dde6 Binary files /dev/null and b/venv/lib/python3.13/site-packages/idna/__pycache__/compat.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/idna/__pycache__/core.cpython-313.pyc b/venv/lib/python3.13/site-packages/idna/__pycache__/core.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7084c01a909ee44b9e73c3a3636f50e8b660ad50 Binary files /dev/null and b/venv/lib/python3.13/site-packages/idna/__pycache__/core.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/idna/__pycache__/intranges.cpython-313.pyc b/venv/lib/python3.13/site-packages/idna/__pycache__/intranges.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eac04d2daf7dd48cbe7a9730eab4fd9788b0ead7 Binary files /dev/null and b/venv/lib/python3.13/site-packages/idna/__pycache__/intranges.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/idna/__pycache__/package_data.cpython-313.pyc b/venv/lib/python3.13/site-packages/idna/__pycache__/package_data.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..526adb365cd8f175fd2dbde3eb3015dd705dfc7a Binary files /dev/null and b/venv/lib/python3.13/site-packages/idna/__pycache__/package_data.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db791c30f4badb4e3f58724f13e5d09ec6dfd362 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_convertions.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_convertions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a1c95da6963da2b651d248cd3fb586cabf832c3 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_convertions.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_inspect.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_inspect.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ecb94caceb41f2693f9f1898437f23ed012fc0 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_inspect.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_pep440.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_pep440.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b39b12c0566c5f5eeabab0922a315ff386cb4b Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/_utils/__pycache__/_pep440.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..235a8b57846852db43001fbc44e7d9876e6b8d8b Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/extras.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/extras.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67522d09278a0da5b8ab8345567662170f3ee4b Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/extras.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/mrecords.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/mrecords.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800b10685dfee406345e5420536ef858a01ac3c8 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/mrecords.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/testutils.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/testutils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed99717a712e6944fdd2806fa757d1f4ae19c119 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/__pycache__/testutils.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__init__.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd6225eea6f702206ddc120a2e5e140d5aaff3c5 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_arrayobject.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_arrayobject.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..206467309390605aca09a04a91b4e588d79aac23 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_arrayobject.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_deprecations.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_deprecations.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d24cb64c7844c5773aabf34c78b6bf3a87ec50d8 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_deprecations.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_mrecords.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_mrecords.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b24c14770ba8a06fd656102cde00515c0897d4d Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_mrecords.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_old_ma.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_old_ma.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..048205cba132c2a49e42950bd3c7daea5ac00615 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_old_ma.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_regression.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_regression.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e570e91d1555f349cd1894afa2e1e4d18981387 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_regression.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_subclassing.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_subclassing.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb677c4abb6159828329f0b91524780bf458e3dd Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/ma/tests/__pycache__/test_subclassing.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/test_arrayobject.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_arrayobject.py new file mode 100644 index 0000000000000000000000000000000000000000..2000cea226fce9dea890553601594d7c1255d77c --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_arrayobject.py @@ -0,0 +1,40 @@ +import pytest + +import numpy as np +from numpy.ma import masked_array +from numpy.testing import assert_array_equal + + +def test_matrix_transpose_raises_error_for_1d(): + msg = "matrix transpose with ndim < 2 is undefined" + ma_arr = masked_array(data=[1, 2, 3, 4, 5, 6], + mask=[1, 0, 1, 1, 1, 0]) + with pytest.raises(ValueError, match=msg): + ma_arr.mT + + +def test_matrix_transpose_equals_transpose_2d(): + ma_arr = masked_array(data=[[1, 2, 3], [4, 5, 6]], + mask=[[1, 0, 1], [1, 1, 0]]) + assert_array_equal(ma_arr.T, ma_arr.mT) + + +ARRAY_SHAPES_TO_TEST = ( + (5, 2), + (5, 2, 3), + (5, 2, 3, 4), +) + + +@pytest.mark.parametrize("shape", ARRAY_SHAPES_TO_TEST) +def test_matrix_transpose_equals_swapaxes(shape): + num_of_axes = len(shape) + vec = np.arange(shape[-1]) + arr = np.broadcast_to(vec, shape) + + rng = np.random.default_rng(42) + mask = rng.choice([0, 1], size=shape) + ma_arr = masked_array(data=arr, mask=mask) + + tgt = np.swapaxes(arr, num_of_axes - 2, num_of_axes - 1) + assert_array_equal(tgt, ma_arr.mT) diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/test_core.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..091ba6c99fff1fe52fa1a67c753bec9d62bfb509 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_core.py @@ -0,0 +1,5886 @@ +"""Tests suite for MaskedArray & subclassing. + +:author: Pierre Gerard-Marchant +:contact: pierregm_at_uga_dot_edu +""" +__author__ = "Pierre GF Gerard-Marchant" + +import copy +import itertools +import operator +import pickle +import sys +import textwrap +import warnings +from functools import reduce + +import pytest + +import numpy as np +import numpy._core.fromnumeric as fromnumeric +import numpy._core.umath as umath +import numpy.ma.core +from numpy import ndarray +from numpy._utils import asbytes +from numpy.exceptions import AxisError +from numpy.ma.core import ( + MAError, + MaskedArray, + MaskError, + MaskType, + abs, + absolute, + add, + all, + allclose, + allequal, + alltrue, + angle, + anom, + arange, + arccos, + arccosh, + arcsin, + arctan, + arctan2, + argsort, + array, + asarray, + choose, + concatenate, + conjugate, + cos, + cosh, + count, + default_fill_value, + diag, + divide, + empty, + empty_like, + equal, + exp, + filled, + fix_invalid, + flatten_mask, + flatten_structured_array, + fromflex, + getmask, + getmaskarray, + greater, + greater_equal, + identity, + inner, + isMaskedArray, + less, + less_equal, + log, + log10, + make_mask, + make_mask_descr, + mask_or, + masked, + masked_array, + masked_equal, + masked_greater, + masked_greater_equal, + masked_inside, + masked_less, + masked_less_equal, + masked_not_equal, + masked_outside, + masked_print_option, + masked_values, + masked_where, + max, + maximum, + maximum_fill_value, + min, + minimum, + minimum_fill_value, + mod, + multiply, + mvoid, + nomask, + not_equal, + ones, + ones_like, + outer, + power, + product, + put, + putmask, + ravel, + repeat, + reshape, + resize, + shape, + sin, + sinh, + sometrue, + sort, + sqrt, + subtract, + sum, + take, + tan, + tanh, + transpose, + where, + zeros, + zeros_like, +) +from numpy.ma.testutils import ( + assert_, + assert_almost_equal, + assert_array_equal, + assert_equal, + assert_equal_records, + assert_mask_equal, + assert_not_equal, + fail_if_equal, +) +from numpy.testing import ( + IS_WASM, + assert_raises, + assert_warns, + suppress_warnings, + temppath, +) +from numpy.testing._private.utils import requires_memory + +pi = np.pi + + +suppress_copy_mask_on_assignment = suppress_warnings() +suppress_copy_mask_on_assignment.filter( + numpy.ma.core.MaskedArrayFutureWarning, + "setting an item on a masked array which has a shared mask will not copy") + + +# For parametrized numeric testing +num_dts = [np.dtype(dt_) for dt_ in '?bhilqBHILQefdgFD'] +num_ids = [dt_.char for dt_ in num_dts] + + +class TestMaskedArray: + # Base test class for MaskedArrays. + + def setup_method(self): + # Base data definition. + x = np.array([1., 1., 1., -2., pi / 2.0, 4., 5., -10., 10., 1., 2., 3.]) + y = np.array([5., 0., 3., 2., -1., -4., 0., -10., 10., 1., 0., 3.]) + a10 = 10. + m1 = [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0] + m2 = [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1] + xm = masked_array(x, mask=m1) + ym = masked_array(y, mask=m2) + z = np.array([-.5, 0., .5, .8]) + zm = masked_array(z, mask=[0, 1, 0, 0]) + xf = np.where(m1, 1e+20, x) + xm.set_fill_value(1e+20) + self.d = (x, y, a10, m1, m2, xm, ym, z, zm, xf) + + def test_basicattributes(self): + # Tests some basic array attributes. + a = array([1, 3, 2]) + b = array([1, 3, 2], mask=[1, 0, 1]) + assert_equal(a.ndim, 1) + assert_equal(b.ndim, 1) + assert_equal(a.size, 3) + assert_equal(b.size, 3) + assert_equal(a.shape, (3,)) + assert_equal(b.shape, (3,)) + + def test_basic0d(self): + # Checks masking a scalar + x = masked_array(0) + assert_equal(str(x), '0') + x = masked_array(0, mask=True) + assert_equal(str(x), str(masked_print_option)) + x = masked_array(0, mask=False) + assert_equal(str(x), '0') + x = array(0, mask=1) + assert_(x.filled().dtype is x._data.dtype) + + def test_basic1d(self): + # Test of basic array creation and properties in 1 dimension. + (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d + assert_(not isMaskedArray(x)) + assert_(isMaskedArray(xm)) + assert_((xm - ym).filled(0).any()) + fail_if_equal(xm.mask.astype(int), ym.mask.astype(int)) + s = x.shape + assert_equal(np.shape(xm), s) + assert_equal(xm.shape, s) + assert_equal(xm.dtype, x.dtype) + assert_equal(zm.dtype, z.dtype) + assert_equal(xm.size, reduce(lambda x, y: x * y, s)) + assert_equal(count(xm), len(m1) - reduce(lambda x, y: x + y, m1)) + assert_array_equal(xm, xf) + assert_array_equal(filled(xm, 1.e20), xf) + assert_array_equal(x, xm) + + def test_basic2d(self): + # Test of basic array creation and properties in 2 dimensions. + (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d + for s in [(4, 3), (6, 2)]: + x.shape = s + y.shape = s + xm.shape = s + ym.shape = s + xf.shape = s + + assert_(not isMaskedArray(x)) + assert_(isMaskedArray(xm)) + assert_equal(shape(xm), s) + assert_equal(xm.shape, s) + assert_equal(xm.size, reduce(lambda x, y: x * y, s)) + assert_equal(count(xm), len(m1) - reduce(lambda x, y: x + y, m1)) + assert_equal(xm, xf) + assert_equal(filled(xm, 1.e20), xf) + assert_equal(x, xm) + + def test_concatenate_basic(self): + # Tests concatenations. + (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d + # basic concatenation + assert_equal(np.concatenate((x, y)), concatenate((xm, ym))) + assert_equal(np.concatenate((x, y)), concatenate((x, y))) + assert_equal(np.concatenate((x, y)), concatenate((xm, y))) + assert_equal(np.concatenate((x, y, x)), concatenate((x, ym, x))) + + def test_concatenate_alongaxis(self): + # Tests concatenations. + (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d + # Concatenation along an axis + s = (3, 4) + x.shape = y.shape = xm.shape = ym.shape = s + assert_equal(xm.mask, np.reshape(m1, s)) + assert_equal(ym.mask, np.reshape(m2, s)) + xmym = concatenate((xm, ym), 1) + assert_equal(np.concatenate((x, y), 1), xmym) + assert_equal(np.concatenate((xm.mask, ym.mask), 1), xmym._mask) + + x = zeros(2) + y = array(ones(2), mask=[False, True]) + z = concatenate((x, y)) + assert_array_equal(z, [0, 0, 1, 1]) + assert_array_equal(z.mask, [False, False, False, True]) + z = concatenate((y, x)) + assert_array_equal(z, [1, 1, 0, 0]) + assert_array_equal(z.mask, [False, True, False, False]) + + def test_concatenate_flexible(self): + # Tests the concatenation on flexible arrays. + data = masked_array(list(zip(np.random.rand(10), + np.arange(10))), + dtype=[('a', float), ('b', int)]) + + test = concatenate([data[:5], data[5:]]) + assert_equal_records(test, data) + + def test_creation_ndmin(self): + # Check the use of ndmin + x = array([1, 2, 3], mask=[1, 0, 0], ndmin=2) + assert_equal(x.shape, (1, 3)) + assert_equal(x._data, [[1, 2, 3]]) + assert_equal(x._mask, [[1, 0, 0]]) + + def test_creation_ndmin_from_maskedarray(self): + # Make sure we're not losing the original mask w/ ndmin + x = array([1, 2, 3]) + x[-1] = masked + xx = array(x, ndmin=2, dtype=float) + assert_equal(x.shape, x._mask.shape) + assert_equal(xx.shape, xx._mask.shape) + + def test_creation_maskcreation(self): + # Tests how masks are initialized at the creation of Maskedarrays. + data = arange(24, dtype=float) + data[[3, 6, 15]] = masked + dma_1 = MaskedArray(data) + assert_equal(dma_1.mask, data.mask) + dma_2 = MaskedArray(dma_1) + assert_equal(dma_2.mask, dma_1.mask) + dma_3 = MaskedArray(dma_1, mask=[1, 0, 0, 0] * 6) + fail_if_equal(dma_3.mask, dma_1.mask) + + x = array([1, 2, 3], mask=True) + assert_equal(x._mask, [True, True, True]) + x = array([1, 2, 3], mask=False) + assert_equal(x._mask, [False, False, False]) + y = array([1, 2, 3], mask=x._mask, copy=False) + assert_(np.may_share_memory(x.mask, y.mask)) + y = array([1, 2, 3], mask=x._mask, copy=True) + assert_(not np.may_share_memory(x.mask, y.mask)) + x = array([1, 2, 3], mask=None) + assert_equal(x._mask, [False, False, False]) + + def test_masked_singleton_array_creation_warns(self): + # The first works, but should not (ideally), there may be no way + # to solve this, however, as long as `np.ma.masked` is an ndarray. + np.array(np.ma.masked) + with pytest.warns(UserWarning): + # Tries to create a float array, using `float(np.ma.masked)`. + # We may want to define this is invalid behaviour in the future! + # (requiring np.ma.masked to be a known NumPy scalar probably + # with a DType.) + np.array([3., np.ma.masked]) + + def test_creation_with_list_of_maskedarrays(self): + # Tests creating a masked array from a list of masked arrays. + x = array(np.arange(5), mask=[1, 0, 0, 0, 0]) + data = array((x, x[::-1])) + assert_equal(data, [[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]) + assert_equal(data._mask, [[1, 0, 0, 0, 0], [0, 0, 0, 0, 1]]) + + x.mask = nomask + data = array((x, x[::-1])) + assert_equal(data, [[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]) + assert_(data.mask is nomask) + + def test_creation_with_list_of_maskedarrays_no_bool_cast(self): + # Tests the regression in gh-18551 + masked_str = np.ma.masked_array(['a', 'b'], mask=[True, False]) + normal_int = np.arange(2) + res = np.ma.asarray([masked_str, normal_int], dtype="U21") + assert_array_equal(res.mask, [[True, False], [False, False]]) + + # The above only failed due a long chain of oddity, try also with + # an object array that cannot be converted to bool always: + class NotBool: + def __bool__(self): + raise ValueError("not a bool!") + masked_obj = np.ma.masked_array([NotBool(), 'b'], mask=[True, False]) + # Check that the NotBool actually fails like we would expect: + with pytest.raises(ValueError, match="not a bool!"): + np.asarray([masked_obj], dtype=bool) + + res = np.ma.asarray([masked_obj, normal_int]) + assert_array_equal(res.mask, [[True, False], [False, False]]) + + def test_creation_from_ndarray_with_padding(self): + x = np.array([('A', 0)], dtype={'names': ['f0', 'f1'], + 'formats': ['S4', 'i8'], + 'offsets': [0, 8]}) + array(x) # used to fail due to 'V' padding field in x.dtype.descr + + def test_unknown_keyword_parameter(self): + with pytest.raises(TypeError, match="unexpected keyword argument"): + MaskedArray([1, 2, 3], maks=[0, 1, 0]) # `mask` is misspelled. + + def test_asarray(self): + (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d + xm.fill_value = -9999 + xm._hardmask = True + xmm = asarray(xm) + assert_equal(xmm._data, xm._data) + assert_equal(xmm._mask, xm._mask) + assert_equal(xmm.fill_value, xm.fill_value) + assert_equal(xmm._hardmask, xm._hardmask) + + def test_asarray_default_order(self): + # See Issue #6646 + m = np.eye(3).T + assert_(not m.flags.c_contiguous) + + new_m = asarray(m) + assert_(new_m.flags.c_contiguous) + + def test_asarray_enforce_order(self): + # See Issue #6646 + m = np.eye(3).T + assert_(not m.flags.c_contiguous) + + new_m = asarray(m, order='C') + assert_(new_m.flags.c_contiguous) + + def test_fix_invalid(self): + # Checks fix_invalid. + with np.errstate(invalid='ignore'): + data = masked_array([np.nan, 0., 1.], mask=[0, 0, 1]) + data_fixed = fix_invalid(data) + assert_equal(data_fixed._data, [data.fill_value, 0., 1.]) + assert_equal(data_fixed._mask, [1., 0., 1.]) + + def test_maskedelement(self): + # Test of masked element + x = arange(6) + x[1] = masked + assert_(str(masked) == '--') + assert_(x[1] is masked) + assert_equal(filled(x[1], 0), 0) + + def test_set_element_as_object(self): + # Tests setting elements with object + a = empty(1, dtype=object) + x = (1, 2, 3, 4, 5) + a[0] = x + assert_equal(a[0], x) + assert_(a[0] is x) + + import datetime + dt = datetime.datetime.now() + a[0] = dt + assert_(a[0] is dt) + + def test_indexing(self): + # Tests conversions and indexing + x1 = np.array([1, 2, 4, 3]) + x2 = array(x1, mask=[1, 0, 0, 0]) + x3 = array(x1, mask=[0, 1, 0, 1]) + x4 = array(x1) + # test conversion to strings + str(x2) # raises? + repr(x2) # raises? + assert_equal(np.sort(x1), sort(x2, endwith=False)) + # tests of indexing + assert_(type(x2[1]) is type(x1[1])) + assert_(x1[1] == x2[1]) + assert_(x2[0] is masked) + assert_equal(x1[2], x2[2]) + assert_equal(x1[2:5], x2[2:5]) + assert_equal(x1[:], x2[:]) + assert_equal(x1[1:], x3[1:]) + x1[2] = 9 + x2[2] = 9 + assert_equal(x1, x2) + x1[1:3] = 99 + x2[1:3] = 99 + assert_equal(x1, x2) + x2[1] = masked + assert_equal(x1, x2) + x2[1:3] = masked + assert_equal(x1, x2) + x2[:] = x1 + x2[1] = masked + assert_(allequal(getmask(x2), array([0, 1, 0, 0]))) + x3[:] = masked_array([1, 2, 3, 4], [0, 1, 1, 0]) + assert_(allequal(getmask(x3), array([0, 1, 1, 0]))) + x4[:] = masked_array([1, 2, 3, 4], [0, 1, 1, 0]) + assert_(allequal(getmask(x4), array([0, 1, 1, 0]))) + assert_(allequal(x4, array([1, 2, 3, 4]))) + x1 = np.arange(5) * 1.0 + x2 = masked_values(x1, 3.0) + assert_equal(x1, x2) + assert_(allequal(array([0, 0, 0, 1, 0], MaskType), x2.mask)) + assert_equal(3.0, x2.fill_value) + x1 = array([1, 'hello', 2, 3], object) + x2 = np.array([1, 'hello', 2, 3], object) + s1 = x1[1] + s2 = x2[1] + assert_equal(type(s2), str) + assert_equal(type(s1), str) + assert_equal(s1, s2) + assert_(x1[1:1].shape == (0,)) + + def test_setitem_no_warning(self): + # Setitem shouldn't warn, because the assignment might be masked + # and warning for a masked assignment is weird (see gh-23000) + # (When the value is masked, otherwise a warning would be acceptable + # but is not given currently.) + x = np.ma.arange(60).reshape((6, 10)) + index = (slice(1, 5, 2), [7, 5]) + value = np.ma.masked_all((2, 2)) + value._data[...] = np.inf # not a valid integer... + x[index] = value + # The masked scalar is special cased, but test anyway (it's NaN): + x[...] = np.ma.masked + # Finally, a large value that cannot be cast to the float32 `x` + x = np.ma.arange(3., dtype=np.float32) + value = np.ma.array([2e234, 1, 1], mask=[True, False, False]) + x[...] = value + x[[0, 1, 2]] = value + + @suppress_copy_mask_on_assignment + def test_copy(self): + # Tests of some subtle points of copying and sizing. + n = [0, 0, 1, 0, 0] + m = make_mask(n) + m2 = make_mask(m) + assert_(m is m2) + m3 = make_mask(m, copy=True) + assert_(m is not m3) + + x1 = np.arange(5) + y1 = array(x1, mask=m) + assert_equal(y1._data.__array_interface__, x1.__array_interface__) + assert_(allequal(x1, y1.data)) + assert_equal(y1._mask.__array_interface__, m.__array_interface__) + + y1a = array(y1) + # Default for masked array is not to copy; see gh-10318. + assert_(y1a._data.__array_interface__ == + y1._data.__array_interface__) + assert_(y1a._mask.__array_interface__ == + y1._mask.__array_interface__) + + y2 = array(x1, mask=m3) + assert_(y2._data.__array_interface__ == x1.__array_interface__) + assert_(y2._mask.__array_interface__ == m3.__array_interface__) + assert_(y2[2] is masked) + y2[2] = 9 + assert_(y2[2] is not masked) + assert_(y2._mask.__array_interface__ == m3.__array_interface__) + assert_(allequal(y2.mask, 0)) + + y2a = array(x1, mask=m, copy=1) + assert_(y2a._data.__array_interface__ != x1.__array_interface__) + #assert_( y2a._mask is not m) + assert_(y2a._mask.__array_interface__ != m.__array_interface__) + assert_(y2a[2] is masked) + y2a[2] = 9 + assert_(y2a[2] is not masked) + #assert_( y2a._mask is not m) + assert_(y2a._mask.__array_interface__ != m.__array_interface__) + assert_(allequal(y2a.mask, 0)) + + y3 = array(x1 * 1.0, mask=m) + assert_(filled(y3).dtype is (x1 * 1.0).dtype) + + x4 = arange(4) + x4[2] = masked + y4 = resize(x4, (8,)) + assert_equal(concatenate([x4, x4]), y4) + assert_equal(getmask(y4), [0, 0, 1, 0, 0, 0, 1, 0]) + y5 = repeat(x4, (2, 2, 2, 2), axis=0) + assert_equal(y5, [0, 0, 1, 1, 2, 2, 3, 3]) + y6 = repeat(x4, 2, axis=0) + assert_equal(y5, y6) + y7 = x4.repeat((2, 2, 2, 2), axis=0) + assert_equal(y5, y7) + y8 = x4.repeat(2, 0) + assert_equal(y5, y8) + + y9 = x4.copy() + assert_equal(y9._data, x4._data) + assert_equal(y9._mask, x4._mask) + + x = masked_array([1, 2, 3], mask=[0, 1, 0]) + # Copy is False by default + y = masked_array(x) + assert_equal(y._data.ctypes.data, x._data.ctypes.data) + assert_equal(y._mask.ctypes.data, x._mask.ctypes.data) + y = masked_array(x, copy=True) + assert_not_equal(y._data.ctypes.data, x._data.ctypes.data) + assert_not_equal(y._mask.ctypes.data, x._mask.ctypes.data) + + def test_copy_0d(self): + # gh-9430 + x = np.ma.array(43, mask=True) + xc = x.copy() + assert_equal(xc.mask, True) + + def test_copy_on_python_builtins(self): + # Tests copy works on python builtins (issue#8019) + assert_(isMaskedArray(np.ma.copy([1, 2, 3]))) + assert_(isMaskedArray(np.ma.copy((1, 2, 3)))) + + def test_copy_immutable(self): + # Tests that the copy method is immutable, GitHub issue #5247 + a = np.ma.array([1, 2, 3]) + b = np.ma.array([4, 5, 6]) + a_copy_method = a.copy + b.copy + assert_equal(a_copy_method(), [1, 2, 3]) + + def test_deepcopy(self): + from copy import deepcopy + a = array([0, 1, 2], mask=[False, True, False]) + copied = deepcopy(a) + assert_equal(copied.mask, a.mask) + assert_not_equal(id(a._mask), id(copied._mask)) + + copied[1] = 1 + assert_equal(copied.mask, [0, 0, 0]) + assert_equal(a.mask, [0, 1, 0]) + + copied = deepcopy(a) + assert_equal(copied.mask, a.mask) + copied.mask[1] = False + assert_equal(copied.mask, [0, 0, 0]) + assert_equal(a.mask, [0, 1, 0]) + + def test_format(self): + a = array([0, 1, 2], mask=[False, True, False]) + assert_equal(format(a), "[0 -- 2]") + assert_equal(format(masked), "--") + assert_equal(format(masked, ""), "--") + + # Postponed from PR #15410, perhaps address in the future. + # assert_equal(format(masked, " >5"), " --") + # assert_equal(format(masked, " <5"), "-- ") + + # Expect a FutureWarning for using format_spec with MaskedElement + with assert_warns(FutureWarning): + with_format_string = format(masked, " >5") + assert_equal(with_format_string, "--") + + def test_str_repr(self): + a = array([0, 1, 2], mask=[False, True, False]) + assert_equal(str(a), '[0 -- 2]') + assert_equal( + repr(a), + textwrap.dedent('''\ + masked_array(data=[0, --, 2], + mask=[False, True, False], + fill_value=999999)''') + ) + + # arrays with a continuation + a = np.ma.arange(2000) + a[1:50] = np.ma.masked + assert_equal( + repr(a), + textwrap.dedent('''\ + masked_array(data=[0, --, --, ..., 1997, 1998, 1999], + mask=[False, True, True, ..., False, False, False], + fill_value=999999)''') + ) + + # line-wrapped 1d arrays are correctly aligned + a = np.ma.arange(20) + assert_equal( + repr(a), + textwrap.dedent('''\ + masked_array(data=[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19], + mask=False, + fill_value=999999)''') + ) + + # 2d arrays cause wrapping + a = array([[1, 2, 3], [4, 5, 6]], dtype=np.int8) + a[1, 1] = np.ma.masked + assert_equal( + repr(a), + textwrap.dedent(f'''\ + masked_array( + data=[[1, 2, 3], + [4, --, 6]], + mask=[[False, False, False], + [False, True, False]], + fill_value={np.array(999999)[()]!r}, + dtype=int8)''') + ) + + # but not it they're a row vector + assert_equal( + repr(a[:1]), + textwrap.dedent(f'''\ + masked_array(data=[[1, 2, 3]], + mask=[[False, False, False]], + fill_value={np.array(999999)[()]!r}, + dtype=int8)''') + ) + + # dtype=int is implied, so not shown + assert_equal( + repr(a.astype(int)), + textwrap.dedent('''\ + masked_array( + data=[[1, 2, 3], + [4, --, 6]], + mask=[[False, False, False], + [False, True, False]], + fill_value=999999)''') + ) + + def test_str_repr_legacy(self): + oldopts = np.get_printoptions() + np.set_printoptions(legacy='1.13') + try: + a = array([0, 1, 2], mask=[False, True, False]) + assert_equal(str(a), '[0 -- 2]') + assert_equal(repr(a), 'masked_array(data = [0 -- 2],\n' + ' mask = [False True False],\n' + ' fill_value = 999999)\n') + + a = np.ma.arange(2000) + a[1:50] = np.ma.masked + assert_equal( + repr(a), + 'masked_array(data = [0 -- -- ..., 1997 1998 1999],\n' + ' mask = [False True True ..., False False False],\n' + ' fill_value = 999999)\n' + ) + finally: + np.set_printoptions(**oldopts) + + def test_0d_unicode(self): + u = 'caf\xe9' + utype = type(u) + + arr_nomask = np.ma.array(u) + arr_masked = np.ma.array(u, mask=True) + + assert_equal(utype(arr_nomask), u) + assert_equal(utype(arr_masked), '--') + + def test_pickling(self): + # Tests pickling + for dtype in (int, float, str, object): + a = arange(10).astype(dtype) + a.fill_value = 999 + + masks = ([0, 0, 0, 1, 0, 1, 0, 1, 0, 1], # partially masked + True, # Fully masked + False) # Fully unmasked + + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + for mask in masks: + a.mask = mask + a_pickled = pickle.loads(pickle.dumps(a, protocol=proto)) + assert_equal(a_pickled._mask, a._mask) + assert_equal(a_pickled._data, a._data) + if dtype in (object, int): + assert_equal(a_pickled.fill_value, 999) + else: + assert_equal(a_pickled.fill_value, dtype(999)) + assert_array_equal(a_pickled.mask, mask) + + def test_pickling_subbaseclass(self): + # Test pickling w/ a subclass of ndarray + x = np.array([(1.0, 2), (3.0, 4)], + dtype=[('x', float), ('y', int)]).view(np.recarray) + a = masked_array(x, mask=[(True, False), (False, True)]) + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + a_pickled = pickle.loads(pickle.dumps(a, protocol=proto)) + assert_equal(a_pickled._mask, a._mask) + assert_equal(a_pickled, a) + assert_(isinstance(a_pickled._data, np.recarray)) + + def test_pickling_maskedconstant(self): + # Test pickling MaskedConstant + mc = np.ma.masked + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + mc_pickled = pickle.loads(pickle.dumps(mc, protocol=proto)) + assert_equal(mc_pickled._baseclass, mc._baseclass) + assert_equal(mc_pickled._mask, mc._mask) + assert_equal(mc_pickled._data, mc._data) + + def test_pickling_wstructured(self): + # Tests pickling w/ structured array + a = array([(1, 1.), (2, 2.)], mask=[(0, 0), (0, 1)], + dtype=[('a', int), ('b', float)]) + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + a_pickled = pickle.loads(pickle.dumps(a, protocol=proto)) + assert_equal(a_pickled._mask, a._mask) + assert_equal(a_pickled, a) + + def test_pickling_keepalignment(self): + # Tests pickling w/ F_CONTIGUOUS arrays + a = arange(10) + a.shape = (-1, 2) + b = a.T + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + test = pickle.loads(pickle.dumps(b, protocol=proto)) + assert_equal(test, b) + + def test_single_element_subscript(self): + # Tests single element subscripts of Maskedarrays. + a = array([1, 3, 2]) + b = array([1, 3, 2], mask=[1, 0, 1]) + assert_equal(a[0].shape, ()) + assert_equal(b[0].shape, ()) + assert_equal(b[1].shape, ()) + + def test_topython(self): + # Tests some communication issues with Python. + assert_equal(1, int(array(1))) + assert_equal(1.0, float(array(1))) + assert_equal(1, int(array([[[1]]]))) + assert_equal(1.0, float(array([[1]]))) + assert_raises(TypeError, float, array([1, 1])) + + with suppress_warnings() as sup: + sup.filter(UserWarning, 'Warning: converting a masked element') + assert_(np.isnan(float(array([1], mask=[1])))) + + a = array([1, 2, 3], mask=[1, 0, 0]) + assert_raises(TypeError, lambda: float(a)) + assert_equal(float(a[-1]), 3.) + assert_(np.isnan(float(a[0]))) + assert_raises(TypeError, int, a) + assert_equal(int(a[-1]), 3) + assert_raises(MAError, lambda: int(a[0])) + + def test_oddfeatures_1(self): + # Test of other odd features + x = arange(20) + x = x.reshape(4, 5) + x.flat[5] = 12 + assert_(x[1, 0] == 12) + z = x + 10j * x + assert_equal(z.real, x) + assert_equal(z.imag, 10 * x) + assert_equal((z * conjugate(z)).real, 101 * x * x) + z.imag[...] = 0.0 + + x = arange(10) + x[3] = masked + assert_(str(x[3]) == str(masked)) + c = x >= 8 + assert_(count(where(c, masked, masked)) == 0) + assert_(shape(where(c, masked, masked)) == c.shape) + + z = masked_where(c, x) + assert_(z.dtype is x.dtype) + assert_(z[3] is masked) + assert_(z[4] is not masked) + assert_(z[7] is not masked) + assert_(z[8] is masked) + assert_(z[9] is masked) + assert_equal(x, z) + + def test_oddfeatures_2(self): + # Tests some more features. + x = array([1., 2., 3., 4., 5.]) + c = array([1, 1, 1, 0, 0]) + x[2] = masked + z = where(c, x, -x) + assert_equal(z, [1., 2., 0., -4., -5]) + c[0] = masked + z = where(c, x, -x) + assert_equal(z, [1., 2., 0., -4., -5]) + assert_(z[0] is masked) + assert_(z[1] is not masked) + assert_(z[2] is masked) + + @suppress_copy_mask_on_assignment + def test_oddfeatures_3(self): + # Tests some generic features + atest = array([10], mask=True) + btest = array([20]) + idx = atest.mask + atest[idx] = btest[idx] + assert_equal(atest, [20]) + + def test_filled_with_object_dtype(self): + a = np.ma.masked_all(1, dtype='O') + assert_equal(a.filled('x')[0], 'x') + + def test_filled_with_flexible_dtype(self): + # Test filled w/ flexible dtype + flexi = array([(1, 1, 1)], + dtype=[('i', int), ('s', '|S8'), ('f', float)]) + flexi[0] = masked + assert_equal(flexi.filled(), + np.array([(default_fill_value(0), + default_fill_value('0'), + default_fill_value(0.),)], dtype=flexi.dtype)) + flexi[0] = masked + assert_equal(flexi.filled(1), + np.array([(1, '1', 1.)], dtype=flexi.dtype)) + + def test_filled_with_mvoid(self): + # Test filled w/ mvoid + ndtype = [('a', int), ('b', float)] + a = mvoid((1, 2.), mask=[(0, 1)], dtype=ndtype) + # Filled using default + test = a.filled() + assert_equal(tuple(test), (1, default_fill_value(1.))) + # Explicit fill_value + test = a.filled((-1, -1)) + assert_equal(tuple(test), (1, -1)) + # Using predefined filling values + a.fill_value = (-999, -999) + assert_equal(tuple(a.filled()), (1, -999)) + + def test_filled_with_nested_dtype(self): + # Test filled w/ nested dtype + ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])] + a = array([(1, (1, 1)), (2, (2, 2))], + mask=[(0, (1, 0)), (0, (0, 1))], dtype=ndtype) + test = a.filled(0) + control = np.array([(1, (0, 1)), (2, (2, 0))], dtype=ndtype) + assert_equal(test, control) + + test = a['B'].filled(0) + control = np.array([(0, 1), (2, 0)], dtype=a['B'].dtype) + assert_equal(test, control) + + # test if mask gets set correctly (see #6760) + Z = numpy.ma.zeros(2, numpy.dtype([("A", "(2,2)i1,(2,2)i1", (2, 2))])) + assert_equal(Z.data.dtype, numpy.dtype([('A', [('f0', 'i1', (2, 2)), + ('f1', 'i1', (2, 2))], (2, 2))])) + assert_equal(Z.mask.dtype, numpy.dtype([('A', [('f0', '?', (2, 2)), + ('f1', '?', (2, 2))], (2, 2))])) + + def test_filled_with_f_order(self): + # Test filled w/ F-contiguous array + a = array(np.array([(0, 1, 2), (4, 5, 6)], order='F'), + mask=np.array([(0, 0, 1), (1, 0, 0)], order='F'), + order='F') # this is currently ignored + assert_(a.flags['F_CONTIGUOUS']) + assert_(a.filled(0).flags['F_CONTIGUOUS']) + + def test_optinfo_propagation(self): + # Checks that _optinfo dictionary isn't back-propagated + x = array([1, 2, 3, ], dtype=float) + x._optinfo['info'] = '???' + y = x.copy() + assert_equal(y._optinfo['info'], '???') + y._optinfo['info'] = '!!!' + assert_equal(x._optinfo['info'], '???') + + def test_optinfo_forward_propagation(self): + a = array([1, 2, 2, 4]) + a._optinfo["key"] = "value" + assert_equal(a._optinfo["key"], (a == 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a != 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a > 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a >= 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a <= 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a + 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a - 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a * 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], (a / 2)._optinfo["key"]) + assert_equal(a._optinfo["key"], a[:2]._optinfo["key"]) + assert_equal(a._optinfo["key"], a[[0, 0, 2]]._optinfo["key"]) + assert_equal(a._optinfo["key"], np.exp(a)._optinfo["key"]) + assert_equal(a._optinfo["key"], np.abs(a)._optinfo["key"]) + assert_equal(a._optinfo["key"], array(a, copy=True)._optinfo["key"]) + assert_equal(a._optinfo["key"], np.zeros_like(a)._optinfo["key"]) + + def test_fancy_printoptions(self): + # Test printing a masked array w/ fancy dtype. + fancydtype = np.dtype([('x', int), ('y', [('t', int), ('s', float)])]) + test = array([(1, (2, 3.0)), (4, (5, 6.0))], + mask=[(1, (0, 1)), (0, (1, 0))], + dtype=fancydtype) + control = "[(--, (2, --)) (4, (--, 6.0))]" + assert_equal(str(test), control) + + # Test 0-d array with multi-dimensional dtype + t_2d0 = masked_array(data=(0, [[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]], + 0.0), + mask=(False, [[True, False, True], + [False, False, True]], + False), + dtype="int, (2,3)float, float") + control = "(0, [[--, 0.0, --], [0.0, 0.0, --]], 0.0)" + assert_equal(str(t_2d0), control) + + def test_flatten_structured_array(self): + # Test flatten_structured_array on arrays + # On ndarray + ndtype = [('a', int), ('b', float)] + a = np.array([(1, 1), (2, 2)], dtype=ndtype) + test = flatten_structured_array(a) + control = np.array([[1., 1.], [2., 2.]], dtype=float) + assert_equal(test, control) + assert_equal(test.dtype, control.dtype) + # On masked_array + a = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype) + test = flatten_structured_array(a) + control = array([[1., 1.], [2., 2.]], + mask=[[0, 1], [1, 0]], dtype=float) + assert_equal(test, control) + assert_equal(test.dtype, control.dtype) + assert_equal(test.mask, control.mask) + # On masked array with nested structure + ndtype = [('a', int), ('b', [('ba', int), ('bb', float)])] + a = array([(1, (1, 1.1)), (2, (2, 2.2))], + mask=[(0, (1, 0)), (1, (0, 1))], dtype=ndtype) + test = flatten_structured_array(a) + control = array([[1., 1., 1.1], [2., 2., 2.2]], + mask=[[0, 1, 0], [1, 0, 1]], dtype=float) + assert_equal(test, control) + assert_equal(test.dtype, control.dtype) + assert_equal(test.mask, control.mask) + # Keeping the initial shape + ndtype = [('a', int), ('b', float)] + a = np.array([[(1, 1), ], [(2, 2), ]], dtype=ndtype) + test = flatten_structured_array(a) + control = np.array([[[1., 1.], ], [[2., 2.], ]], dtype=float) + assert_equal(test, control) + assert_equal(test.dtype, control.dtype) + + def test_void0d(self): + # Test creating a mvoid object + ndtype = [('a', int), ('b', int)] + a = np.array([(1, 2,)], dtype=ndtype)[0] + f = mvoid(a) + assert_(isinstance(f, mvoid)) + + a = masked_array([(1, 2)], mask=[(1, 0)], dtype=ndtype)[0] + assert_(isinstance(a, mvoid)) + + a = masked_array([(1, 2), (1, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype) + f = mvoid(a._data[0], a._mask[0]) + assert_(isinstance(f, mvoid)) + + def test_mvoid_getitem(self): + # Test mvoid.__getitem__ + ndtype = [('a', int), ('b', int)] + a = masked_array([(1, 2,), (3, 4)], mask=[(0, 0), (1, 0)], + dtype=ndtype) + # w/o mask + f = a[0] + assert_(isinstance(f, mvoid)) + assert_equal((f[0], f['a']), (1, 1)) + assert_equal(f['b'], 2) + # w/ mask + f = a[1] + assert_(isinstance(f, mvoid)) + assert_(f[0] is masked) + assert_(f['a'] is masked) + assert_equal(f[1], 4) + + # exotic dtype + A = masked_array(data=[([0, 1],)], + mask=[([True, False],)], + dtype=[("A", ">i2", (2,))]) + assert_equal(A[0]["A"], A["A"][0]) + assert_equal(A[0]["A"], masked_array(data=[0, 1], + mask=[True, False], dtype=">i2")) + + def test_mvoid_iter(self): + # Test iteration on __getitem__ + ndtype = [('a', int), ('b', int)] + a = masked_array([(1, 2,), (3, 4)], mask=[(0, 0), (1, 0)], + dtype=ndtype) + # w/o mask + assert_equal(list(a[0]), [1, 2]) + # w/ mask + assert_equal(list(a[1]), [masked, 4]) + + def test_mvoid_print(self): + # Test printing a mvoid + mx = array([(1, 1), (2, 2)], dtype=[('a', int), ('b', int)]) + assert_equal(str(mx[0]), "(1, 1)") + mx['b'][0] = masked + ini_display = masked_print_option._display + masked_print_option.set_display("-X-") + try: + assert_equal(str(mx[0]), "(1, -X-)") + assert_equal(repr(mx[0]), "(1, -X-)") + finally: + masked_print_option.set_display(ini_display) + + # also check if there are object datatypes (see gh-7493) + mx = array([(1,), (2,)], dtype=[('a', 'O')]) + assert_equal(str(mx[0]), "(1,)") + + def test_mvoid_multidim_print(self): + + # regression test for gh-6019 + t_ma = masked_array(data=[([1, 2, 3],)], + mask=[([False, True, False],)], + fill_value=([999999, 999999, 999999],), + dtype=[('a', ' 1: + assert_equal(np.concatenate((x, y), 1), concatenate((xm, ym), 1)) + assert_equal(np.add.reduce(x, 1), add.reduce(x, 1)) + assert_equal(np.sum(x, 1), sum(x, 1)) + assert_equal(np.prod(x, 1), product(x, 1)) + + def test_binops_d2D(self): + # Test binary operations on 2D data + a = array([[1.], [2.], [3.]], mask=[[False], [True], [True]]) + b = array([[2., 3.], [4., 5.], [6., 7.]]) + + test = a * b + control = array([[2., 3.], [2., 2.], [3., 3.]], + mask=[[0, 0], [1, 1], [1, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + test = b * a + control = array([[2., 3.], [4., 5.], [6., 7.]], + mask=[[0, 0], [1, 1], [1, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + a = array([[1.], [2.], [3.]]) + b = array([[2., 3.], [4., 5.], [6., 7.]], + mask=[[0, 0], [0, 0], [0, 1]]) + test = a * b + control = array([[2, 3], [8, 10], [18, 3]], + mask=[[0, 0], [0, 0], [0, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + test = b * a + control = array([[2, 3], [8, 10], [18, 7]], + mask=[[0, 0], [0, 0], [0, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + def test_domained_binops_d2D(self): + # Test domained binary operations on 2D data + a = array([[1.], [2.], [3.]], mask=[[False], [True], [True]]) + b = array([[2., 3.], [4., 5.], [6., 7.]]) + + test = a / b + control = array([[1. / 2., 1. / 3.], [2., 2.], [3., 3.]], + mask=[[0, 0], [1, 1], [1, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + test = b / a + control = array([[2. / 1., 3. / 1.], [4., 5.], [6., 7.]], + mask=[[0, 0], [1, 1], [1, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + a = array([[1.], [2.], [3.]]) + b = array([[2., 3.], [4., 5.], [6., 7.]], + mask=[[0, 0], [0, 0], [0, 1]]) + test = a / b + control = array([[1. / 2, 1. / 3], [2. / 4, 2. / 5], [3. / 6, 3]], + mask=[[0, 0], [0, 0], [0, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + test = b / a + control = array([[2 / 1., 3 / 1.], [4 / 2., 5 / 2.], [6 / 3., 7]], + mask=[[0, 0], [0, 0], [0, 1]]) + assert_equal(test, control) + assert_equal(test.data, control.data) + assert_equal(test.mask, control.mask) + + def test_noshrinking(self): + # Check that we don't shrink a mask when not wanted + # Binary operations + a = masked_array([1., 2., 3.], mask=[False, False, False], + shrink=False) + b = a + 1 + assert_equal(b.mask, [0, 0, 0]) + # In place binary operation + a += 1 + assert_equal(a.mask, [0, 0, 0]) + # Domained binary operation + b = a / 1. + assert_equal(b.mask, [0, 0, 0]) + # In place binary operation + a /= 1. + assert_equal(a.mask, [0, 0, 0]) + + def test_ufunc_nomask(self): + # check the case ufuncs should set the mask to false + m = np.ma.array([1]) + # check we don't get array([False], dtype=bool) + assert_equal(np.true_divide(m, 5).mask.shape, ()) + + def test_noshink_on_creation(self): + # Check that the mask is not shrunk on array creation when not wanted + a = np.ma.masked_values([1., 2.5, 3.1], 1.5, shrink=False) + assert_equal(a.mask, [0, 0, 0]) + + def test_mod(self): + # Tests mod + (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d + assert_equal(mod(x, y), mod(xm, ym)) + test = mod(ym, xm) + assert_equal(test, np.mod(ym, xm)) + assert_equal(test.mask, mask_or(xm.mask, ym.mask)) + test = mod(xm, ym) + assert_equal(test, np.mod(xm, ym)) + assert_equal(test.mask, mask_or(mask_or(xm.mask, ym.mask), (ym == 0))) + + def test_TakeTransposeInnerOuter(self): + # Test of take, transpose, inner, outer products + x = arange(24) + y = np.arange(24) + x[5:6] = masked + x = x.reshape(2, 3, 4) + y = y.reshape(2, 3, 4) + assert_equal(np.transpose(y, (2, 0, 1)), transpose(x, (2, 0, 1))) + assert_equal(np.take(y, (2, 0, 1), 1), take(x, (2, 0, 1), 1)) + assert_equal(np.inner(filled(x, 0), filled(y, 0)), + inner(x, y)) + assert_equal(np.outer(filled(x, 0), filled(y, 0)), + outer(x, y)) + y = array(['abc', 1, 'def', 2, 3], object) + y[2] = masked + t = take(y, [0, 3, 4]) + assert_(t[0] == 'abc') + assert_(t[1] == 2) + assert_(t[2] == 3) + + def test_imag_real(self): + # Check complex + xx = array([1 + 10j, 20 + 2j], mask=[1, 0]) + assert_equal(xx.imag, [10, 2]) + assert_equal(xx.imag.filled(), [1e+20, 2]) + assert_equal(xx.imag.dtype, xx._data.imag.dtype) + assert_equal(xx.real, [1, 20]) + assert_equal(xx.real.filled(), [1e+20, 20]) + assert_equal(xx.real.dtype, xx._data.real.dtype) + + def test_methods_with_output(self): + xm = array(np.random.uniform(0, 10, 12)).reshape(3, 4) + xm[:, 0] = xm[0] = xm[-1, -1] = masked + + funclist = ('sum', 'prod', 'var', 'std', 'max', 'min', 'ptp', 'mean',) + + for funcname in funclist: + npfunc = getattr(np, funcname) + xmmeth = getattr(xm, funcname) + # A ndarray as explicit input + output = np.empty(4, dtype=float) + output.fill(-9999) + result = npfunc(xm, axis=0, out=output) + # ... the result should be the given output + assert_(result is output) + assert_equal(result, xmmeth(axis=0, out=output)) + + output = empty(4, dtype=int) + result = xmmeth(axis=0, out=output) + assert_(result is output) + assert_(output[0] is masked) + + def test_eq_on_structured(self): + # Test the equality of structured arrays + ndtype = [('A', int), ('B', int)] + a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype) + + test = (a == a) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + + test = (a == a[0]) + assert_equal(test.data, [True, False]) + assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + + b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype) + test = (a == b) + assert_equal(test.data, [False, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + test = (a[0] == b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype) + test = (a == b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + + # complicated dtype, 2-dimensional array. + ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])] + a = array([[(1, (1, 1)), (2, (2, 2))], + [(3, (3, 3)), (4, (4, 4))]], + mask=[[(0, (1, 0)), (0, (0, 1))], + [(1, (0, 0)), (1, (1, 1))]], dtype=ndtype) + test = (a[0, 0] == a) + assert_equal(test.data, [[True, False], [False, False]]) + assert_equal(test.mask, [[False, False], [False, True]]) + assert_(test.fill_value == True) + + def test_ne_on_structured(self): + # Test the equality of structured arrays + ndtype = [('A', int), ('B', int)] + a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype) + + test = (a != a) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + + test = (a != a[0]) + assert_equal(test.data, [False, True]) + assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + + b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype) + test = (a != b) + assert_equal(test.data, [True, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + test = (a[0] != b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype) + test = (a != b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + + # complicated dtype, 2-dimensional array. + ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])] + a = array([[(1, (1, 1)), (2, (2, 2))], + [(3, (3, 3)), (4, (4, 4))]], + mask=[[(0, (1, 0)), (0, (0, 1))], + [(1, (0, 0)), (1, (1, 1))]], dtype=ndtype) + test = (a[0, 0] != a) + assert_equal(test.data, [[False, True], [True, True]]) + assert_equal(test.mask, [[False, False], [False, True]]) + assert_(test.fill_value == True) + + def test_eq_ne_structured_with_non_masked(self): + a = array([(1, 1), (2, 2), (3, 4)], + mask=[(0, 1), (0, 0), (1, 1)], dtype='i4,i4') + eq = a == a.data + ne = a.data != a + # Test the obvious. + assert_(np.all(eq)) + assert_(not np.any(ne)) + # Expect the mask set only for items with all fields masked. + expected_mask = a.mask == np.ones((), a.mask.dtype) + assert_array_equal(eq.mask, expected_mask) + assert_array_equal(ne.mask, expected_mask) + # The masked element will indicated not equal, because the + # masks did not match. + assert_equal(eq.data, [True, True, False]) + assert_array_equal(eq.data, ~ne.data) + + def test_eq_ne_structured_extra(self): + # ensure simple examples are symmetric and make sense. + # from https://github.com/numpy/numpy/pull/8590#discussion_r101126465 + dt = np.dtype('i4,i4') + for m1 in (mvoid((1, 2), mask=(0, 0), dtype=dt), + mvoid((1, 2), mask=(0, 1), dtype=dt), + mvoid((1, 2), mask=(1, 0), dtype=dt), + mvoid((1, 2), mask=(1, 1), dtype=dt)): + ma1 = m1.view(MaskedArray) + r1 = ma1.view('2i4') + for m2 in (np.array((1, 1), dtype=dt), + mvoid((1, 1), dtype=dt), + mvoid((1, 0), mask=(0, 1), dtype=dt), + mvoid((3, 2), mask=(0, 1), dtype=dt)): + ma2 = m2.view(MaskedArray) + r2 = ma2.view('2i4') + eq_expected = (r1 == r2).all() + assert_equal(m1 == m2, eq_expected) + assert_equal(m2 == m1, eq_expected) + assert_equal(ma1 == m2, eq_expected) + assert_equal(m1 == ma2, eq_expected) + assert_equal(ma1 == ma2, eq_expected) + # Also check it is the same if we do it element by element. + el_by_el = [m1[name] == m2[name] for name in dt.names] + assert_equal(array(el_by_el, dtype=bool).all(), eq_expected) + ne_expected = (r1 != r2).any() + assert_equal(m1 != m2, ne_expected) + assert_equal(m2 != m1, ne_expected) + assert_equal(ma1 != m2, ne_expected) + assert_equal(m1 != ma2, ne_expected) + assert_equal(ma1 != ma2, ne_expected) + el_by_el = [m1[name] != m2[name] for name in dt.names] + assert_equal(array(el_by_el, dtype=bool).any(), ne_expected) + + @pytest.mark.parametrize('dt', ['S', 'U']) + @pytest.mark.parametrize('fill', [None, 'A']) + def test_eq_for_strings(self, dt, fill): + # Test the equality of structured arrays + a = array(['a', 'b'], dtype=dt, mask=[0, 1], fill_value=fill) + + test = (a == a) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a == a[0]) + assert_equal(test.data, [True, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array(['a', 'b'], dtype=dt, mask=[1, 0], fill_value=fill) + test = (a == b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + test = (a[0] == b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + test = (b == a[0]) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize('dt', ['S', 'U']) + @pytest.mark.parametrize('fill', [None, 'A']) + def test_ne_for_strings(self, dt, fill): + # Test the equality of structured arrays + a = array(['a', 'b'], dtype=dt, mask=[0, 1], fill_value=fill) + + test = (a != a) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a != a[0]) + assert_equal(test.data, [False, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array(['a', 'b'], dtype=dt, mask=[1, 0], fill_value=fill) + test = (a != b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + test = (a[0] != b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + test = (b != a[0]) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize('dt1', num_dts, ids=num_ids) + @pytest.mark.parametrize('dt2', num_dts, ids=num_ids) + @pytest.mark.parametrize('fill', [None, 1]) + def test_eq_for_numeric(self, dt1, dt2, fill): + # Test the equality of structured arrays + a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill) + + test = (a == a) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a == a[0]) + assert_equal(test.data, [True, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill) + test = (a == b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + test = (a[0] == b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + test = (b == a[0]) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize("op", [operator.eq, operator.lt]) + def test_eq_broadcast_with_unmasked(self, op): + a = array([0, 1], mask=[0, 1]) + b = np.arange(10).reshape(5, 2) + result = op(a, b) + assert_(result.mask.shape == b.shape) + assert_equal(result.mask, np.zeros(b.shape, bool) | a.mask) + + @pytest.mark.parametrize("op", [operator.eq, operator.gt]) + def test_comp_no_mask_not_broadcast(self, op): + # Regression test for failing doctest in MaskedArray.nonzero + # after gh-24556. + a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + result = op(a, 3) + assert_(not result.mask.shape) + assert_(result.mask is nomask) + + @pytest.mark.parametrize('dt1', num_dts, ids=num_ids) + @pytest.mark.parametrize('dt2', num_dts, ids=num_ids) + @pytest.mark.parametrize('fill', [None, 1]) + def test_ne_for_numeric(self, dt1, dt2, fill): + # Test the equality of structured arrays + a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill) + + test = (a != a) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a != a[0]) + assert_equal(test.data, [False, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill) + test = (a != b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + test = (a[0] != b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + test = (b != a[0]) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize('dt1', num_dts, ids=num_ids) + @pytest.mark.parametrize('dt2', num_dts, ids=num_ids) + @pytest.mark.parametrize('fill', [None, 1]) + @pytest.mark.parametrize('op', + [operator.le, operator.lt, operator.ge, operator.gt]) + def test_comparisons_for_numeric(self, op, dt1, dt2, fill): + # Test the equality of structured arrays + a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill) + + test = op(a, a) + assert_equal(test.data, op(a._data, a._data)) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = op(a, a[0]) + assert_equal(test.data, op(a._data, a._data[0])) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill) + test = op(a, b) + assert_equal(test.data, op(a._data, b._data)) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + test = op(a[0], b) + assert_equal(test.data, op(a._data[0], b._data)) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + test = op(b, a[0]) + assert_equal(test.data, op(b._data, a._data[0])) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize('op', + [operator.le, operator.lt, operator.ge, operator.gt]) + @pytest.mark.parametrize('fill', [None, "N/A"]) + def test_comparisons_strings(self, op, fill): + # See gh-21770, mask propagation is broken for strings (and some other + # cases) so we explicitly test strings here. + # In principle only == and != may need special handling... + ma1 = masked_array(["a", "b", "cde"], mask=[0, 1, 0], fill_value=fill) + ma2 = masked_array(["cde", "b", "a"], mask=[0, 1, 0], fill_value=fill) + assert_equal(op(ma1, ma2)._data, op(ma1._data, ma2._data)) + + def test_eq_with_None(self): + # Really, comparisons with None should not be done, but check them + # anyway. Note that pep8 will flag these tests. + # Deprecation is in place for arrays, and when it happens this + # test will fail (and have to be changed accordingly). + + # With partial mask + with suppress_warnings() as sup: + sup.filter(FutureWarning, "Comparison to `None`") + a = array([None, 1], mask=[0, 1]) + assert_equal(a == None, array([True, False], mask=[0, 1])) # noqa: E711 + assert_equal(a.data == None, [True, False]) # noqa: E711 + assert_equal(a != None, array([False, True], mask=[0, 1])) # noqa: E711 + # With nomask + a = array([None, 1], mask=False) + assert_equal(a == None, [True, False]) # noqa: E711 + assert_equal(a != None, [False, True]) # noqa: E711 + # With complete mask + a = array([None, 2], mask=True) + assert_equal(a == None, array([False, True], mask=True)) # noqa: E711 + assert_equal(a != None, array([True, False], mask=True)) # noqa: E711 + # Fully masked, even comparison to None should return "masked" + a = masked + assert_equal(a == None, masked) # noqa: E711 + + def test_eq_with_scalar(self): + a = array(1) + assert_equal(a == 1, True) + assert_equal(a == 0, False) + assert_equal(a != 1, False) + assert_equal(a != 0, True) + b = array(1, mask=True) + assert_equal(b == 0, masked) + assert_equal(b == 1, masked) + assert_equal(b != 0, masked) + assert_equal(b != 1, masked) + + def test_eq_different_dimensions(self): + m1 = array([1, 1], mask=[0, 1]) + # test comparison with both masked and regular arrays. + for m2 in (array([[0, 1], [1, 2]]), + np.array([[0, 1], [1, 2]])): + test = (m1 == m2) + assert_equal(test.data, [[False, False], + [True, False]]) + assert_equal(test.mask, [[False, True], + [False, True]]) + + def test_numpyarithmetic(self): + # Check that the mask is not back-propagated when using numpy functions + a = masked_array([-1, 0, 1, 2, 3], mask=[0, 0, 0, 0, 1]) + control = masked_array([np.nan, np.nan, 0, np.log(2), -1], + mask=[1, 1, 0, 0, 1]) + + test = log(a) + assert_equal(test, control) + assert_equal(test.mask, control.mask) + assert_equal(a.mask, [0, 0, 0, 0, 1]) + + test = np.log(a) + assert_equal(test, control) + assert_equal(test.mask, control.mask) + assert_equal(a.mask, [0, 0, 0, 0, 1]) + + +class TestMaskedArrayAttributes: + + def test_keepmask(self): + # Tests the keep mask flag + x = masked_array([1, 2, 3], mask=[1, 0, 0]) + mx = masked_array(x) + assert_equal(mx.mask, x.mask) + mx = masked_array(x, mask=[0, 1, 0], keep_mask=False) + assert_equal(mx.mask, [0, 1, 0]) + mx = masked_array(x, mask=[0, 1, 0], keep_mask=True) + assert_equal(mx.mask, [1, 1, 0]) + # We default to true + mx = masked_array(x, mask=[0, 1, 0]) + assert_equal(mx.mask, [1, 1, 0]) + + def test_hardmask(self): + # Test hard_mask + d = arange(5) + n = [0, 0, 0, 1, 1] + m = make_mask(n) + xh = array(d, mask=m, hard_mask=True) + # We need to copy, to avoid updating d in xh ! + xs = array(d, mask=m, hard_mask=False, copy=True) + xh[[1, 4]] = [10, 40] + xs[[1, 4]] = [10, 40] + assert_equal(xh._data, [0, 10, 2, 3, 4]) + assert_equal(xs._data, [0, 10, 2, 3, 40]) + assert_equal(xs.mask, [0, 0, 0, 1, 0]) + assert_(xh._hardmask) + assert_(not xs._hardmask) + xh[1:4] = [10, 20, 30] + xs[1:4] = [10, 20, 30] + assert_equal(xh._data, [0, 10, 20, 3, 4]) + assert_equal(xs._data, [0, 10, 20, 30, 40]) + assert_equal(xs.mask, nomask) + xh[0] = masked + xs[0] = masked + assert_equal(xh.mask, [1, 0, 0, 1, 1]) + assert_equal(xs.mask, [1, 0, 0, 0, 0]) + xh[:] = 1 + xs[:] = 1 + assert_equal(xh._data, [0, 1, 1, 3, 4]) + assert_equal(xs._data, [1, 1, 1, 1, 1]) + assert_equal(xh.mask, [1, 0, 0, 1, 1]) + assert_equal(xs.mask, nomask) + # Switch to soft mask + xh.soften_mask() + xh[:] = arange(5) + assert_equal(xh._data, [0, 1, 2, 3, 4]) + assert_equal(xh.mask, nomask) + # Switch back to hard mask + xh.harden_mask() + xh[xh < 3] = masked + assert_equal(xh._data, [0, 1, 2, 3, 4]) + assert_equal(xh._mask, [1, 1, 1, 0, 0]) + xh[filled(xh > 1, False)] = 5 + assert_equal(xh._data, [0, 1, 2, 5, 5]) + assert_equal(xh._mask, [1, 1, 1, 0, 0]) + + xh = array([[1, 2], [3, 4]], mask=[[1, 0], [0, 0]], hard_mask=True) + xh[0] = 0 + assert_equal(xh._data, [[1, 0], [3, 4]]) + assert_equal(xh._mask, [[1, 0], [0, 0]]) + xh[-1, -1] = 5 + assert_equal(xh._data, [[1, 0], [3, 5]]) + assert_equal(xh._mask, [[1, 0], [0, 0]]) + xh[filled(xh < 5, False)] = 2 + assert_equal(xh._data, [[1, 2], [2, 5]]) + assert_equal(xh._mask, [[1, 0], [0, 0]]) + + def test_hardmask_again(self): + # Another test of hardmask + d = arange(5) + n = [0, 0, 0, 1, 1] + m = make_mask(n) + xh = array(d, mask=m, hard_mask=True) + xh[4:5] = 999 + xh[0:1] = 999 + assert_equal(xh._data, [999, 1, 2, 3, 4]) + + def test_hardmask_oncemore_yay(self): + # OK, yet another test of hardmask + # Make sure that harden_mask/soften_mask//unshare_mask returns self + a = array([1, 2, 3], mask=[1, 0, 0]) + b = a.harden_mask() + assert_equal(a, b) + b[0] = 0 + assert_equal(a, b) + assert_equal(b, array([1, 2, 3], mask=[1, 0, 0])) + a = b.soften_mask() + a[0] = 0 + assert_equal(a, b) + assert_equal(b, array([0, 2, 3], mask=[0, 0, 0])) + + def test_smallmask(self): + # Checks the behaviour of _smallmask + a = arange(10) + a[1] = masked + a[1] = 1 + assert_equal(a._mask, nomask) + a = arange(10) + a._smallmask = False + a[1] = masked + a[1] = 1 + assert_equal(a._mask, zeros(10)) + + def test_shrink_mask(self): + # Tests .shrink_mask() + a = array([1, 2, 3], mask=[0, 0, 0]) + b = a.shrink_mask() + assert_equal(a, b) + assert_equal(a.mask, nomask) + + # Mask cannot be shrunk on structured types, so is a no-op + a = np.ma.array([(1, 2.0)], [('a', int), ('b', float)]) + b = a.copy() + a.shrink_mask() + assert_equal(a.mask, b.mask) + + def test_flat(self): + # Test that flat can return all types of items [#4585, #4615] + # test 2-D record array + # ... on structured array w/ masked records + x = array([[(1, 1.1, 'one'), (2, 2.2, 'two'), (3, 3.3, 'thr')], + [(4, 4.4, 'fou'), (5, 5.5, 'fiv'), (6, 6.6, 'six')]], + dtype=[('a', int), ('b', float), ('c', '|S8')]) + x['a'][0, 1] = masked + x['b'][1, 0] = masked + x['c'][0, 2] = masked + x[-1, -1] = masked + xflat = x.flat + assert_equal(xflat[0], x[0, 0]) + assert_equal(xflat[1], x[0, 1]) + assert_equal(xflat[2], x[0, 2]) + assert_equal(xflat[:3], x[0]) + assert_equal(xflat[3], x[1, 0]) + assert_equal(xflat[4], x[1, 1]) + assert_equal(xflat[5], x[1, 2]) + assert_equal(xflat[3:], x[1]) + assert_equal(xflat[-1], x[-1, -1]) + i = 0 + j = 0 + for xf in xflat: + assert_equal(xf, x[j, i]) + i += 1 + if i >= x.shape[-1]: + i = 0 + j += 1 + + def test_assign_dtype(self): + # check that the mask's dtype is updated when dtype is changed + a = np.zeros(4, dtype='f4,i4') + + m = np.ma.array(a) + m.dtype = np.dtype('f4') + repr(m) # raises? + assert_equal(m.dtype, np.dtype('f4')) + + # check that dtype changes that change shape of mask too much + # are not allowed + def assign(): + m = np.ma.array(a) + m.dtype = np.dtype('f8') + assert_raises(ValueError, assign) + + b = a.view(dtype='f4', type=np.ma.MaskedArray) # raises? + assert_equal(b.dtype, np.dtype('f4')) + + # check that nomask is preserved + a = np.zeros(4, dtype='f4') + m = np.ma.array(a) + m.dtype = np.dtype('f4,i4') + assert_equal(m.dtype, np.dtype('f4,i4')) + assert_equal(m._mask, np.ma.nomask) + + +class TestFillingValues: + + def test_check_on_scalar(self): + # Test _check_fill_value set to valid and invalid values + _check_fill_value = np.ma.core._check_fill_value + + fval = _check_fill_value(0, int) + assert_equal(fval, 0) + fval = _check_fill_value(None, int) + assert_equal(fval, default_fill_value(0)) + + fval = _check_fill_value(0, "|S3") + assert_equal(fval, b"0") + fval = _check_fill_value(None, "|S3") + assert_equal(fval, default_fill_value(b"camelot!")) + assert_raises(TypeError, _check_fill_value, 1e+20, int) + assert_raises(TypeError, _check_fill_value, 'stuff', int) + + def test_check_on_fields(self): + # Tests _check_fill_value with records + _check_fill_value = np.ma.core._check_fill_value + ndtype = [('a', int), ('b', float), ('c', "|S3")] + # A check on a list should return a single record + fval = _check_fill_value([-999, -12345678.9, "???"], ndtype) + assert_(isinstance(fval, ndarray)) + assert_equal(fval.item(), [-999, -12345678.9, b"???"]) + # A check on None should output the defaults + fval = _check_fill_value(None, ndtype) + assert_(isinstance(fval, ndarray)) + assert_equal(fval.item(), [default_fill_value(0), + default_fill_value(0.), + asbytes(default_fill_value("0"))]) + #.....Using a structured type as fill_value should work + fill_val = np.array((-999, -12345678.9, "???"), dtype=ndtype) + fval = _check_fill_value(fill_val, ndtype) + assert_(isinstance(fval, ndarray)) + assert_equal(fval.item(), [-999, -12345678.9, b"???"]) + + #.....Using a flexible type w/ a different type shouldn't matter + # BEHAVIOR in 1.5 and earlier, and 1.13 and later: match structured + # types by position + fill_val = np.array((-999, -12345678.9, "???"), + dtype=[("A", int), ("B", float), ("C", "|S3")]) + fval = _check_fill_value(fill_val, ndtype) + assert_(isinstance(fval, ndarray)) + assert_equal(fval.item(), [-999, -12345678.9, b"???"]) + + #.....Using an object-array shouldn't matter either + fill_val = np.ndarray(shape=(1,), dtype=object) + fill_val[0] = (-999, -12345678.9, b"???") + fval = _check_fill_value(fill_val, object) + assert_(isinstance(fval, ndarray)) + assert_equal(fval.item(), [-999, -12345678.9, b"???"]) + # NOTE: This test was never run properly as "fill_value" rather than + # "fill_val" was assigned. Written properly, it fails. + #fill_val = np.array((-999, -12345678.9, "???")) + #fval = _check_fill_value(fill_val, ndtype) + #assert_(isinstance(fval, ndarray)) + #assert_equal(fval.item(), [-999, -12345678.9, b"???"]) + #.....One-field-only flexible type should work as well + ndtype = [("a", int)] + fval = _check_fill_value(-999999999, ndtype) + assert_(isinstance(fval, ndarray)) + assert_equal(fval.item(), (-999999999,)) + + def test_fillvalue_conversion(self): + # Tests the behavior of fill_value during conversion + # We had a tailored comment to make sure special attributes are + # properly dealt with + a = array([b'3', b'4', b'5']) + a._optinfo.update({'comment': "updated!"}) + + b = array(a, dtype=int) + assert_equal(b._data, [3, 4, 5]) + assert_equal(b.fill_value, default_fill_value(0)) + + b = array(a, dtype=float) + assert_equal(b._data, [3, 4, 5]) + assert_equal(b.fill_value, default_fill_value(0.)) + + b = a.astype(int) + assert_equal(b._data, [3, 4, 5]) + assert_equal(b.fill_value, default_fill_value(0)) + assert_equal(b._optinfo['comment'], "updated!") + + b = a.astype([('a', '|S3')]) + assert_equal(b['a']._data, a._data) + assert_equal(b['a'].fill_value, a.fill_value) + + def test_default_fill_value(self): + # check all calling conventions + f1 = default_fill_value(1.) + f2 = default_fill_value(np.array(1.)) + f3 = default_fill_value(np.array(1.).dtype) + assert_equal(f1, f2) + assert_equal(f1, f3) + + def test_default_fill_value_structured(self): + fields = array([(1, 1, 1)], + dtype=[('i', int), ('s', '|S8'), ('f', float)]) + + f1 = default_fill_value(fields) + f2 = default_fill_value(fields.dtype) + expected = np.array((default_fill_value(0), + default_fill_value('0'), + default_fill_value(0.)), dtype=fields.dtype) + assert_equal(f1, expected) + assert_equal(f2, expected) + + def test_default_fill_value_void(self): + dt = np.dtype([('v', 'V7')]) + f = default_fill_value(dt) + assert_equal(f['v'], np.array(default_fill_value(dt['v']), dt['v'])) + + def test_fillvalue(self): + # Yet more fun with the fill_value + data = masked_array([1, 2, 3], fill_value=-999) + series = data[[0, 2, 1]] + assert_equal(series._fill_value, data._fill_value) + + mtype = [('f', float), ('s', '|S3')] + x = array([(1, 'a'), (2, 'b'), (pi, 'pi')], dtype=mtype) + x.fill_value = 999 + assert_equal(x.fill_value.item(), [999., b'999']) + assert_equal(x['f'].fill_value, 999) + assert_equal(x['s'].fill_value, b'999') + + x.fill_value = (9, '???') + assert_equal(x.fill_value.item(), (9, b'???')) + assert_equal(x['f'].fill_value, 9) + assert_equal(x['s'].fill_value, b'???') + + x = array([1, 2, 3.1]) + x.fill_value = 999 + assert_equal(np.asarray(x.fill_value).dtype, float) + assert_equal(x.fill_value, 999.) + assert_equal(x._fill_value, np.array(999.)) + + def test_subarray_fillvalue(self): + # gh-10483 test multi-field index fill value + fields = array([(1, 1, 1)], + dtype=[('i', int), ('s', '|S8'), ('f', float)]) + with suppress_warnings() as sup: + sup.filter(FutureWarning, "Numpy has detected") + subfields = fields[['i', 'f']] + assert_equal(tuple(subfields.fill_value), (999999, 1.e+20)) + # test comparison does not raise: + subfields[1:] == subfields[:-1] + + def test_fillvalue_exotic_dtype(self): + # Tests yet more exotic flexible dtypes + _check_fill_value = np.ma.core._check_fill_value + ndtype = [('i', int), ('s', '|S8'), ('f', float)] + control = np.array((default_fill_value(0), + default_fill_value('0'), + default_fill_value(0.),), + dtype=ndtype) + assert_equal(_check_fill_value(None, ndtype), control) + # The shape shouldn't matter + ndtype = [('f0', float, (2, 2))] + control = np.array((default_fill_value(0.),), + dtype=[('f0', float)]).astype(ndtype) + assert_equal(_check_fill_value(None, ndtype), control) + control = np.array((0,), dtype=[('f0', float)]).astype(ndtype) + assert_equal(_check_fill_value(0, ndtype), control) + + ndtype = np.dtype("int, (2,3)float, float") + control = np.array((default_fill_value(0), + default_fill_value(0.), + default_fill_value(0.),), + dtype="int, float, float").astype(ndtype) + test = _check_fill_value(None, ndtype) + assert_equal(test, control) + control = np.array((0, 0, 0), dtype="int, float, float").astype(ndtype) + assert_equal(_check_fill_value(0, ndtype), control) + # but when indexing, fill value should become scalar not tuple + # See issue #6723 + M = masked_array(control) + assert_equal(M["f1"].fill_value.ndim, 0) + + def test_fillvalue_datetime_timedelta(self): + # Test default fillvalue for datetime64 and timedelta64 types. + # See issue #4476, this would return '?' which would cause errors + # elsewhere + + for timecode in ("as", "fs", "ps", "ns", "us", "ms", "s", "m", + "h", "D", "W", "M", "Y"): + control = numpy.datetime64("NaT", timecode) + test = default_fill_value(numpy.dtype(" 0 + + # test different unary domains + sqrt(m) + log(m) + tan(m) + arcsin(m) + arccos(m) + arccosh(m) + + # test binary domains + divide(m, 2) + + # also check that allclose uses ma ufuncs, to avoid warning + allclose(m, 0.5) + + def test_masked_array_underflow(self): + x = np.arange(0, 3, 0.1) + X = np.ma.array(x) + with np.errstate(under="raise"): + X2 = X / 2.0 + np.testing.assert_array_equal(X2, x / 2) + +class TestMaskedArrayInPlaceArithmetic: + # Test MaskedArray Arithmetic + + def setup_method(self): + x = arange(10) + y = arange(10) + xm = arange(10) + xm[2] = masked + self.intdata = (x, y, xm) + self.floatdata = (x.astype(float), y.astype(float), xm.astype(float)) + self.othertypes = np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + self.othertypes = [np.dtype(_).type for _ in self.othertypes] + self.uint8data = ( + x.astype(np.uint8), + y.astype(np.uint8), + xm.astype(np.uint8) + ) + + def test_inplace_addition_scalar(self): + # Test of inplace additions + (x, y, xm) = self.intdata + xm[2] = masked + x += 1 + assert_equal(x, y + 1) + xm += 1 + assert_equal(xm, y + 1) + + (x, _, xm) = self.floatdata + id1 = x.data.ctypes.data + x += 1. + assert_(id1 == x.data.ctypes.data) + assert_equal(x, y + 1.) + + def test_inplace_addition_array(self): + # Test of inplace additions + (x, y, xm) = self.intdata + m = xm.mask + a = arange(10, dtype=np.int16) + a[-1] = masked + x += a + xm += a + assert_equal(x, y + a) + assert_equal(xm, y + a) + assert_equal(xm.mask, mask_or(m, a.mask)) + + def test_inplace_subtraction_scalar(self): + # Test of inplace subtractions + (x, y, xm) = self.intdata + x -= 1 + assert_equal(x, y - 1) + xm -= 1 + assert_equal(xm, y - 1) + + def test_inplace_subtraction_array(self): + # Test of inplace subtractions + (x, y, xm) = self.floatdata + m = xm.mask + a = arange(10, dtype=float) + a[-1] = masked + x -= a + xm -= a + assert_equal(x, y - a) + assert_equal(xm, y - a) + assert_equal(xm.mask, mask_or(m, a.mask)) + + def test_inplace_multiplication_scalar(self): + # Test of inplace multiplication + (x, y, xm) = self.floatdata + x *= 2.0 + assert_equal(x, y * 2) + xm *= 2.0 + assert_equal(xm, y * 2) + + def test_inplace_multiplication_array(self): + # Test of inplace multiplication + (x, y, xm) = self.floatdata + m = xm.mask + a = arange(10, dtype=float) + a[-1] = masked + x *= a + xm *= a + assert_equal(x, y * a) + assert_equal(xm, y * a) + assert_equal(xm.mask, mask_or(m, a.mask)) + + def test_inplace_division_scalar_int(self): + # Test of inplace division + (x, y, xm) = self.intdata + x = arange(10) * 2 + xm = arange(10) * 2 + xm[2] = masked + x //= 2 + assert_equal(x, y) + xm //= 2 + assert_equal(xm, y) + + def test_inplace_division_scalar_float(self): + # Test of inplace division + (x, y, xm) = self.floatdata + x /= 2.0 + assert_equal(x, y / 2.0) + xm /= arange(10) + assert_equal(xm, ones((10,))) + + def test_inplace_division_array_float(self): + # Test of inplace division + (x, y, xm) = self.floatdata + m = xm.mask + a = arange(10, dtype=float) + a[-1] = masked + x /= a + xm /= a + assert_equal(x, y / a) + assert_equal(xm, y / a) + assert_equal(xm.mask, mask_or(mask_or(m, a.mask), (a == 0))) + + def test_inplace_division_misc(self): + + x = [1., 1., 1., -2., pi / 2., 4., 5., -10., 10., 1., 2., 3.] + y = [5., 0., 3., 2., -1., -4., 0., -10., 10., 1., 0., 3.] + m1 = [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0] + m2 = [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1] + xm = masked_array(x, mask=m1) + ym = masked_array(y, mask=m2) + + z = xm / ym + assert_equal(z._mask, [1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1]) + assert_equal(z._data, + [1., 1., 1., -1., -pi / 2., 4., 5., 1., 1., 1., 2., 3.]) + + xm = xm.copy() + xm /= ym + assert_equal(xm._mask, [1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1]) + assert_equal(z._data, + [1., 1., 1., -1., -pi / 2., 4., 5., 1., 1., 1., 2., 3.]) + + def test_datafriendly_add(self): + # Test keeping data w/ (inplace) addition + x = array([1, 2, 3], mask=[0, 0, 1]) + # Test add w/ scalar + xx = x + 1 + assert_equal(xx.data, [2, 3, 3]) + assert_equal(xx.mask, [0, 0, 1]) + # Test iadd w/ scalar + x += 1 + assert_equal(x.data, [2, 3, 3]) + assert_equal(x.mask, [0, 0, 1]) + # Test add w/ array + x = array([1, 2, 3], mask=[0, 0, 1]) + xx = x + array([1, 2, 3], mask=[1, 0, 0]) + assert_equal(xx.data, [1, 4, 3]) + assert_equal(xx.mask, [1, 0, 1]) + # Test iadd w/ array + x = array([1, 2, 3], mask=[0, 0, 1]) + x += array([1, 2, 3], mask=[1, 0, 0]) + assert_equal(x.data, [1, 4, 3]) + assert_equal(x.mask, [1, 0, 1]) + + def test_datafriendly_sub(self): + # Test keeping data w/ (inplace) subtraction + # Test sub w/ scalar + x = array([1, 2, 3], mask=[0, 0, 1]) + xx = x - 1 + assert_equal(xx.data, [0, 1, 3]) + assert_equal(xx.mask, [0, 0, 1]) + # Test isub w/ scalar + x = array([1, 2, 3], mask=[0, 0, 1]) + x -= 1 + assert_equal(x.data, [0, 1, 3]) + assert_equal(x.mask, [0, 0, 1]) + # Test sub w/ array + x = array([1, 2, 3], mask=[0, 0, 1]) + xx = x - array([1, 2, 3], mask=[1, 0, 0]) + assert_equal(xx.data, [1, 0, 3]) + assert_equal(xx.mask, [1, 0, 1]) + # Test isub w/ array + x = array([1, 2, 3], mask=[0, 0, 1]) + x -= array([1, 2, 3], mask=[1, 0, 0]) + assert_equal(x.data, [1, 0, 3]) + assert_equal(x.mask, [1, 0, 1]) + + def test_datafriendly_mul(self): + # Test keeping data w/ (inplace) multiplication + # Test mul w/ scalar + x = array([1, 2, 3], mask=[0, 0, 1]) + xx = x * 2 + assert_equal(xx.data, [2, 4, 3]) + assert_equal(xx.mask, [0, 0, 1]) + # Test imul w/ scalar + x = array([1, 2, 3], mask=[0, 0, 1]) + x *= 2 + assert_equal(x.data, [2, 4, 3]) + assert_equal(x.mask, [0, 0, 1]) + # Test mul w/ array + x = array([1, 2, 3], mask=[0, 0, 1]) + xx = x * array([10, 20, 30], mask=[1, 0, 0]) + assert_equal(xx.data, [1, 40, 3]) + assert_equal(xx.mask, [1, 0, 1]) + # Test imul w/ array + x = array([1, 2, 3], mask=[0, 0, 1]) + x *= array([10, 20, 30], mask=[1, 0, 0]) + assert_equal(x.data, [1, 40, 3]) + assert_equal(x.mask, [1, 0, 1]) + + def test_datafriendly_div(self): + # Test keeping data w/ (inplace) division + # Test div on scalar + x = array([1, 2, 3], mask=[0, 0, 1]) + xx = x / 2. + assert_equal(xx.data, [1 / 2., 2 / 2., 3]) + assert_equal(xx.mask, [0, 0, 1]) + # Test idiv on scalar + x = array([1., 2., 3.], mask=[0, 0, 1]) + x /= 2. + assert_equal(x.data, [1 / 2., 2 / 2., 3]) + assert_equal(x.mask, [0, 0, 1]) + # Test div on array + x = array([1., 2., 3.], mask=[0, 0, 1]) + xx = x / array([10., 20., 30.], mask=[1, 0, 0]) + assert_equal(xx.data, [1., 2. / 20., 3.]) + assert_equal(xx.mask, [1, 0, 1]) + # Test idiv on array + x = array([1., 2., 3.], mask=[0, 0, 1]) + x /= array([10., 20., 30.], mask=[1, 0, 0]) + assert_equal(x.data, [1., 2 / 20., 3.]) + assert_equal(x.mask, [1, 0, 1]) + + def test_datafriendly_pow(self): + # Test keeping data w/ (inplace) power + # Test pow on scalar + x = array([1., 2., 3.], mask=[0, 0, 1]) + xx = x ** 2.5 + assert_equal(xx.data, [1., 2. ** 2.5, 3.]) + assert_equal(xx.mask, [0, 0, 1]) + # Test ipow on scalar + x **= 2.5 + assert_equal(x.data, [1., 2. ** 2.5, 3]) + assert_equal(x.mask, [0, 0, 1]) + + def test_datafriendly_add_arrays(self): + a = array([[1, 1], [3, 3]]) + b = array([1, 1], mask=[0, 0]) + a += b + assert_equal(a, [[2, 2], [4, 4]]) + if a.mask is not nomask: + assert_equal(a.mask, [[0, 0], [0, 0]]) + + a = array([[1, 1], [3, 3]]) + b = array([1, 1], mask=[0, 1]) + a += b + assert_equal(a, [[2, 2], [4, 4]]) + assert_equal(a.mask, [[0, 1], [0, 1]]) + + def test_datafriendly_sub_arrays(self): + a = array([[1, 1], [3, 3]]) + b = array([1, 1], mask=[0, 0]) + a -= b + assert_equal(a, [[0, 0], [2, 2]]) + if a.mask is not nomask: + assert_equal(a.mask, [[0, 0], [0, 0]]) + + a = array([[1, 1], [3, 3]]) + b = array([1, 1], mask=[0, 1]) + a -= b + assert_equal(a, [[0, 0], [2, 2]]) + assert_equal(a.mask, [[0, 1], [0, 1]]) + + def test_datafriendly_mul_arrays(self): + a = array([[1, 1], [3, 3]]) + b = array([1, 1], mask=[0, 0]) + a *= b + assert_equal(a, [[1, 1], [3, 3]]) + if a.mask is not nomask: + assert_equal(a.mask, [[0, 0], [0, 0]]) + + a = array([[1, 1], [3, 3]]) + b = array([1, 1], mask=[0, 1]) + a *= b + assert_equal(a, [[1, 1], [3, 3]]) + assert_equal(a.mask, [[0, 1], [0, 1]]) + + def test_inplace_addition_scalar_type(self): + # Test of inplace additions + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + xm[2] = masked + x += t(1) + assert_equal(x, y + t(1)) + xm += t(1) + assert_equal(xm, y + t(1)) + + def test_inplace_addition_array_type(self): + # Test of inplace additions + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + m = xm.mask + a = arange(10, dtype=t) + a[-1] = masked + x += a + xm += a + assert_equal(x, y + a) + assert_equal(xm, y + a) + assert_equal(xm.mask, mask_or(m, a.mask)) + + def test_inplace_subtraction_scalar_type(self): + # Test of inplace subtractions + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + x -= t(1) + assert_equal(x, y - t(1)) + xm -= t(1) + assert_equal(xm, y - t(1)) + + def test_inplace_subtraction_array_type(self): + # Test of inplace subtractions + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + m = xm.mask + a = arange(10, dtype=t) + a[-1] = masked + x -= a + xm -= a + assert_equal(x, y - a) + assert_equal(xm, y - a) + assert_equal(xm.mask, mask_or(m, a.mask)) + + def test_inplace_multiplication_scalar_type(self): + # Test of inplace multiplication + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + x *= t(2) + assert_equal(x, y * t(2)) + xm *= t(2) + assert_equal(xm, y * t(2)) + + def test_inplace_multiplication_array_type(self): + # Test of inplace multiplication + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + m = xm.mask + a = arange(10, dtype=t) + a[-1] = masked + x *= a + xm *= a + assert_equal(x, y * a) + assert_equal(xm, y * a) + assert_equal(xm.mask, mask_or(m, a.mask)) + + def test_inplace_floor_division_scalar_type(self): + # Test of inplace division + # Check for TypeError in case of unsupported types + unsupported = {np.dtype(t).type for t in np.typecodes["Complex"]} + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + x = arange(10, dtype=t) * t(2) + xm = arange(10, dtype=t) * t(2) + xm[2] = masked + try: + x //= t(2) + xm //= t(2) + assert_equal(x, y) + assert_equal(xm, y) + except TypeError: + msg = f"Supported type {t} throwing TypeError" + assert t in unsupported, msg + + def test_inplace_floor_division_array_type(self): + # Test of inplace division + # Check for TypeError in case of unsupported types + unsupported = {np.dtype(t).type for t in np.typecodes["Complex"]} + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + m = xm.mask + a = arange(10, dtype=t) + a[-1] = masked + try: + x //= a + xm //= a + assert_equal(x, y // a) + assert_equal(xm, y // a) + assert_equal( + xm.mask, + mask_or(mask_or(m, a.mask), (a == t(0))) + ) + except TypeError: + msg = f"Supported type {t} throwing TypeError" + assert t in unsupported, msg + + def test_inplace_division_scalar_type(self): + # Test of inplace division + for t in self.othertypes: + with suppress_warnings() as sup: + sup.record(UserWarning) + + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + x = arange(10, dtype=t) * t(2) + xm = arange(10, dtype=t) * t(2) + xm[2] = masked + + # May get a DeprecationWarning or a TypeError. + # + # This is a consequence of the fact that this is true divide + # and will require casting to float for calculation and + # casting back to the original type. This will only be raised + # with integers. Whether it is an error or warning is only + # dependent on how stringent the casting rules are. + # + # Will handle the same way. + try: + x /= t(2) + assert_equal(x, y) + except (DeprecationWarning, TypeError) as e: + warnings.warn(str(e), stacklevel=1) + try: + xm /= t(2) + assert_equal(xm, y) + except (DeprecationWarning, TypeError) as e: + warnings.warn(str(e), stacklevel=1) + + if issubclass(t, np.integer): + assert_equal(len(sup.log), 2, f'Failed on type={t}.') + else: + assert_equal(len(sup.log), 0, f'Failed on type={t}.') + + def test_inplace_division_array_type(self): + # Test of inplace division + for t in self.othertypes: + with suppress_warnings() as sup: + sup.record(UserWarning) + (x, y, xm) = (_.astype(t) for _ in self.uint8data) + m = xm.mask + a = arange(10, dtype=t) + a[-1] = masked + + # May get a DeprecationWarning or a TypeError. + # + # This is a consequence of the fact that this is true divide + # and will require casting to float for calculation and + # casting back to the original type. This will only be raised + # with integers. Whether it is an error or warning is only + # dependent on how stringent the casting rules are. + # + # Will handle the same way. + try: + x /= a + assert_equal(x, y / a) + except (DeprecationWarning, TypeError) as e: + warnings.warn(str(e), stacklevel=1) + try: + xm /= a + assert_equal(xm, y / a) + assert_equal( + xm.mask, + mask_or(mask_or(m, a.mask), (a == t(0))) + ) + except (DeprecationWarning, TypeError) as e: + warnings.warn(str(e), stacklevel=1) + + if issubclass(t, np.integer): + assert_equal(len(sup.log), 2, f'Failed on type={t}.') + else: + assert_equal(len(sup.log), 0, f'Failed on type={t}.') + + def test_inplace_pow_type(self): + # Test keeping data w/ (inplace) power + for t in self.othertypes: + with warnings.catch_warnings(): + warnings.filterwarnings("error") + # Test pow on scalar + x = array([1, 2, 3], mask=[0, 0, 1], dtype=t) + xx = x ** t(2) + xx_r = array([1, 2 ** 2, 3], mask=[0, 0, 1], dtype=t) + assert_equal(xx.data, xx_r.data) + assert_equal(xx.mask, xx_r.mask) + # Test ipow on scalar + x **= t(2) + assert_equal(x.data, xx_r.data) + assert_equal(x.mask, xx_r.mask) + + +class TestMaskedArrayMethods: + # Test class for miscellaneous MaskedArrays methods. + def setup_method(self): + # Base data definition. + x = np.array([8.375, 7.545, 8.828, 8.5, 1.757, 5.928, + 8.43, 7.78, 9.865, 5.878, 8.979, 4.732, + 3.012, 6.022, 5.095, 3.116, 5.238, 3.957, + 6.04, 9.63, 7.712, 3.382, 4.489, 6.479, + 7.189, 9.645, 5.395, 4.961, 9.894, 2.893, + 7.357, 9.828, 6.272, 3.758, 6.693, 0.993]) + X = x.reshape(6, 6) + XX = x.reshape(3, 2, 2, 3) + + m = np.array([0, 1, 0, 1, 0, 0, + 1, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, 0]) + mx = array(data=x, mask=m) + mX = array(data=X, mask=m.reshape(X.shape)) + mXX = array(data=XX, mask=m.reshape(XX.shape)) + + m2 = np.array([1, 1, 0, 1, 0, 0, + 1, 1, 1, 1, 0, 1, + 0, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 1, 0, + 0, 0, 1, 0, 1, 1]) + m2x = array(data=x, mask=m2) + m2X = array(data=X, mask=m2.reshape(X.shape)) + m2XX = array(data=XX, mask=m2.reshape(XX.shape)) + self.d = (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) + + def test_generic_methods(self): + # Tests some MaskedArray methods. + a = array([1, 3, 2]) + assert_equal(a.any(), a._data.any()) + assert_equal(a.all(), a._data.all()) + assert_equal(a.argmax(), a._data.argmax()) + assert_equal(a.argmin(), a._data.argmin()) + assert_equal(a.choose(0, 1, 2, 3, 4), a._data.choose(0, 1, 2, 3, 4)) + assert_equal(a.compress([1, 0, 1]), a._data.compress([1, 0, 1])) + assert_equal(a.conj(), a._data.conj()) + assert_equal(a.conjugate(), a._data.conjugate()) + + m = array([[1, 2], [3, 4]]) + assert_equal(m.diagonal(), m._data.diagonal()) + assert_equal(a.sum(), a._data.sum()) + assert_equal(a.take([1, 2]), a._data.take([1, 2])) + assert_equal(m.transpose(), m._data.transpose()) + + def test_allclose(self): + # Tests allclose on arrays + a = np.random.rand(10) + b = a + np.random.rand(10) * 1e-8 + assert_(allclose(a, b)) + # Test allclose w/ infs + a[0] = np.inf + assert_(not allclose(a, b)) + b[0] = np.inf + assert_(allclose(a, b)) + # Test allclose w/ masked + a = masked_array(a) + a[-1] = masked + assert_(allclose(a, b, masked_equal=True)) + assert_(not allclose(a, b, masked_equal=False)) + # Test comparison w/ scalar + a *= 1e-8 + a[0] = 0 + assert_(allclose(a, 0, masked_equal=True)) + + # Test that the function works for MIN_INT integer typed arrays + a = masked_array([np.iinfo(np.int_).min], dtype=np.int_) + assert_(allclose(a, a)) + + def test_allclose_timedelta(self): + # Allclose currently works for timedelta64 as long as `atol` is + # an integer or also a timedelta64 + a = np.array([[1, 2, 3, 4]], dtype="m8[ns]") + assert allclose(a, a, atol=0) + assert allclose(a, a, atol=np.timedelta64(1, "ns")) + + def test_allany(self): + # Checks the any/all methods/functions. + x = np.array([[0.13, 0.26, 0.90], + [0.28, 0.33, 0.63], + [0.31, 0.87, 0.70]]) + m = np.array([[True, False, False], + [False, False, False], + [True, True, False]], dtype=np.bool) + mx = masked_array(x, mask=m) + mxbig = (mx > 0.5) + mxsmall = (mx < 0.5) + + assert_(not mxbig.all()) + assert_(mxbig.any()) + assert_equal(mxbig.all(0), [False, False, True]) + assert_equal(mxbig.all(1), [False, False, True]) + assert_equal(mxbig.any(0), [False, False, True]) + assert_equal(mxbig.any(1), [True, True, True]) + + assert_(not mxsmall.all()) + assert_(mxsmall.any()) + assert_equal(mxsmall.all(0), [True, True, False]) + assert_equal(mxsmall.all(1), [False, False, False]) + assert_equal(mxsmall.any(0), [True, True, False]) + assert_equal(mxsmall.any(1), [True, True, False]) + + def test_allany_oddities(self): + # Some fun with all and any + store = empty((), dtype=bool) + full = array([1, 2, 3], mask=True) + + assert_(full.all() is masked) + full.all(out=store) + assert_(store) + assert_(store._mask, True) + assert_(store is not masked) + + store = empty((), dtype=bool) + assert_(full.any() is masked) + full.any(out=store) + assert_(not store) + assert_(store._mask, True) + assert_(store is not masked) + + def test_argmax_argmin(self): + # Tests argmin & argmax on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + + assert_equal(mx.argmin(), 35) + assert_equal(mX.argmin(), 35) + assert_equal(m2x.argmin(), 4) + assert_equal(m2X.argmin(), 4) + assert_equal(mx.argmax(), 28) + assert_equal(mX.argmax(), 28) + assert_equal(m2x.argmax(), 31) + assert_equal(m2X.argmax(), 31) + + assert_equal(mX.argmin(0), [2, 2, 2, 5, 0, 5]) + assert_equal(m2X.argmin(0), [2, 2, 4, 5, 0, 4]) + assert_equal(mX.argmax(0), [0, 5, 0, 5, 4, 0]) + assert_equal(m2X.argmax(0), [5, 5, 0, 5, 1, 0]) + + assert_equal(mX.argmin(1), [4, 1, 0, 0, 5, 5, ]) + assert_equal(m2X.argmin(1), [4, 4, 0, 0, 5, 3]) + assert_equal(mX.argmax(1), [2, 4, 1, 1, 4, 1]) + assert_equal(m2X.argmax(1), [2, 4, 1, 1, 1, 1]) + + def test_clip(self): + # Tests clip on MaskedArrays. + x = np.array([8.375, 7.545, 8.828, 8.5, 1.757, 5.928, + 8.43, 7.78, 9.865, 5.878, 8.979, 4.732, + 3.012, 6.022, 5.095, 3.116, 5.238, 3.957, + 6.04, 9.63, 7.712, 3.382, 4.489, 6.479, + 7.189, 9.645, 5.395, 4.961, 9.894, 2.893, + 7.357, 9.828, 6.272, 3.758, 6.693, 0.993]) + m = np.array([0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0]) + mx = array(x, mask=m) + clipped = mx.clip(2, 8) + assert_equal(clipped.mask, mx.mask) + assert_equal(clipped._data, x.clip(2, 8)) + assert_equal(clipped._data, mx._data.clip(2, 8)) + + def test_clip_out(self): + # gh-14140 + a = np.arange(10) + m = np.ma.MaskedArray(a, mask=[0, 1] * 5) + m.clip(0, 5, out=m) + assert_equal(m.mask, [0, 1] * 5) + + def test_compress(self): + # test compress + a = masked_array([1., 2., 3., 4., 5.], fill_value=9999) + condition = (a > 1.5) & (a < 3.5) + assert_equal(a.compress(condition), [2., 3.]) + + a[[2, 3]] = masked + b = a.compress(condition) + assert_equal(b._data, [2., 3.]) + assert_equal(b._mask, [0, 1]) + assert_equal(b.fill_value, 9999) + assert_equal(b, a[condition]) + + condition = (a < 4.) + b = a.compress(condition) + assert_equal(b._data, [1., 2., 3.]) + assert_equal(b._mask, [0, 0, 1]) + assert_equal(b.fill_value, 9999) + assert_equal(b, a[condition]) + + a = masked_array([[10, 20, 30], [40, 50, 60]], + mask=[[0, 0, 1], [1, 0, 0]]) + b = a.compress(a.ravel() >= 22) + assert_equal(b._data, [30, 40, 50, 60]) + assert_equal(b._mask, [1, 1, 0, 0]) + + x = np.array([3, 1, 2]) + b = a.compress(x >= 2, axis=1) + assert_equal(b._data, [[10, 30], [40, 60]]) + assert_equal(b._mask, [[0, 1], [1, 0]]) + + def test_compressed(self): + # Tests compressed + a = array([1, 2, 3, 4], mask=[0, 0, 0, 0]) + b = a.compressed() + assert_equal(b, a) + a[0] = masked + b = a.compressed() + assert_equal(b, [2, 3, 4]) + + def test_empty(self): + # Tests empty/like + datatype = [('a', int), ('b', float), ('c', '|S8')] + a = masked_array([(1, 1.1, '1.1'), (2, 2.2, '2.2'), (3, 3.3, '3.3')], + dtype=datatype) + assert_equal(len(a.fill_value.item()), len(datatype)) + + b = empty_like(a) + assert_equal(b.shape, a.shape) + assert_equal(b.fill_value, a.fill_value) + + b = empty(len(a), dtype=datatype) + assert_equal(b.shape, a.shape) + assert_equal(b.fill_value, a.fill_value) + + # check empty_like mask handling + a = masked_array([1, 2, 3], mask=[False, True, False]) + b = empty_like(a) + assert_(not np.may_share_memory(a.mask, b.mask)) + b = a.view(masked_array) + assert_(np.may_share_memory(a.mask, b.mask)) + + def test_zeros(self): + # Tests zeros/like + datatype = [('a', int), ('b', float), ('c', '|S8')] + a = masked_array([(1, 1.1, '1.1'), (2, 2.2, '2.2'), (3, 3.3, '3.3')], + dtype=datatype) + assert_equal(len(a.fill_value.item()), len(datatype)) + + b = zeros(len(a), dtype=datatype) + assert_equal(b.shape, a.shape) + assert_equal(b.fill_value, a.fill_value) + + b = zeros_like(a) + assert_equal(b.shape, a.shape) + assert_equal(b.fill_value, a.fill_value) + + # check zeros_like mask handling + a = masked_array([1, 2, 3], mask=[False, True, False]) + b = zeros_like(a) + assert_(not np.may_share_memory(a.mask, b.mask)) + b = a.view() + assert_(np.may_share_memory(a.mask, b.mask)) + + def test_ones(self): + # Tests ones/like + datatype = [('a', int), ('b', float), ('c', '|S8')] + a = masked_array([(1, 1.1, '1.1'), (2, 2.2, '2.2'), (3, 3.3, '3.3')], + dtype=datatype) + assert_equal(len(a.fill_value.item()), len(datatype)) + + b = ones(len(a), dtype=datatype) + assert_equal(b.shape, a.shape) + assert_equal(b.fill_value, a.fill_value) + + b = ones_like(a) + assert_equal(b.shape, a.shape) + assert_equal(b.fill_value, a.fill_value) + + # check ones_like mask handling + a = masked_array([1, 2, 3], mask=[False, True, False]) + b = ones_like(a) + assert_(not np.may_share_memory(a.mask, b.mask)) + b = a.view() + assert_(np.may_share_memory(a.mask, b.mask)) + + @suppress_copy_mask_on_assignment + def test_put(self): + # Tests put. + d = arange(5) + n = [0, 0, 0, 1, 1] + m = make_mask(n) + x = array(d, mask=m) + assert_(x[3] is masked) + assert_(x[4] is masked) + x[[1, 4]] = [10, 40] + assert_(x[3] is masked) + assert_(x[4] is not masked) + assert_equal(x, [0, 10, 2, -1, 40]) + + x = masked_array(arange(10), mask=[1, 0, 0, 0, 0] * 2) + i = [0, 2, 4, 6] + x.put(i, [6, 4, 2, 0]) + assert_equal(x, asarray([6, 1, 4, 3, 2, 5, 0, 7, 8, 9, ])) + assert_equal(x.mask, [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]) + x.put(i, masked_array([0, 2, 4, 6], [1, 0, 1, 0])) + assert_array_equal(x, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]) + assert_equal(x.mask, [1, 0, 0, 0, 1, 1, 0, 0, 0, 0]) + + x = masked_array(arange(10), mask=[1, 0, 0, 0, 0] * 2) + put(x, i, [6, 4, 2, 0]) + assert_equal(x, asarray([6, 1, 4, 3, 2, 5, 0, 7, 8, 9, ])) + assert_equal(x.mask, [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]) + put(x, i, masked_array([0, 2, 4, 6], [1, 0, 1, 0])) + assert_array_equal(x, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]) + assert_equal(x.mask, [1, 0, 0, 0, 1, 1, 0, 0, 0, 0]) + + def test_put_nomask(self): + # GitHub issue 6425 + x = zeros(10) + z = array([3., -1.], mask=[False, True]) + + x.put([1, 2], z) + assert_(x[0] is not masked) + assert_equal(x[0], 0) + assert_(x[1] is not masked) + assert_equal(x[1], 3) + assert_(x[2] is masked) + assert_(x[3] is not masked) + assert_equal(x[3], 0) + + def test_put_hardmask(self): + # Tests put on hardmask + d = arange(5) + n = [0, 0, 0, 1, 1] + m = make_mask(n) + xh = array(d + 1, mask=m, hard_mask=True, copy=True) + xh.put([4, 2, 0, 1, 3], [1, 2, 3, 4, 5]) + assert_equal(xh._data, [3, 4, 2, 4, 5]) + + def test_putmask(self): + x = arange(6) + 1 + mx = array(x, mask=[0, 0, 0, 1, 1, 1]) + mask = [0, 0, 1, 0, 0, 1] + # w/o mask, w/o masked values + xx = x.copy() + putmask(xx, mask, 99) + assert_equal(xx, [1, 2, 99, 4, 5, 99]) + # w/ mask, w/o masked values + mxx = mx.copy() + putmask(mxx, mask, 99) + assert_equal(mxx._data, [1, 2, 99, 4, 5, 99]) + assert_equal(mxx._mask, [0, 0, 0, 1, 1, 0]) + # w/o mask, w/ masked values + values = array([10, 20, 30, 40, 50, 60], mask=[1, 1, 1, 0, 0, 0]) + xx = x.copy() + putmask(xx, mask, values) + assert_equal(xx._data, [1, 2, 30, 4, 5, 60]) + assert_equal(xx._mask, [0, 0, 1, 0, 0, 0]) + # w/ mask, w/ masked values + mxx = mx.copy() + putmask(mxx, mask, values) + assert_equal(mxx._data, [1, 2, 30, 4, 5, 60]) + assert_equal(mxx._mask, [0, 0, 1, 1, 1, 0]) + # w/ mask, w/ masked values + hardmask + mxx = mx.copy() + mxx.harden_mask() + putmask(mxx, mask, values) + assert_equal(mxx, [1, 2, 30, 4, 5, 60]) + + def test_ravel(self): + # Tests ravel + a = array([[1, 2, 3, 4, 5]], mask=[[0, 1, 0, 0, 0]]) + aravel = a.ravel() + assert_equal(aravel._mask.shape, aravel.shape) + a = array([0, 0], mask=[1, 1]) + aravel = a.ravel() + assert_equal(aravel._mask.shape, a.shape) + # Checks that small_mask is preserved + a = array([1, 2, 3, 4], mask=[0, 0, 0, 0], shrink=False) + assert_equal(a.ravel()._mask, [0, 0, 0, 0]) + # Test that the fill_value is preserved + a.fill_value = -99 + a.shape = (2, 2) + ar = a.ravel() + assert_equal(ar._mask, [0, 0, 0, 0]) + assert_equal(ar._data, [1, 2, 3, 4]) + assert_equal(ar.fill_value, -99) + # Test index ordering + assert_equal(a.ravel(order='C'), [1, 2, 3, 4]) + assert_equal(a.ravel(order='F'), [1, 3, 2, 4]) + + @pytest.mark.parametrize("order", "AKCF") + @pytest.mark.parametrize("data_order", "CF") + def test_ravel_order(self, order, data_order): + # Ravelling must ravel mask and data in the same order always to avoid + # misaligning the two in the ravel result. + arr = np.ones((5, 10), order=data_order) + arr[0, :] = 0 + mask = np.ones((10, 5), dtype=bool, order=data_order).T + mask[0, :] = False + x = array(arr, mask=mask) + assert x._data.flags.fnc != x._mask.flags.fnc + assert (x.filled(0) == 0).all() + raveled = x.ravel(order) + assert (raveled.filled(0) == 0).all() + + # NOTE: Can be wrong if arr order is neither C nor F and `order="K"` + assert_array_equal(arr.ravel(order), x.ravel(order)._data) + + def test_reshape(self): + # Tests reshape + x = arange(4) + x[0] = masked + y = x.reshape(2, 2) + assert_equal(y.shape, (2, 2,)) + assert_equal(y._mask.shape, (2, 2,)) + assert_equal(x.shape, (4,)) + assert_equal(x._mask.shape, (4,)) + + def test_sort(self): + # Test sort + x = array([1, 4, 2, 3], mask=[0, 1, 0, 0], dtype=np.uint8) + + sortedx = sort(x) + assert_equal(sortedx._data, [1, 2, 3, 4]) + assert_equal(sortedx._mask, [0, 0, 0, 1]) + + sortedx = sort(x, endwith=False) + assert_equal(sortedx._data, [4, 1, 2, 3]) + assert_equal(sortedx._mask, [1, 0, 0, 0]) + + x.sort() + assert_equal(x._data, [1, 2, 3, 4]) + assert_equal(x._mask, [0, 0, 0, 1]) + + x = array([1, 4, 2, 3], mask=[0, 1, 0, 0], dtype=np.uint8) + x.sort(endwith=False) + assert_equal(x._data, [4, 1, 2, 3]) + assert_equal(x._mask, [1, 0, 0, 0]) + + x = [1, 4, 2, 3] + sortedx = sort(x) + assert_(not isinstance(sorted, MaskedArray)) + + x = array([0, 1, -1, -2, 2], mask=nomask, dtype=np.int8) + sortedx = sort(x, endwith=False) + assert_equal(sortedx._data, [-2, -1, 0, 1, 2]) + x = array([0, 1, -1, -2, 2], mask=[0, 1, 0, 0, 1], dtype=np.int8) + sortedx = sort(x, endwith=False) + assert_equal(sortedx._data, [1, 2, -2, -1, 0]) + assert_equal(sortedx._mask, [1, 1, 0, 0, 0]) + + x = array([0, -1], dtype=np.int8) + sortedx = sort(x, kind="stable") + assert_equal(sortedx, array([-1, 0], dtype=np.int8)) + + def test_stable_sort(self): + x = array([1, 2, 3, 1, 2, 3], dtype=np.uint8) + expected = array([0, 3, 1, 4, 2, 5]) + computed = argsort(x, kind='stable') + assert_equal(computed, expected) + + def test_argsort_matches_sort(self): + x = array([1, 4, 2, 3], mask=[0, 1, 0, 0], dtype=np.uint8) + + for kwargs in [{}, + {"endwith": True}, + {"endwith": False}, + {"fill_value": 2}, + {"fill_value": 2, "endwith": True}, + {"fill_value": 2, "endwith": False}]: + sortedx = sort(x, **kwargs) + argsortedx = x[argsort(x, **kwargs)] + assert_equal(sortedx._data, argsortedx._data) + assert_equal(sortedx._mask, argsortedx._mask) + + def test_sort_2d(self): + # Check sort of 2D array. + # 2D array w/o mask + a = masked_array([[8, 4, 1], [2, 0, 9]]) + a.sort(0) + assert_equal(a, [[2, 0, 1], [8, 4, 9]]) + a = masked_array([[8, 4, 1], [2, 0, 9]]) + a.sort(1) + assert_equal(a, [[1, 4, 8], [0, 2, 9]]) + # 2D array w/mask + a = masked_array([[8, 4, 1], [2, 0, 9]], mask=[[1, 0, 0], [0, 0, 1]]) + a.sort(0) + assert_equal(a, [[2, 0, 1], [8, 4, 9]]) + assert_equal(a._mask, [[0, 0, 0], [1, 0, 1]]) + a = masked_array([[8, 4, 1], [2, 0, 9]], mask=[[1, 0, 0], [0, 0, 1]]) + a.sort(1) + assert_equal(a, [[1, 4, 8], [0, 2, 9]]) + assert_equal(a._mask, [[0, 0, 1], [0, 0, 1]]) + # 3D + a = masked_array([[[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[1, 2, 3], [7, 8, 9], [4, 5, 6]], + [[7, 8, 9], [1, 2, 3], [4, 5, 6]], + [[4, 5, 6], [1, 2, 3], [7, 8, 9]]]) + a[a % 4 == 0] = masked + am = a.copy() + an = a.filled(99) + am.sort(0) + an.sort(0) + assert_equal(am, an) + am = a.copy() + an = a.filled(99) + am.sort(1) + an.sort(1) + assert_equal(am, an) + am = a.copy() + an = a.filled(99) + am.sort(2) + an.sort(2) + assert_equal(am, an) + + def test_sort_flexible(self): + # Test sort on structured dtype. + a = array( + data=[(3, 3), (3, 2), (2, 2), (2, 1), (1, 0), (1, 1), (1, 2)], + mask=[(0, 0), (0, 1), (0, 0), (0, 0), (1, 0), (0, 0), (0, 0)], + dtype=[('A', int), ('B', int)]) + mask_last = array( + data=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3), (3, 2), (1, 0)], + mask=[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (1, 0)], + dtype=[('A', int), ('B', int)]) + mask_first = array( + data=[(1, 0), (1, 1), (1, 2), (2, 1), (2, 2), (3, 2), (3, 3)], + mask=[(1, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (0, 0)], + dtype=[('A', int), ('B', int)]) + + test = sort(a) + assert_equal(test, mask_last) + assert_equal(test.mask, mask_last.mask) + + test = sort(a, endwith=False) + assert_equal(test, mask_first) + assert_equal(test.mask, mask_first.mask) + + # Test sort on dtype with subarray (gh-8069) + # Just check that the sort does not error, structured array subarrays + # are treated as byte strings and that leads to differing behavior + # depending on endianness and `endwith`. + dt = np.dtype([('v', int, 2)]) + a = a.view(dt) + test = sort(a) + test = sort(a, endwith=False) + + def test_argsort(self): + # Test argsort + a = array([1, 5, 2, 4, 3], mask=[1, 0, 0, 1, 0]) + assert_equal(np.argsort(a), argsort(a)) + + def test_squeeze(self): + # Check squeeze + data = masked_array([[1, 2, 3]]) + assert_equal(data.squeeze(), [1, 2, 3]) + data = masked_array([[1, 2, 3]], mask=[[1, 1, 1]]) + assert_equal(data.squeeze(), [1, 2, 3]) + assert_equal(data.squeeze()._mask, [1, 1, 1]) + + # normal ndarrays return a view + arr = np.array([[1]]) + arr_sq = arr.squeeze() + assert_equal(arr_sq, 1) + arr_sq[...] = 2 + assert_equal(arr[0, 0], 2) + + # so maskedarrays should too + m_arr = masked_array([[1]], mask=True) + m_arr_sq = m_arr.squeeze() + assert_(m_arr_sq is not np.ma.masked) + assert_equal(m_arr_sq.mask, True) + m_arr_sq[...] = 2 + assert_equal(m_arr[0, 0], 2) + + def test_swapaxes(self): + # Tests swapaxes on MaskedArrays. + x = np.array([8.375, 7.545, 8.828, 8.5, 1.757, 5.928, + 8.43, 7.78, 9.865, 5.878, 8.979, 4.732, + 3.012, 6.022, 5.095, 3.116, 5.238, 3.957, + 6.04, 9.63, 7.712, 3.382, 4.489, 6.479, + 7.189, 9.645, 5.395, 4.961, 9.894, 2.893, + 7.357, 9.828, 6.272, 3.758, 6.693, 0.993]) + m = np.array([0, 1, 0, 1, 0, 0, + 1, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, 0]) + mX = array(x, mask=m).reshape(6, 6) + mXX = mX.reshape(3, 2, 2, 3) + + mXswapped = mX.swapaxes(0, 1) + assert_equal(mXswapped[-1], mX[:, -1]) + + mXXswapped = mXX.swapaxes(0, 2) + assert_equal(mXXswapped.shape, (2, 2, 3, 3)) + + def test_take(self): + # Tests take + x = masked_array([10, 20, 30, 40], [0, 1, 0, 1]) + assert_equal(x.take([0, 0, 3]), masked_array([10, 10, 40], [0, 0, 1])) + assert_equal(x.take([0, 0, 3]), x[[0, 0, 3]]) + assert_equal(x.take([[0, 1], [0, 1]]), + masked_array([[10, 20], [10, 20]], [[0, 1], [0, 1]])) + + # assert_equal crashes when passed np.ma.mask + assert_(x[1] is np.ma.masked) + assert_(x.take(1) is np.ma.masked) + + x = array([[10, 20, 30], [40, 50, 60]], mask=[[0, 0, 1], [1, 0, 0, ]]) + assert_equal(x.take([0, 2], axis=1), + array([[10, 30], [40, 60]], mask=[[0, 1], [1, 0]])) + assert_equal(take(x, [0, 2], axis=1), + array([[10, 30], [40, 60]], mask=[[0, 1], [1, 0]])) + + def test_take_masked_indices(self): + # Test take w/ masked indices + a = np.array((40, 18, 37, 9, 22)) + indices = np.arange(3)[None, :] + np.arange(5)[:, None] + mindices = array(indices, mask=(indices >= len(a))) + # No mask + test = take(a, mindices, mode='clip') + ctrl = array([[40, 18, 37], + [18, 37, 9], + [37, 9, 22], + [9, 22, 22], + [22, 22, 22]]) + assert_equal(test, ctrl) + # Masked indices + test = take(a, mindices) + ctrl = array([[40, 18, 37], + [18, 37, 9], + [37, 9, 22], + [9, 22, 40], + [22, 40, 40]]) + ctrl[3, 2] = ctrl[4, 1] = ctrl[4, 2] = masked + assert_equal(test, ctrl) + assert_equal(test.mask, ctrl.mask) + # Masked input + masked indices + a = array((40, 18, 37, 9, 22), mask=(0, 1, 0, 0, 0)) + test = take(a, mindices) + ctrl[0, 1] = ctrl[1, 0] = masked + assert_equal(test, ctrl) + assert_equal(test.mask, ctrl.mask) + + def test_tolist(self): + # Tests to list + # ... on 1D + x = array(np.arange(12)) + x[[1, -2]] = masked + xlist = x.tolist() + assert_(xlist[1] is None) + assert_(xlist[-2] is None) + # ... on 2D + x.shape = (3, 4) + xlist = x.tolist() + ctrl = [[0, None, 2, 3], [4, 5, 6, 7], [8, 9, None, 11]] + assert_equal(xlist[0], [0, None, 2, 3]) + assert_equal(xlist[1], [4, 5, 6, 7]) + assert_equal(xlist[2], [8, 9, None, 11]) + assert_equal(xlist, ctrl) + # ... on structured array w/ masked records + x = array(list(zip([1, 2, 3], + [1.1, 2.2, 3.3], + ['one', 'two', 'thr'])), + dtype=[('a', int), ('b', float), ('c', '|S8')]) + x[-1] = masked + assert_equal(x.tolist(), + [(1, 1.1, b'one'), + (2, 2.2, b'two'), + (None, None, None)]) + # ... on structured array w/ masked fields + a = array([(1, 2,), (3, 4)], mask=[(0, 1), (0, 0)], + dtype=[('a', int), ('b', int)]) + test = a.tolist() + assert_equal(test, [[1, None], [3, 4]]) + # ... on mvoid + a = a[0] + test = a.tolist() + assert_equal(test, [1, None]) + + def test_tolist_specialcase(self): + # Test mvoid.tolist: make sure we return a standard Python object + a = array([(0, 1), (2, 3)], dtype=[('a', int), ('b', int)]) + # w/o mask: each entry is a np.void whose elements are standard Python + for entry in a: + for item in entry.tolist(): + assert_(not isinstance(item, np.generic)) + # w/ mask: each entry is a ma.void whose elements should be + # standard Python + a.mask[0] = (0, 1) + for entry in a: + for item in entry.tolist(): + assert_(not isinstance(item, np.generic)) + + def test_toflex(self): + # Test the conversion to records + data = arange(10) + record = data.toflex() + assert_equal(record['_data'], data._data) + assert_equal(record['_mask'], data._mask) + + data[[0, 1, 2, -1]] = masked + record = data.toflex() + assert_equal(record['_data'], data._data) + assert_equal(record['_mask'], data._mask) + + ndtype = [('i', int), ('s', '|S3'), ('f', float)] + data = array(list(zip(np.arange(10), + 'ABCDEFGHIJKLM', + np.random.rand(10))), + dtype=ndtype) + data[[0, 1, 2, -1]] = masked + record = data.toflex() + assert_equal(record['_data'], data._data) + assert_equal(record['_mask'], data._mask) + + ndtype = np.dtype("int, (2,3)float, float") + data = array(list(zip(np.arange(10), + np.random.rand(10), + np.random.rand(10))), + dtype=ndtype) + data[[0, 1, 2, -1]] = masked + record = data.toflex() + assert_equal_records(record['_data'], data._data) + assert_equal_records(record['_mask'], data._mask) + + def test_fromflex(self): + # Test the reconstruction of a masked_array from a record + a = array([1, 2, 3]) + test = fromflex(a.toflex()) + assert_equal(test, a) + assert_equal(test.mask, a.mask) + + a = array([1, 2, 3], mask=[0, 0, 1]) + test = fromflex(a.toflex()) + assert_equal(test, a) + assert_equal(test.mask, a.mask) + + a = array([(1, 1.), (2, 2.), (3, 3.)], mask=[(1, 0), (0, 0), (0, 1)], + dtype=[('A', int), ('B', float)]) + test = fromflex(a.toflex()) + assert_equal(test, a) + assert_equal(test.data, a.data) + + def test_arraymethod(self): + # Test a _arraymethod w/ n argument + marray = masked_array([[1, 2, 3, 4, 5]], mask=[0, 0, 1, 0, 0]) + control = masked_array([[1], [2], [3], [4], [5]], + mask=[0, 0, 1, 0, 0]) + assert_equal(marray.T, control) + assert_equal(marray.transpose(), control) + + assert_equal(MaskedArray.cumsum(marray.T, 0), control.cumsum(0)) + + def test_arraymethod_0d(self): + # gh-9430 + x = np.ma.array(42, mask=True) + assert_equal(x.T.mask, x.mask) + assert_equal(x.T.data, x.data) + + def test_transpose_view(self): + x = np.ma.array([[1, 2, 3], [4, 5, 6]]) + x[0, 1] = np.ma.masked + xt = x.T + + xt[1, 0] = 10 + xt[0, 1] = np.ma.masked + + assert_equal(x.data, xt.T.data) + assert_equal(x.mask, xt.T.mask) + + def test_diagonal_view(self): + x = np.ma.zeros((3, 3)) + x[0, 0] = 10 + x[1, 1] = np.ma.masked + x[2, 2] = 20 + xd = x.diagonal() + x[1, 1] = 15 + assert_equal(xd.mask, x.diagonal().mask) + assert_equal(xd.data, x.diagonal().data) + + +class TestMaskedArrayMathMethods: + + def setup_method(self): + # Base data definition. + x = np.array([8.375, 7.545, 8.828, 8.5, 1.757, 5.928, + 8.43, 7.78, 9.865, 5.878, 8.979, 4.732, + 3.012, 6.022, 5.095, 3.116, 5.238, 3.957, + 6.04, 9.63, 7.712, 3.382, 4.489, 6.479, + 7.189, 9.645, 5.395, 4.961, 9.894, 2.893, + 7.357, 9.828, 6.272, 3.758, 6.693, 0.993]) + X = x.reshape(6, 6) + XX = x.reshape(3, 2, 2, 3) + + m = np.array([0, 1, 0, 1, 0, 0, + 1, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, 0]) + mx = array(data=x, mask=m) + mX = array(data=X, mask=m.reshape(X.shape)) + mXX = array(data=XX, mask=m.reshape(XX.shape)) + + m2 = np.array([1, 1, 0, 1, 0, 0, + 1, 1, 1, 1, 0, 1, + 0, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 1, 0, + 0, 0, 1, 0, 1, 1]) + m2x = array(data=x, mask=m2) + m2X = array(data=X, mask=m2.reshape(X.shape)) + m2XX = array(data=XX, mask=m2.reshape(XX.shape)) + self.d = (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) + + def test_cumsumprod(self): + # Tests cumsum & cumprod on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + mXcp = mX.cumsum(0) + assert_equal(mXcp._data, mX.filled(0).cumsum(0)) + mXcp = mX.cumsum(1) + assert_equal(mXcp._data, mX.filled(0).cumsum(1)) + + mXcp = mX.cumprod(0) + assert_equal(mXcp._data, mX.filled(1).cumprod(0)) + mXcp = mX.cumprod(1) + assert_equal(mXcp._data, mX.filled(1).cumprod(1)) + + def test_cumsumprod_with_output(self): + # Tests cumsum/cumprod w/ output + xm = array(np.random.uniform(0, 10, 12)).reshape(3, 4) + xm[:, 0] = xm[0] = xm[-1, -1] = masked + + for funcname in ('cumsum', 'cumprod'): + npfunc = getattr(np, funcname) + xmmeth = getattr(xm, funcname) + + # A ndarray as explicit input + output = np.empty((3, 4), dtype=float) + output.fill(-9999) + result = npfunc(xm, axis=0, out=output) + # ... the result should be the given output + assert_(result is output) + assert_equal(result, xmmeth(axis=0, out=output)) + + output = empty((3, 4), dtype=int) + result = xmmeth(axis=0, out=output) + assert_(result is output) + + def test_ptp(self): + # Tests ptp on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + (n, m) = X.shape + assert_equal(mx.ptp(), np.ptp(mx.compressed())) + rows = np.zeros(n, float) + cols = np.zeros(m, float) + for k in range(m): + cols[k] = np.ptp(mX[:, k].compressed()) + for k in range(n): + rows[k] = np.ptp(mX[k].compressed()) + assert_equal(mX.ptp(0), cols) + assert_equal(mX.ptp(1), rows) + + def test_add_object(self): + x = masked_array(['a', 'b'], mask=[1, 0], dtype=object) + y = x + 'x' + assert_equal(y[1], 'bx') + assert_(y.mask[0]) + + def test_sum_object(self): + # Test sum on object dtype + a = masked_array([1, 2, 3], mask=[1, 0, 0], dtype=object) + assert_equal(a.sum(), 5) + a = masked_array([[1, 2, 3], [4, 5, 6]], dtype=object) + assert_equal(a.sum(axis=0), [5, 7, 9]) + + def test_prod_object(self): + # Test prod on object dtype + a = masked_array([1, 2, 3], mask=[1, 0, 0], dtype=object) + assert_equal(a.prod(), 2 * 3) + a = masked_array([[1, 2, 3], [4, 5, 6]], dtype=object) + assert_equal(a.prod(axis=0), [4, 10, 18]) + + def test_meananom_object(self): + # Test mean/anom on object dtype + a = masked_array([1, 2, 3], dtype=object) + assert_equal(a.mean(), 2) + assert_equal(a.anom(), [-1, 0, 1]) + + def test_anom_shape(self): + a = masked_array([1, 2, 3]) + assert_equal(a.anom().shape, a.shape) + a.mask = True + assert_equal(a.anom().shape, a.shape) + assert_(np.ma.is_masked(a.anom())) + + def test_anom(self): + a = masked_array(np.arange(1, 7).reshape(2, 3)) + assert_almost_equal(a.anom(), + [[-2.5, -1.5, -0.5], [0.5, 1.5, 2.5]]) + assert_almost_equal(a.anom(axis=0), + [[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) + assert_almost_equal(a.anom(axis=1), + [[-1., 0., 1.], [-1., 0., 1.]]) + a.mask = [[0, 0, 1], [0, 1, 0]] + mval = -99 + assert_almost_equal(a.anom().filled(mval), + [[-2.25, -1.25, mval], [0.75, mval, 2.75]]) + assert_almost_equal(a.anom(axis=0).filled(mval), + [[-1.5, 0.0, mval], [1.5, mval, 0.0]]) + assert_almost_equal(a.anom(axis=1).filled(mval), + [[-0.5, 0.5, mval], [-1.0, mval, 1.0]]) + + def test_trace(self): + # Tests trace on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + mXdiag = mX.diagonal() + assert_equal(mX.trace(), mX.diagonal().compressed().sum()) + assert_almost_equal(mX.trace(), + X.trace() - sum(mXdiag.mask * X.diagonal(), + axis=0)) + assert_equal(np.trace(mX), mX.trace()) + + # gh-5560 + arr = np.arange(2 * 4 * 4).reshape(2, 4, 4) + m_arr = np.ma.masked_array(arr, False) + assert_equal(arr.trace(axis1=1, axis2=2), m_arr.trace(axis1=1, axis2=2)) + + def test_dot(self): + # Tests dot on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + fx = mx.filled(0) + r = mx.dot(mx) + assert_almost_equal(r.filled(0), fx.dot(fx)) + assert_(r.mask is nomask) + + fX = mX.filled(0) + r = mX.dot(mX) + assert_almost_equal(r.filled(0), fX.dot(fX)) + assert_(r.mask[1, 3]) + r1 = empty_like(r) + mX.dot(mX, out=r1) + assert_almost_equal(r, r1) + + mYY = mXX.swapaxes(-1, -2) + fXX, fYY = mXX.filled(0), mYY.filled(0) + r = mXX.dot(mYY) + assert_almost_equal(r.filled(0), fXX.dot(fYY)) + r1 = empty_like(r) + mXX.dot(mYY, out=r1) + assert_almost_equal(r, r1) + + def test_dot_shape_mismatch(self): + # regression test + x = masked_array([[1, 2], [3, 4]], mask=[[0, 1], [0, 0]]) + y = masked_array([[1, 2], [3, 4]], mask=[[0, 1], [0, 0]]) + z = masked_array([[0, 1], [3, 3]]) + x.dot(y, out=z) + assert_almost_equal(z.filled(0), [[1, 0], [15, 16]]) + assert_almost_equal(z.mask, [[0, 1], [0, 0]]) + + def test_varmean_nomask(self): + # gh-5769 + foo = array([1, 2, 3, 4], dtype='f8') + bar = array([1, 2, 3, 4], dtype='f8') + assert_equal(type(foo.mean()), np.float64) + assert_equal(type(foo.var()), np.float64) + assert (foo.mean() == bar.mean()) is np.bool(True) + + # check array type is preserved and out works + foo = array(np.arange(16).reshape((4, 4)), dtype='f8') + bar = empty(4, dtype='f4') + assert_equal(type(foo.mean(axis=1)), MaskedArray) + assert_equal(type(foo.var(axis=1)), MaskedArray) + assert_(foo.mean(axis=1, out=bar) is bar) + assert_(foo.var(axis=1, out=bar) is bar) + + def test_varstd(self): + # Tests var & std on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + assert_almost_equal(mX.var(axis=None), mX.compressed().var()) + assert_almost_equal(mX.std(axis=None), mX.compressed().std()) + assert_almost_equal(mX.std(axis=None, ddof=1), + mX.compressed().std(ddof=1)) + assert_almost_equal(mX.var(axis=None, ddof=1), + mX.compressed().var(ddof=1)) + assert_equal(mXX.var(axis=3).shape, XX.var(axis=3).shape) + assert_equal(mX.var().shape, X.var().shape) + (mXvar0, mXvar1) = (mX.var(axis=0), mX.var(axis=1)) + assert_almost_equal(mX.var(axis=None, ddof=2), + mX.compressed().var(ddof=2)) + assert_almost_equal(mX.std(axis=None, ddof=2), + mX.compressed().std(ddof=2)) + for k in range(6): + assert_almost_equal(mXvar1[k], mX[k].compressed().var()) + assert_almost_equal(mXvar0[k], mX[:, k].compressed().var()) + assert_almost_equal(np.sqrt(mXvar0[k]), + mX[:, k].compressed().std()) + + @suppress_copy_mask_on_assignment + def test_varstd_specialcases(self): + # Test a special case for var + nout = np.array(-1, dtype=float) + mout = array(-1, dtype=float) + + x = array(arange(10), mask=True) + for methodname in ('var', 'std'): + method = getattr(x, methodname) + assert_(method() is masked) + assert_(method(0) is masked) + assert_(method(-1) is masked) + # Using a masked array as explicit output + method(out=mout) + assert_(mout is not masked) + assert_equal(mout.mask, True) + # Using a ndarray as explicit output + method(out=nout) + assert_(np.isnan(nout)) + + x = array(arange(10), mask=True) + x[-1] = 9 + for methodname in ('var', 'std'): + method = getattr(x, methodname) + assert_(method(ddof=1) is masked) + assert_(method(0, ddof=1) is masked) + assert_(method(-1, ddof=1) is masked) + # Using a masked array as explicit output + method(out=mout, ddof=1) + assert_(mout is not masked) + assert_equal(mout.mask, True) + # Using a ndarray as explicit output + method(out=nout, ddof=1) + assert_(np.isnan(nout)) + + def test_varstd_ddof(self): + a = array([[1, 1, 0], [1, 1, 0]], mask=[[0, 0, 1], [0, 0, 1]]) + test = a.std(axis=0, ddof=0) + assert_equal(test.filled(0), [0, 0, 0]) + assert_equal(test.mask, [0, 0, 1]) + test = a.std(axis=0, ddof=1) + assert_equal(test.filled(0), [0, 0, 0]) + assert_equal(test.mask, [0, 0, 1]) + test = a.std(axis=0, ddof=2) + assert_equal(test.filled(0), [0, 0, 0]) + assert_equal(test.mask, [1, 1, 1]) + + def test_diag(self): + # Test diag + x = arange(9).reshape((3, 3)) + x[1, 1] = masked + out = np.diag(x) + assert_equal(out, [0, 4, 8]) + out = diag(x) + assert_equal(out, [0, 4, 8]) + assert_equal(out.mask, [0, 1, 0]) + out = diag(out) + control = array([[0, 0, 0], [0, 4, 0], [0, 0, 8]], + mask=[[0, 0, 0], [0, 1, 0], [0, 0, 0]]) + assert_equal(out, control) + + def test_axis_methods_nomask(self): + # Test the combination nomask & methods w/ axis + a = array([[1, 2, 3], [4, 5, 6]]) + + assert_equal(a.sum(0), [5, 7, 9]) + assert_equal(a.sum(-1), [6, 15]) + assert_equal(a.sum(1), [6, 15]) + + assert_equal(a.prod(0), [4, 10, 18]) + assert_equal(a.prod(-1), [6, 120]) + assert_equal(a.prod(1), [6, 120]) + + assert_equal(a.min(0), [1, 2, 3]) + assert_equal(a.min(-1), [1, 4]) + assert_equal(a.min(1), [1, 4]) + + assert_equal(a.max(0), [4, 5, 6]) + assert_equal(a.max(-1), [3, 6]) + assert_equal(a.max(1), [3, 6]) + + @requires_memory(free_bytes=2 * 10000 * 1000 * 2) + def test_mean_overflow(self): + # Test overflow in masked arrays + # gh-20272 + a = masked_array(np.full((10000, 10000), 65535, dtype=np.uint16), + mask=np.zeros((10000, 10000))) + assert_equal(a.mean(), 65535.0) + + def test_diff_with_prepend(self): + # GH 22465 + x = np.array([1, 2, 2, 3, 4, 2, 1, 1]) + + a = np.ma.masked_equal(x[3:], value=2) + a_prep = np.ma.masked_equal(x[:3], value=2) + diff1 = np.ma.diff(a, prepend=a_prep, axis=0) + + b = np.ma.masked_equal(x, value=2) + diff2 = np.ma.diff(b, axis=0) + + assert_(np.ma.allequal(diff1, diff2)) + + def test_diff_with_append(self): + # GH 22465 + x = np.array([1, 2, 2, 3, 4, 2, 1, 1]) + + a = np.ma.masked_equal(x[:3], value=2) + a_app = np.ma.masked_equal(x[3:], value=2) + diff1 = np.ma.diff(a, append=a_app, axis=0) + + b = np.ma.masked_equal(x, value=2) + diff2 = np.ma.diff(b, axis=0) + + assert_(np.ma.allequal(diff1, diff2)) + + def test_diff_with_dim_0(self): + with pytest.raises( + ValueError, + match="diff requires input that is at least one dimensional" + ): + np.ma.diff(np.array(1)) + + def test_diff_with_n_0(self): + a = np.ma.masked_equal([1, 2, 2, 3, 4, 2, 1, 1], value=2) + diff = np.ma.diff(a, n=0, axis=0) + + assert_(np.ma.allequal(a, diff)) + + +class TestMaskedArrayMathMethodsComplex: + # Test class for miscellaneous MaskedArrays methods. + def setup_method(self): + # Base data definition. + x = np.array([8.375j, 7.545j, 8.828j, 8.5j, 1.757j, 5.928, + 8.43, 7.78, 9.865, 5.878, 8.979, 4.732, + 3.012, 6.022, 5.095, 3.116, 5.238, 3.957, + 6.04, 9.63, 7.712, 3.382, 4.489, 6.479j, + 7.189j, 9.645, 5.395, 4.961, 9.894, 2.893, + 7.357, 9.828, 6.272, 3.758, 6.693, 0.993j]) + X = x.reshape(6, 6) + XX = x.reshape(3, 2, 2, 3) + + m = np.array([0, 1, 0, 1, 0, 0, + 1, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, 0]) + mx = array(data=x, mask=m) + mX = array(data=X, mask=m.reshape(X.shape)) + mXX = array(data=XX, mask=m.reshape(XX.shape)) + + m2 = np.array([1, 1, 0, 1, 0, 0, + 1, 1, 1, 1, 0, 1, + 0, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 1, 0, + 0, 0, 1, 0, 1, 1]) + m2x = array(data=x, mask=m2) + m2X = array(data=X, mask=m2.reshape(X.shape)) + m2XX = array(data=XX, mask=m2.reshape(XX.shape)) + self.d = (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) + + def test_varstd(self): + # Tests var & std on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + assert_almost_equal(mX.var(axis=None), mX.compressed().var()) + assert_almost_equal(mX.std(axis=None), mX.compressed().std()) + assert_equal(mXX.var(axis=3).shape, XX.var(axis=3).shape) + assert_equal(mX.var().shape, X.var().shape) + (mXvar0, mXvar1) = (mX.var(axis=0), mX.var(axis=1)) + assert_almost_equal(mX.var(axis=None, ddof=2), + mX.compressed().var(ddof=2)) + assert_almost_equal(mX.std(axis=None, ddof=2), + mX.compressed().std(ddof=2)) + for k in range(6): + assert_almost_equal(mXvar1[k], mX[k].compressed().var()) + assert_almost_equal(mXvar0[k], mX[:, k].compressed().var()) + assert_almost_equal(np.sqrt(mXvar0[k]), + mX[:, k].compressed().std()) + + +class TestMaskedArrayFunctions: + # Test class for miscellaneous functions. + + def setup_method(self): + x = np.array([1., 1., 1., -2., pi / 2.0, 4., 5., -10., 10., 1., 2., 3.]) + y = np.array([5., 0., 3., 2., -1., -4., 0., -10., 10., 1., 0., 3.]) + m1 = [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0] + m2 = [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1] + xm = masked_array(x, mask=m1) + ym = masked_array(y, mask=m2) + xm.set_fill_value(1e+20) + self.info = (xm, ym) + + def test_masked_where_bool(self): + x = [1, 2] + y = masked_where(False, x) + assert_equal(y, [1, 2]) + assert_equal(y[1], 2) + + def test_masked_equal_wlist(self): + x = [1, 2, 3] + mx = masked_equal(x, 3) + assert_equal(mx, x) + assert_equal(mx._mask, [0, 0, 1]) + mx = masked_not_equal(x, 3) + assert_equal(mx, x) + assert_equal(mx._mask, [1, 1, 0]) + + def test_masked_equal_fill_value(self): + x = [1, 2, 3] + mx = masked_equal(x, 3) + assert_equal(mx._mask, [0, 0, 1]) + assert_equal(mx.fill_value, 3) + + def test_masked_where_condition(self): + # Tests masking functions. + x = array([1., 2., 3., 4., 5.]) + x[2] = masked + assert_equal(masked_where(greater(x, 2), x), masked_greater(x, 2)) + assert_equal(masked_where(greater_equal(x, 2), x), + masked_greater_equal(x, 2)) + assert_equal(masked_where(less(x, 2), x), masked_less(x, 2)) + assert_equal(masked_where(less_equal(x, 2), x), + masked_less_equal(x, 2)) + assert_equal(masked_where(not_equal(x, 2), x), masked_not_equal(x, 2)) + assert_equal(masked_where(equal(x, 2), x), masked_equal(x, 2)) + assert_equal(masked_where(not_equal(x, 2), x), masked_not_equal(x, 2)) + assert_equal(masked_where([1, 1, 0, 0, 0], [1, 2, 3, 4, 5]), + [99, 99, 3, 4, 5]) + + def test_masked_where_oddities(self): + # Tests some generic features. + atest = ones((10, 10, 10), dtype=float) + btest = zeros(atest.shape, MaskType) + ctest = masked_where(btest, atest) + assert_equal(atest, ctest) + + def test_masked_where_shape_constraint(self): + a = arange(10) + with assert_raises(IndexError): + masked_equal(1, a) + test = masked_equal(a, 1) + assert_equal(test.mask, [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) + + def test_masked_where_structured(self): + # test that masked_where on a structured array sets a structured + # mask (see issue #2972) + a = np.zeros(10, dtype=[("A", " 6, x) + + def test_masked_otherfunctions(self): + assert_equal(masked_inside(list(range(5)), 1, 3), + [0, 199, 199, 199, 4]) + assert_equal(masked_outside(list(range(5)), 1, 3), [199, 1, 2, 3, 199]) + assert_equal(masked_inside(array(list(range(5)), + mask=[1, 0, 0, 0, 0]), 1, 3).mask, + [1, 1, 1, 1, 0]) + assert_equal(masked_outside(array(list(range(5)), + mask=[0, 1, 0, 0, 0]), 1, 3).mask, + [1, 1, 0, 0, 1]) + assert_equal(masked_equal(array(list(range(5)), + mask=[1, 0, 0, 0, 0]), 2).mask, + [1, 0, 1, 0, 0]) + assert_equal(masked_not_equal(array([2, 2, 1, 2, 1], + mask=[1, 0, 0, 0, 0]), 2).mask, + [1, 0, 1, 0, 1]) + + def test_round(self): + a = array([1.23456, 2.34567, 3.45678, 4.56789, 5.67890], + mask=[0, 1, 0, 0, 0]) + assert_equal(a.round(), [1., 2., 3., 5., 6.]) + assert_equal(a.round(1), [1.2, 2.3, 3.5, 4.6, 5.7]) + assert_equal(a.round(3), [1.235, 2.346, 3.457, 4.568, 5.679]) + b = empty_like(a) + a.round(out=b) + assert_equal(b, [1., 2., 3., 5., 6.]) + + x = array([1., 2., 3., 4., 5.]) + c = array([1, 1, 1, 0, 0]) + x[2] = masked + z = where(c, x, -x) + assert_equal(z, [1., 2., 0., -4., -5]) + c[0] = masked + z = where(c, x, -x) + assert_equal(z, [1., 2., 0., -4., -5]) + assert_(z[0] is masked) + assert_(z[1] is not masked) + assert_(z[2] is masked) + + def test_round_with_output(self): + # Testing round with an explicit output + + xm = array(np.random.uniform(0, 10, 12)).reshape(3, 4) + xm[:, 0] = xm[0] = xm[-1, -1] = masked + + # A ndarray as explicit input + output = np.empty((3, 4), dtype=float) + output.fill(-9999) + result = np.round(xm, decimals=2, out=output) + # ... the result should be the given output + assert_(result is output) + assert_equal(result, xm.round(decimals=2, out=output)) + + output = empty((3, 4), dtype=float) + result = xm.round(decimals=2, out=output) + assert_(result is output) + + def test_round_with_scalar(self): + # Testing round with scalar/zero dimension input + # GH issue 2244 + a = array(1.1, mask=[False]) + assert_equal(a.round(), 1) + + a = array(1.1, mask=[True]) + assert_(a.round() is masked) + + a = array(1.1, mask=[False]) + output = np.empty(1, dtype=float) + output.fill(-9999) + a.round(out=output) + assert_equal(output, 1) + + a = array(1.1, mask=[False]) + output = array(-9999., mask=[True]) + a.round(out=output) + assert_equal(output[()], 1) + + a = array(1.1, mask=[True]) + output = array(-9999., mask=[False]) + a.round(out=output) + assert_(output[()] is masked) + + def test_identity(self): + a = identity(5) + assert_(isinstance(a, MaskedArray)) + assert_equal(a, np.identity(5)) + + def test_power(self): + x = -1.1 + assert_almost_equal(power(x, 2.), 1.21) + assert_(power(x, masked) is masked) + x = array([-1.1, -1.1, 1.1, 1.1, 0.]) + b = array([0.5, 2., 0.5, 2., -1.], mask=[0, 0, 0, 0, 1]) + y = power(x, b) + assert_almost_equal(y, [0, 1.21, 1.04880884817, 1.21, 0.]) + assert_equal(y._mask, [1, 0, 0, 0, 1]) + b.mask = nomask + y = power(x, b) + assert_equal(y._mask, [1, 0, 0, 0, 1]) + z = x ** b + assert_equal(z._mask, y._mask) + assert_almost_equal(z, y) + assert_almost_equal(z._data, y._data) + x **= b + assert_equal(x._mask, y._mask) + assert_almost_equal(x, y) + assert_almost_equal(x._data, y._data) + + def test_power_with_broadcasting(self): + # Test power w/ broadcasting + a2 = np.array([[1., 2., 3.], [4., 5., 6.]]) + a2m = array(a2, mask=[[1, 0, 0], [0, 0, 1]]) + b1 = np.array([2, 4, 3]) + b2 = np.array([b1, b1]) + b2m = array(b2, mask=[[0, 1, 0], [0, 1, 0]]) + + ctrl = array([[1 ** 2, 2 ** 4, 3 ** 3], [4 ** 2, 5 ** 4, 6 ** 3]], + mask=[[1, 1, 0], [0, 1, 1]]) + # No broadcasting, base & exp w/ mask + test = a2m ** b2m + assert_equal(test, ctrl) + assert_equal(test.mask, ctrl.mask) + # No broadcasting, base w/ mask, exp w/o mask + test = a2m ** b2 + assert_equal(test, ctrl) + assert_equal(test.mask, a2m.mask) + # No broadcasting, base w/o mask, exp w/ mask + test = a2 ** b2m + assert_equal(test, ctrl) + assert_equal(test.mask, b2m.mask) + + ctrl = array([[2 ** 2, 4 ** 4, 3 ** 3], [2 ** 2, 4 ** 4, 3 ** 3]], + mask=[[0, 1, 0], [0, 1, 0]]) + test = b1 ** b2m + assert_equal(test, ctrl) + assert_equal(test.mask, ctrl.mask) + test = b2m ** b1 + assert_equal(test, ctrl) + assert_equal(test.mask, ctrl.mask) + + @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm") + def test_where(self): + # Test the where function + x = np.array([1., 1., 1., -2., pi / 2.0, 4., 5., -10., 10., 1., 2., 3.]) + y = np.array([5., 0., 3., 2., -1., -4., 0., -10., 10., 1., 0., 3.]) + m1 = [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0] + m2 = [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1] + xm = masked_array(x, mask=m1) + ym = masked_array(y, mask=m2) + xm.set_fill_value(1e+20) + + d = where(xm > 2, xm, -9) + assert_equal(d, [-9., -9., -9., -9., -9., 4., + -9., -9., 10., -9., -9., 3.]) + assert_equal(d._mask, xm._mask) + d = where(xm > 2, -9, ym) + assert_equal(d, [5., 0., 3., 2., -1., -9., + -9., -10., -9., 1., 0., -9.]) + assert_equal(d._mask, [1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0]) + d = where(xm > 2, xm, masked) + assert_equal(d, [-9., -9., -9., -9., -9., 4., + -9., -9., 10., -9., -9., 3.]) + tmp = xm._mask.copy() + tmp[(xm <= 2).filled(True)] = True + assert_equal(d._mask, tmp) + + with np.errstate(invalid="warn"): + # The fill value is 1e20, it cannot be converted to `int`: + with pytest.warns(RuntimeWarning, match="invalid value"): + ixm = xm.astype(int) + d = where(ixm > 2, ixm, masked) + assert_equal(d, [-9, -9, -9, -9, -9, 4, -9, -9, 10, -9, -9, 3]) + assert_equal(d.dtype, ixm.dtype) + + def test_where_object(self): + a = np.array(None) + b = masked_array(None) + r = b.copy() + assert_equal(np.ma.where(True, a, a), r) + assert_equal(np.ma.where(True, b, b), r) + + def test_where_with_masked_choice(self): + x = arange(10) + x[3] = masked + c = x >= 8 + # Set False to masked + z = where(c, x, masked) + assert_(z.dtype is x.dtype) + assert_(z[3] is masked) + assert_(z[4] is masked) + assert_(z[7] is masked) + assert_(z[8] is not masked) + assert_(z[9] is not masked) + assert_equal(x, z) + # Set True to masked + z = where(c, masked, x) + assert_(z.dtype is x.dtype) + assert_(z[3] is masked) + assert_(z[4] is not masked) + assert_(z[7] is not masked) + assert_(z[8] is masked) + assert_(z[9] is masked) + + def test_where_with_masked_condition(self): + x = array([1., 2., 3., 4., 5.]) + c = array([1, 1, 1, 0, 0]) + x[2] = masked + z = where(c, x, -x) + assert_equal(z, [1., 2., 0., -4., -5]) + c[0] = masked + z = where(c, x, -x) + assert_equal(z, [1., 2., 0., -4., -5]) + assert_(z[0] is masked) + assert_(z[1] is not masked) + assert_(z[2] is masked) + + x = arange(1, 6) + x[-1] = masked + y = arange(1, 6) * 10 + y[2] = masked + c = array([1, 1, 1, 0, 0], mask=[1, 0, 0, 0, 0]) + cm = c.filled(1) + z = where(c, x, y) + zm = where(cm, x, y) + assert_equal(z, zm) + assert_(getmask(zm) is nomask) + assert_equal(zm, [1, 2, 3, 40, 50]) + z = where(c, masked, 1) + assert_equal(z, [99, 99, 99, 1, 1]) + z = where(c, 1, masked) + assert_equal(z, [99, 1, 1, 99, 99]) + + def test_where_type(self): + # Test the type conservation with where + x = np.arange(4, dtype=np.int32) + y = np.arange(4, dtype=np.float32) * 2.2 + test = where(x > 1.5, y, x).dtype + control = np.result_type(np.int32, np.float32) + assert_equal(test, control) + + def test_where_broadcast(self): + # Issue 8599 + x = np.arange(9).reshape(3, 3) + y = np.zeros(3) + core = np.where([1, 0, 1], x, y) + ma = where([1, 0, 1], x, y) + + assert_equal(core, ma) + assert_equal(core.dtype, ma.dtype) + + def test_where_structured(self): + # Issue 8600 + dt = np.dtype([('a', int), ('b', int)]) + x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt) + y = np.array((10, 20), dtype=dt) + core = np.where([0, 1, 1], x, y) + ma = np.where([0, 1, 1], x, y) + + assert_equal(core, ma) + assert_equal(core.dtype, ma.dtype) + + def test_where_structured_masked(self): + dt = np.dtype([('a', int), ('b', int)]) + x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt) + + ma = where([0, 1, 1], x, masked) + expected = masked_where([1, 0, 0], x) + + assert_equal(ma.dtype, expected.dtype) + assert_equal(ma, expected) + assert_equal(ma.mask, expected.mask) + + def test_masked_invalid_error(self): + a = np.arange(5, dtype=object) + a[3] = np.inf + a[2] = np.nan + with pytest.raises(TypeError, + match="not supported for the input types"): + np.ma.masked_invalid(a) + + def test_masked_invalid_pandas(self): + # getdata() used to be bad for pandas series due to its _data + # attribute. This test is a regression test mainly and may be + # removed if getdata() is adjusted. + class Series: + _data = "nonsense" + + def __array__(self, dtype=None, copy=None): + return np.array([5, np.nan, np.inf]) + + arr = np.ma.masked_invalid(Series()) + assert_array_equal(arr._data, np.array(Series())) + assert_array_equal(arr._mask, [False, True, True]) + + @pytest.mark.parametrize("copy", [True, False]) + def test_masked_invalid_full_mask(self, copy): + # Matplotlib relied on masked_invalid always returning a full mask + # (Also astropy projects, but were ok with it gh-22720 and gh-22842) + a = np.ma.array([1, 2, 3, 4]) + assert a._mask is nomask + res = np.ma.masked_invalid(a, copy=copy) + assert res.mask is not nomask + # mask of a should not be mutated + assert a.mask is nomask + assert np.may_share_memory(a._data, res._data) != copy + + def test_choose(self): + # Test choose + choices = [[0, 1, 2, 3], [10, 11, 12, 13], + [20, 21, 22, 23], [30, 31, 32, 33]] + chosen = choose([2, 3, 1, 0], choices) + assert_equal(chosen, array([20, 31, 12, 3])) + chosen = choose([2, 4, 1, 0], choices, mode='clip') + assert_equal(chosen, array([20, 31, 12, 3])) + chosen = choose([2, 4, 1, 0], choices, mode='wrap') + assert_equal(chosen, array([20, 1, 12, 3])) + # Check with some masked indices + indices_ = array([2, 4, 1, 0], mask=[1, 0, 0, 1]) + chosen = choose(indices_, choices, mode='wrap') + assert_equal(chosen, array([99, 1, 12, 99])) + assert_equal(chosen.mask, [1, 0, 0, 1]) + # Check with some masked choices + choices = array(choices, mask=[[0, 0, 0, 1], [1, 1, 0, 1], + [1, 0, 0, 0], [0, 0, 0, 0]]) + indices_ = [2, 3, 1, 0] + chosen = choose(indices_, choices, mode='wrap') + assert_equal(chosen, array([20, 31, 12, 3])) + assert_equal(chosen.mask, [1, 0, 0, 1]) + + def test_choose_with_out(self): + # Test choose with an explicit out keyword + choices = [[0, 1, 2, 3], [10, 11, 12, 13], + [20, 21, 22, 23], [30, 31, 32, 33]] + store = empty(4, dtype=int) + chosen = choose([2, 3, 1, 0], choices, out=store) + assert_equal(store, array([20, 31, 12, 3])) + assert_(store is chosen) + # Check with some masked indices + out + store = empty(4, dtype=int) + indices_ = array([2, 3, 1, 0], mask=[1, 0, 0, 1]) + chosen = choose(indices_, choices, mode='wrap', out=store) + assert_equal(store, array([99, 31, 12, 99])) + assert_equal(store.mask, [1, 0, 0, 1]) + # Check with some masked choices + out ina ndarray ! + choices = array(choices, mask=[[0, 0, 0, 1], [1, 1, 0, 1], + [1, 0, 0, 0], [0, 0, 0, 0]]) + indices_ = [2, 3, 1, 0] + store = empty(4, dtype=int).view(ndarray) + chosen = choose(indices_, choices, mode='wrap', out=store) + assert_equal(store, array([999999, 31, 12, 999999])) + + def test_reshape(self): + a = arange(10) + a[0] = masked + # Try the default + b = a.reshape((5, 2)) + assert_equal(b.shape, (5, 2)) + assert_(b.flags['C']) + # Try w/ arguments as list instead of tuple + b = a.reshape(5, 2) + assert_equal(b.shape, (5, 2)) + assert_(b.flags['C']) + # Try w/ order + b = a.reshape((5, 2), order='F') + assert_equal(b.shape, (5, 2)) + assert_(b.flags['F']) + # Try w/ order + b = a.reshape(5, 2, order='F') + assert_equal(b.shape, (5, 2)) + assert_(b.flags['F']) + + c = np.reshape(a, (2, 5)) + assert_(isinstance(c, MaskedArray)) + assert_equal(c.shape, (2, 5)) + assert_(c[0, 0] is masked) + assert_(c.flags['C']) + + def test_make_mask_descr(self): + # Flexible + ntype = [('a', float), ('b', float)] + test = make_mask_descr(ntype) + assert_equal(test, [('a', bool), ('b', bool)]) + assert_(test is make_mask_descr(test)) + + # Standard w/ shape + ntype = (float, 2) + test = make_mask_descr(ntype) + assert_equal(test, (bool, 2)) + assert_(test is make_mask_descr(test)) + + # Standard standard + ntype = float + test = make_mask_descr(ntype) + assert_equal(test, np.dtype(bool)) + assert_(test is make_mask_descr(test)) + + # Nested + ntype = [('a', float), ('b', [('ba', float), ('bb', float)])] + test = make_mask_descr(ntype) + control = np.dtype([('a', 'b1'), ('b', [('ba', 'b1'), ('bb', 'b1')])]) + assert_equal(test, control) + assert_(test is make_mask_descr(test)) + + # Named+ shape + ntype = [('a', (float, 2))] + test = make_mask_descr(ntype) + assert_equal(test, np.dtype([('a', (bool, 2))])) + assert_(test is make_mask_descr(test)) + + # 2 names + ntype = [(('A', 'a'), float)] + test = make_mask_descr(ntype) + assert_equal(test, np.dtype([(('A', 'a'), bool)])) + assert_(test is make_mask_descr(test)) + + # nested boolean types should preserve identity + base_type = np.dtype([('a', int, 3)]) + base_mtype = make_mask_descr(base_type) + sub_type = np.dtype([('a', int), ('b', base_mtype)]) + test = make_mask_descr(sub_type) + assert_equal(test, np.dtype([('a', bool), ('b', [('a', bool, 3)])])) + assert_(test.fields['b'][0] is base_mtype) + + def test_make_mask(self): + # Test make_mask + # w/ a list as an input + mask = [0, 1] + test = make_mask(mask) + assert_equal(test.dtype, MaskType) + assert_equal(test, [0, 1]) + # w/ a ndarray as an input + mask = np.array([0, 1], dtype=bool) + test = make_mask(mask) + assert_equal(test.dtype, MaskType) + assert_equal(test, [0, 1]) + # w/ a flexible-type ndarray as an input - use default + mdtype = [('a', bool), ('b', bool)] + mask = np.array([(0, 0), (0, 1)], dtype=mdtype) + test = make_mask(mask) + assert_equal(test.dtype, MaskType) + assert_equal(test, [1, 1]) + # w/ a flexible-type ndarray as an input - use input dtype + mdtype = [('a', bool), ('b', bool)] + mask = np.array([(0, 0), (0, 1)], dtype=mdtype) + test = make_mask(mask, dtype=mask.dtype) + assert_equal(test.dtype, mdtype) + assert_equal(test, mask) + # w/ a flexible-type ndarray as an input - use input dtype + mdtype = [('a', float), ('b', float)] + bdtype = [('a', bool), ('b', bool)] + mask = np.array([(0, 0), (0, 1)], dtype=mdtype) + test = make_mask(mask, dtype=mask.dtype) + assert_equal(test.dtype, bdtype) + assert_equal(test, np.array([(0, 0), (0, 1)], dtype=bdtype)) + # Ensure this also works for void + mask = np.array((False, True), dtype='?,?')[()] + assert_(isinstance(mask, np.void)) + test = make_mask(mask, dtype=mask.dtype) + assert_equal(test, mask) + assert_(test is not mask) + mask = np.array((0, 1), dtype='i4,i4')[()] + test2 = make_mask(mask, dtype=mask.dtype) + assert_equal(test2, test) + # test that nomask is returned when m is nomask. + bools = [True, False] + dtypes = [MaskType, float] + msgformat = 'copy=%s, shrink=%s, dtype=%s' + for cpy, shr, dt in itertools.product(bools, bools, dtypes): + res = make_mask(nomask, copy=cpy, shrink=shr, dtype=dt) + assert_(res is nomask, msgformat % (cpy, shr, dt)) + + def test_mask_or(self): + # Initialize + mtype = [('a', bool), ('b', bool)] + mask = np.array([(0, 0), (0, 1), (1, 0), (0, 0)], dtype=mtype) + # Test using nomask as input + test = mask_or(mask, nomask) + assert_equal(test, mask) + test = mask_or(nomask, mask) + assert_equal(test, mask) + # Using False as input + test = mask_or(mask, False) + assert_equal(test, mask) + # Using another array w / the same dtype + other = np.array([(0, 1), (0, 1), (0, 1), (0, 1)], dtype=mtype) + test = mask_or(mask, other) + control = np.array([(0, 1), (0, 1), (1, 1), (0, 1)], dtype=mtype) + assert_equal(test, control) + # Using another array w / a different dtype + othertype = [('A', bool), ('B', bool)] + other = np.array([(0, 1), (0, 1), (0, 1), (0, 1)], dtype=othertype) + try: + test = mask_or(mask, other) + except ValueError: + pass + # Using nested arrays + dtype = [('a', bool), ('b', [('ba', bool), ('bb', bool)])] + amask = np.array([(0, (1, 0)), (0, (1, 0))], dtype=dtype) + bmask = np.array([(1, (0, 1)), (0, (0, 0))], dtype=dtype) + cntrl = np.array([(1, (1, 1)), (0, (1, 0))], dtype=dtype) + assert_equal(mask_or(amask, bmask), cntrl) + + a = np.array([False, False]) + assert mask_or(a, a) is nomask # gh-27360 + + def test_allequal(self): + x = array([1, 2, 3], mask=[0, 0, 0]) + y = array([1, 2, 3], mask=[1, 0, 0]) + z = array([[1, 2, 3], [4, 5, 6]], mask=[[0, 0, 0], [1, 1, 1]]) + + assert allequal(x, y) + assert not allequal(x, y, fill_value=False) + assert allequal(x, z) + + # test allequal for the same input, with mask=nomask, this test is for + # the scenario raised in https://github.com/numpy/numpy/issues/27201 + assert allequal(x, x) + assert allequal(x, x, fill_value=False) + + assert allequal(y, y) + assert not allequal(y, y, fill_value=False) + + def test_flatten_mask(self): + # Tests flatten mask + # Standard dtype + mask = np.array([0, 0, 1], dtype=bool) + assert_equal(flatten_mask(mask), mask) + # Flexible dtype + mask = np.array([(0, 0), (0, 1)], dtype=[('a', bool), ('b', bool)]) + test = flatten_mask(mask) + control = np.array([0, 0, 0, 1], dtype=bool) + assert_equal(test, control) + + mdtype = [('a', bool), ('b', [('ba', bool), ('bb', bool)])] + data = [(0, (0, 0)), (0, (0, 1))] + mask = np.array(data, dtype=mdtype) + test = flatten_mask(mask) + control = np.array([0, 0, 0, 0, 0, 1], dtype=bool) + assert_equal(test, control) + + def test_on_ndarray(self): + # Test functions on ndarrays + a = np.array([1, 2, 3, 4]) + m = array(a, mask=False) + test = anom(a) + assert_equal(test, m.anom()) + test = reshape(a, (2, 2)) + assert_equal(test, m.reshape(2, 2)) + + def test_compress(self): + # Test compress function on ndarray and masked array + # Address Github #2495. + arr = np.arange(8) + arr.shape = 4, 2 + cond = np.array([True, False, True, True]) + control = arr[[0, 2, 3]] + test = np.ma.compress(cond, arr, axis=0) + assert_equal(test, control) + marr = np.ma.array(arr) + test = np.ma.compress(cond, marr, axis=0) + assert_equal(test, control) + + def test_compressed(self): + # Test ma.compressed function. + # Address gh-4026 + a = np.ma.array([1, 2]) + test = np.ma.compressed(a) + assert_(type(test) is np.ndarray) + + # Test case when input data is ndarray subclass + class A(np.ndarray): + pass + + a = np.ma.array(A(shape=0)) + test = np.ma.compressed(a) + assert_(type(test) is A) + + # Test that compress flattens + test = np.ma.compressed([[1], [2]]) + assert_equal(test.ndim, 1) + test = np.ma.compressed([[[[[1]]]]]) + assert_equal(test.ndim, 1) + + # Test case when input is MaskedArray subclass + class M(MaskedArray): + pass + + test = np.ma.compressed(M([[[]], [[]]])) + assert_equal(test.ndim, 1) + + # with .compressed() overridden + class M(MaskedArray): + def compressed(self): + return 42 + + test = np.ma.compressed(M([[[]], [[]]])) + assert_equal(test, 42) + + def test_convolve(self): + a = masked_equal(np.arange(5), 2) + b = np.array([1, 1]) + + result = masked_equal([0, 1, -1, -1, 7, 4], -1) + test = np.ma.convolve(a, b, mode='full') + assert_equal(test, result) + + test = np.ma.convolve(a, b, mode='same') + assert_equal(test, result[:-1]) + + test = np.ma.convolve(a, b, mode='valid') + assert_equal(test, result[1:-1]) + + result = masked_equal([0, 1, 1, 3, 7, 4], -1) + test = np.ma.convolve(a, b, mode='full', propagate_mask=False) + assert_equal(test, result) + + test = np.ma.convolve(a, b, mode='same', propagate_mask=False) + assert_equal(test, result[:-1]) + + test = np.ma.convolve(a, b, mode='valid', propagate_mask=False) + assert_equal(test, result[1:-1]) + + test = np.ma.convolve([1, 1], [1, 1, 1]) + assert_equal(test, masked_equal([1, 2, 2, 1], -1)) + + a = [1, 1] + b = masked_equal([1, -1, -1, 1], -1) + test = np.ma.convolve(a, b, propagate_mask=False) + assert_equal(test, masked_equal([1, 1, -1, 1, 1], -1)) + test = np.ma.convolve(a, b, propagate_mask=True) + assert_equal(test, masked_equal([-1, -1, -1, -1, -1], -1)) + + +class TestMaskedFields: + + def setup_method(self): + ilist = [1, 2, 3, 4, 5] + flist = [1.1, 2.2, 3.3, 4.4, 5.5] + slist = ['one', 'two', 'three', 'four', 'five'] + ddtype = [('a', int), ('b', float), ('c', '|S8')] + mdtype = [('a', bool), ('b', bool), ('c', bool)] + mask = [0, 1, 0, 0, 1] + base = array(list(zip(ilist, flist, slist)), mask=mask, dtype=ddtype) + self.data = {"base": base, "mask": mask, "ddtype": ddtype, "mdtype": mdtype} + + def test_set_records_masks(self): + base = self.data['base'] + mdtype = self.data['mdtype'] + # Set w/ nomask or masked + base.mask = nomask + assert_equal_records(base._mask, np.zeros(base.shape, dtype=mdtype)) + base.mask = masked + assert_equal_records(base._mask, np.ones(base.shape, dtype=mdtype)) + # Set w/ simple boolean + base.mask = False + assert_equal_records(base._mask, np.zeros(base.shape, dtype=mdtype)) + base.mask = True + assert_equal_records(base._mask, np.ones(base.shape, dtype=mdtype)) + # Set w/ list + base.mask = [0, 0, 0, 1, 1] + assert_equal_records(base._mask, + np.array([(x, x, x) for x in [0, 0, 0, 1, 1]], + dtype=mdtype)) + + def test_set_record_element(self): + # Check setting an element of a record) + base = self.data['base'] + (base_a, base_b, base_c) = (base['a'], base['b'], base['c']) + base[0] = (pi, pi, 'pi') + + assert_equal(base_a.dtype, int) + assert_equal(base_a._data, [3, 2, 3, 4, 5]) + + assert_equal(base_b.dtype, float) + assert_equal(base_b._data, [pi, 2.2, 3.3, 4.4, 5.5]) + + assert_equal(base_c.dtype, '|S8') + assert_equal(base_c._data, + [b'pi', b'two', b'three', b'four', b'five']) + + def test_set_record_slice(self): + base = self.data['base'] + (base_a, base_b, base_c) = (base['a'], base['b'], base['c']) + base[:3] = (pi, pi, 'pi') + + assert_equal(base_a.dtype, int) + assert_equal(base_a._data, [3, 3, 3, 4, 5]) + + assert_equal(base_b.dtype, float) + assert_equal(base_b._data, [pi, pi, pi, 4.4, 5.5]) + + assert_equal(base_c.dtype, '|S8') + assert_equal(base_c._data, + [b'pi', b'pi', b'pi', b'four', b'five']) + + def test_mask_element(self): + "Check record access" + base = self.data['base'] + base[0] = masked + + for n in ('a', 'b', 'c'): + assert_equal(base[n].mask, [1, 1, 0, 0, 1]) + assert_equal(base[n]._data, base._data[n]) + + def test_getmaskarray(self): + # Test getmaskarray on flexible dtype + ndtype = [('a', int), ('b', float)] + test = empty(3, dtype=ndtype) + assert_equal(getmaskarray(test), + np.array([(0, 0), (0, 0), (0, 0)], + dtype=[('a', '|b1'), ('b', '|b1')])) + test[:] = masked + assert_equal(getmaskarray(test), + np.array([(1, 1), (1, 1), (1, 1)], + dtype=[('a', '|b1'), ('b', '|b1')])) + + def test_view(self): + # Test view w/ flexible dtype + iterator = list(zip(np.arange(10), np.random.rand(10))) + data = np.array(iterator) + a = array(iterator, dtype=[('a', float), ('b', float)]) + a.mask[0] = (1, 0) + controlmask = np.array([1] + 19 * [0], dtype=bool) + # Transform globally to simple dtype + test = a.view(float) + assert_equal(test, data.ravel()) + assert_equal(test.mask, controlmask) + # Transform globally to dty + test = a.view((float, 2)) + assert_equal(test, data) + assert_equal(test.mask, controlmask.reshape(-1, 2)) + + def test_getitem(self): + ndtype = [('a', float), ('b', float)] + a = array(list(zip(np.random.rand(10), np.arange(10))), dtype=ndtype) + a.mask = np.array(list(zip([0, 0, 0, 0, 0, 0, 0, 0, 1, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 0])), + dtype=[('a', bool), ('b', bool)]) + + def _test_index(i): + assert_equal(type(a[i]), mvoid) + assert_equal_records(a[i]._data, a._data[i]) + assert_equal_records(a[i]._mask, a._mask[i]) + + assert_equal(type(a[i, ...]), MaskedArray) + assert_equal_records(a[i, ...]._data, a._data[i, ...]) + assert_equal_records(a[i, ...]._mask, a._mask[i, ...]) + + _test_index(1) # No mask + _test_index(0) # One element masked + _test_index(-2) # All element masked + + def test_setitem(self): + # Issue 4866: check that one can set individual items in [record][col] + # and [col][record] order + ndtype = np.dtype([('a', float), ('b', int)]) + ma = np.ma.MaskedArray([(1.0, 1), (2.0, 2)], dtype=ndtype) + ma['a'][1] = 3.0 + assert_equal(ma['a'], np.array([1.0, 3.0])) + ma[1]['a'] = 4.0 + assert_equal(ma['a'], np.array([1.0, 4.0])) + # Issue 2403 + mdtype = np.dtype([('a', bool), ('b', bool)]) + # soft mask + control = np.array([(False, True), (True, True)], dtype=mdtype) + a = np.ma.masked_all((2,), dtype=ndtype) + a['a'][0] = 2 + assert_equal(a.mask, control) + a = np.ma.masked_all((2,), dtype=ndtype) + a[0]['a'] = 2 + assert_equal(a.mask, control) + # hard mask + control = np.array([(True, True), (True, True)], dtype=mdtype) + a = np.ma.masked_all((2,), dtype=ndtype) + a.harden_mask() + a['a'][0] = 2 + assert_equal(a.mask, control) + a = np.ma.masked_all((2,), dtype=ndtype) + a.harden_mask() + a[0]['a'] = 2 + assert_equal(a.mask, control) + + def test_setitem_scalar(self): + # 8510 + mask_0d = np.ma.masked_array(1, mask=True) + arr = np.ma.arange(3) + arr[0] = mask_0d + assert_array_equal(arr.mask, [True, False, False]) + + def test_element_len(self): + # check that len() works for mvoid (Github issue #576) + for rec in self.data['base']: + assert_equal(len(rec), len(self.data['ddtype'])) + + +class TestMaskedObjectArray: + + def test_getitem(self): + arr = np.ma.array([None, None]) + for dt in [float, object]: + a0 = np.eye(2).astype(dt) + a1 = np.eye(3).astype(dt) + arr[0] = a0 + arr[1] = a1 + + assert_(arr[0] is a0) + assert_(arr[1] is a1) + assert_(isinstance(arr[0, ...], MaskedArray)) + assert_(isinstance(arr[1, ...], MaskedArray)) + assert_(arr[0, ...][()] is a0) + assert_(arr[1, ...][()] is a1) + + arr[0] = np.ma.masked + + assert_(arr[1] is a1) + assert_(isinstance(arr[0, ...], MaskedArray)) + assert_(isinstance(arr[1, ...], MaskedArray)) + assert_equal(arr[0, ...].mask, True) + assert_(arr[1, ...][()] is a1) + + # gh-5962 - object arrays of arrays do something special + assert_equal(arr[0].data, a0) + assert_equal(arr[0].mask, True) + assert_equal(arr[0, ...][()].data, a0) + assert_equal(arr[0, ...][()].mask, True) + + def test_nested_ma(self): + + arr = np.ma.array([None, None]) + # set the first object to be an unmasked masked constant. A little fiddly + arr[0, ...] = np.array([np.ma.masked], object)[0, ...] + + # check the above line did what we were aiming for + assert_(arr.data[0] is np.ma.masked) + + # test that getitem returned the value by identity + assert_(arr[0] is np.ma.masked) + + # now mask the masked value! + arr[0] = np.ma.masked + assert_(arr[0] is np.ma.masked) + + +class TestMaskedView: + + def setup_method(self): + iterator = list(zip(np.arange(10), np.random.rand(10))) + data = np.array(iterator) + a = array(iterator, dtype=[('a', float), ('b', float)]) + a.mask[0] = (1, 0) + controlmask = np.array([1] + 19 * [0], dtype=bool) + self.data = (data, a, controlmask) + + def test_view_to_nothing(self): + (data, a, controlmask) = self.data + test = a.view() + assert_(isinstance(test, MaskedArray)) + assert_equal(test._data, a._data) + assert_equal(test._mask, a._mask) + + def test_view_to_type(self): + (data, a, controlmask) = self.data + test = a.view(np.ndarray) + assert_(not isinstance(test, MaskedArray)) + assert_equal(test, a._data) + assert_equal_records(test, data.view(a.dtype).squeeze()) + + def test_view_to_simple_dtype(self): + (data, a, controlmask) = self.data + # View globally + test = a.view(float) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, data.ravel()) + assert_equal(test.mask, controlmask) + + def test_view_to_flexible_dtype(self): + (data, a, controlmask) = self.data + + test = a.view([('A', float), ('B', float)]) + assert_equal(test.mask.dtype.names, ('A', 'B')) + assert_equal(test['A'], a['a']) + assert_equal(test['B'], a['b']) + + test = a[0].view([('A', float), ('B', float)]) + assert_(isinstance(test, MaskedArray)) + assert_equal(test.mask.dtype.names, ('A', 'B')) + assert_equal(test['A'], a['a'][0]) + assert_equal(test['B'], a['b'][0]) + + test = a[-1].view([('A', float), ('B', float)]) + assert_(isinstance(test, MaskedArray)) + assert_equal(test.dtype.names, ('A', 'B')) + assert_equal(test['A'], a['a'][-1]) + assert_equal(test['B'], a['b'][-1]) + + def test_view_to_subdtype(self): + (data, a, controlmask) = self.data + # View globally + test = a.view((float, 2)) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, data) + assert_equal(test.mask, controlmask.reshape(-1, 2)) + # View on 1 masked element + test = a[0].view((float, 2)) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, data[0]) + assert_equal(test.mask, (1, 0)) + # View on 1 unmasked element + test = a[-1].view((float, 2)) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, data[-1]) + + def test_view_to_dtype_and_type(self): + (data, a, controlmask) = self.data + + test = a.view((float, 2), np.recarray) + assert_equal(test, data) + assert_(isinstance(test, np.recarray)) + assert_(not isinstance(test, MaskedArray)) + + +class TestOptionalArgs: + def test_ndarrayfuncs(self): + # test axis arg behaves the same as ndarray (including multiple axes) + + d = np.arange(24.0).reshape((2, 3, 4)) + m = np.zeros(24, dtype=bool).reshape((2, 3, 4)) + # mask out last element of last dimension + m[:, :, -1] = True + a = np.ma.array(d, mask=m) + + def testaxis(f, a, d): + numpy_f = numpy.__getattribute__(f) + ma_f = np.ma.__getattribute__(f) + + # test axis arg + assert_equal(ma_f(a, axis=1)[..., :-1], numpy_f(d[..., :-1], axis=1)) + assert_equal(ma_f(a, axis=(0, 1))[..., :-1], + numpy_f(d[..., :-1], axis=(0, 1))) + + def testkeepdims(f, a, d): + numpy_f = numpy.__getattribute__(f) + ma_f = np.ma.__getattribute__(f) + + # test keepdims arg + assert_equal(ma_f(a, keepdims=True).shape, + numpy_f(d, keepdims=True).shape) + assert_equal(ma_f(a, keepdims=False).shape, + numpy_f(d, keepdims=False).shape) + + # test both at once + assert_equal(ma_f(a, axis=1, keepdims=True)[..., :-1], + numpy_f(d[..., :-1], axis=1, keepdims=True)) + assert_equal(ma_f(a, axis=(0, 1), keepdims=True)[..., :-1], + numpy_f(d[..., :-1], axis=(0, 1), keepdims=True)) + + for f in ['sum', 'prod', 'mean', 'var', 'std']: + testaxis(f, a, d) + testkeepdims(f, a, d) + + for f in ['min', 'max']: + testaxis(f, a, d) + + d = (np.arange(24).reshape((2, 3, 4)) % 2 == 0) + a = np.ma.array(d, mask=m) + for f in ['all', 'any']: + testaxis(f, a, d) + testkeepdims(f, a, d) + + def test_count(self): + # test np.ma.count specially + + d = np.arange(24.0).reshape((2, 3, 4)) + m = np.zeros(24, dtype=bool).reshape((2, 3, 4)) + m[:, 0, :] = True + a = np.ma.array(d, mask=m) + + assert_equal(count(a), 16) + assert_equal(count(a, axis=1), 2 * ones((2, 4))) + assert_equal(count(a, axis=(0, 1)), 4 * ones((4,))) + assert_equal(count(a, keepdims=True), 16 * ones((1, 1, 1))) + assert_equal(count(a, axis=1, keepdims=True), 2 * ones((2, 1, 4))) + assert_equal(count(a, axis=(0, 1), keepdims=True), 4 * ones((1, 1, 4))) + assert_equal(count(a, axis=-2), 2 * ones((2, 4))) + assert_raises(ValueError, count, a, axis=(1, 1)) + assert_raises(AxisError, count, a, axis=3) + + # check the 'nomask' path + a = np.ma.array(d, mask=nomask) + + assert_equal(count(a), 24) + assert_equal(count(a, axis=1), 3 * ones((2, 4))) + assert_equal(count(a, axis=(0, 1)), 6 * ones((4,))) + assert_equal(count(a, keepdims=True), 24 * ones((1, 1, 1))) + assert_equal(np.ndim(count(a, keepdims=True)), 3) + assert_equal(count(a, axis=1, keepdims=True), 3 * ones((2, 1, 4))) + assert_equal(count(a, axis=(0, 1), keepdims=True), 6 * ones((1, 1, 4))) + assert_equal(count(a, axis=-2), 3 * ones((2, 4))) + assert_raises(ValueError, count, a, axis=(1, 1)) + assert_raises(AxisError, count, a, axis=3) + + # check the 'masked' singleton + assert_equal(count(np.ma.masked), 0) + + # check 0-d arrays do not allow axis > 0 + assert_raises(AxisError, count, np.ma.array(1), axis=1) + + +class TestMaskedConstant: + def _do_add_test(self, add): + # sanity check + assert_(add(np.ma.masked, 1) is np.ma.masked) + + # now try with a vector + vector = np.array([1, 2, 3]) + result = add(np.ma.masked, vector) + + # lots of things could go wrong here + assert_(result is not np.ma.masked) + assert_(not isinstance(result, np.ma.core.MaskedConstant)) + assert_equal(result.shape, vector.shape) + assert_equal(np.ma.getmask(result), np.ones(vector.shape, dtype=bool)) + + def test_ufunc(self): + self._do_add_test(np.add) + + def test_operator(self): + self._do_add_test(lambda a, b: a + b) + + def test_ctor(self): + m = np.ma.array(np.ma.masked) + + # most importantly, we do not want to create a new MaskedConstant + # instance + assert_(not isinstance(m, np.ma.core.MaskedConstant)) + assert_(m is not np.ma.masked) + + def test_repr(self): + # copies should not exist, but if they do, it should be obvious that + # something is wrong + assert_equal(repr(np.ma.masked), 'masked') + + # create a new instance in a weird way + masked2 = np.ma.MaskedArray.__new__(np.ma.core.MaskedConstant) + assert_not_equal(repr(masked2), 'masked') + + def test_pickle(self): + from io import BytesIO + + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + with BytesIO() as f: + pickle.dump(np.ma.masked, f, protocol=proto) + f.seek(0) + res = pickle.load(f) + assert_(res is np.ma.masked) + + def test_copy(self): + # gh-9328 + # copy is a no-op, like it is with np.True_ + assert_equal( + np.ma.masked.copy() is np.ma.masked, + np.True_.copy() is np.True_) + + def test__copy(self): + import copy + assert_( + copy.copy(np.ma.masked) is np.ma.masked) + + def test_deepcopy(self): + import copy + assert_( + copy.deepcopy(np.ma.masked) is np.ma.masked) + + def test_immutable(self): + orig = np.ma.masked + assert_raises(np.ma.core.MaskError, operator.setitem, orig, (), 1) + assert_raises(ValueError, operator.setitem, orig.data, (), 1) + assert_raises(ValueError, operator.setitem, orig.mask, (), False) + + view = np.ma.masked.view(np.ma.MaskedArray) + assert_raises(ValueError, operator.setitem, view, (), 1) + assert_raises(ValueError, operator.setitem, view.data, (), 1) + assert_raises(ValueError, operator.setitem, view.mask, (), False) + + def test_coercion_int(self): + a_i = np.zeros((), int) + assert_raises(MaskError, operator.setitem, a_i, (), np.ma.masked) + assert_raises(MaskError, int, np.ma.masked) + + def test_coercion_float(self): + a_f = np.zeros((), float) + assert_warns(UserWarning, operator.setitem, a_f, (), np.ma.masked) + assert_(np.isnan(a_f[()])) + + @pytest.mark.xfail(reason="See gh-9750") + def test_coercion_unicode(self): + a_u = np.zeros((), 'U10') + a_u[()] = np.ma.masked + assert_equal(a_u[()], '--') + + @pytest.mark.xfail(reason="See gh-9750") + def test_coercion_bytes(self): + a_b = np.zeros((), 'S10') + a_b[()] = np.ma.masked + assert_equal(a_b[()], b'--') + + def test_subclass(self): + # https://github.com/astropy/astropy/issues/6645 + class Sub(type(np.ma.masked)): + pass + + a = Sub() + assert_(a is Sub()) + assert_(a is not np.ma.masked) + assert_not_equal(repr(a), 'masked') + + def test_attributes_readonly(self): + assert_raises(AttributeError, setattr, np.ma.masked, 'shape', (1,)) + assert_raises(AttributeError, setattr, np.ma.masked, 'dtype', np.int64) + + +class TestMaskedWhereAliases: + + # TODO: Test masked_object, masked_equal, ... + + def test_masked_values(self): + res = masked_values(np.array([-32768.0]), np.int16(-32768)) + assert_equal(res.mask, [True]) + + res = masked_values(np.inf, np.inf) + assert_equal(res.mask, True) + + res = np.ma.masked_values(np.inf, -np.inf) + assert_equal(res.mask, False) + + res = np.ma.masked_values([1, 2, 3, 4], 5, shrink=True) + assert_(res.mask is np.ma.nomask) + + res = np.ma.masked_values([1, 2, 3, 4], 5, shrink=False) + assert_equal(res.mask, [False] * 4) + + +def test_masked_array(): + a = np.ma.array([0, 1, 2, 3], mask=[0, 0, 1, 0]) + assert_equal(np.argwhere(a), [[1], [3]]) + +def test_masked_array_no_copy(): + # check nomask array is updated in place + a = np.ma.array([1, 2, 3, 4]) + _ = np.ma.masked_where(a == 3, a, copy=False) + assert_array_equal(a.mask, [False, False, True, False]) + # check masked array is updated in place + a = np.ma.array([1, 2, 3, 4], mask=[1, 0, 0, 0]) + _ = np.ma.masked_where(a == 3, a, copy=False) + assert_array_equal(a.mask, [True, False, True, False]) + # check masked array with masked_invalid is updated in place + a = np.ma.array([np.inf, 1, 2, 3, 4]) + _ = np.ma.masked_invalid(a, copy=False) + assert_array_equal(a.mask, [True, False, False, False, False]) + +def test_append_masked_array(): + a = np.ma.masked_equal([1, 2, 3], value=2) + b = np.ma.masked_equal([4, 3, 2], value=2) + + result = np.ma.append(a, b) + expected_data = [1, 2, 3, 4, 3, 2] + expected_mask = [False, True, False, False, False, True] + assert_array_equal(result.data, expected_data) + assert_array_equal(result.mask, expected_mask) + + a = np.ma.masked_all((2, 2)) + b = np.ma.ones((3, 1)) + + result = np.ma.append(a, b) + expected_data = [1] * 3 + expected_mask = [True] * 4 + [False] * 3 + assert_array_equal(result.data[-3], expected_data) + assert_array_equal(result.mask, expected_mask) + + result = np.ma.append(a, b, axis=None) + assert_array_equal(result.data[-3], expected_data) + assert_array_equal(result.mask, expected_mask) + + +def test_append_masked_array_along_axis(): + a = np.ma.masked_equal([1, 2, 3], value=2) + b = np.ma.masked_values([[4, 5, 6], [7, 8, 9]], 7) + + # When `axis` is specified, `values` must have the correct shape. + assert_raises(ValueError, np.ma.append, a, b, axis=0) + + result = np.ma.append(a[np.newaxis, :], b, axis=0) + expected = np.ma.arange(1, 10) + expected[[1, 6]] = np.ma.masked + expected = expected.reshape((3, 3)) + assert_array_equal(result.data, expected.data) + assert_array_equal(result.mask, expected.mask) + +def test_default_fill_value_complex(): + # regression test for Python 3, where 'unicode' was not defined + assert_(default_fill_value(1 + 1j) == 1.e20 + 0.0j) + + +def test_ufunc_with_output(): + # check that giving an output argument always returns that output. + # Regression test for gh-8416. + x = array([1., 2., 3.], mask=[0, 0, 1]) + y = np.add(x, 1., out=x) + assert_(y is x) + + +def test_ufunc_with_out_varied(): + """ Test that masked arrays are immune to gh-10459 """ + # the mask of the output should not affect the result, however it is passed + a = array([ 1, 2, 3], mask=[1, 0, 0]) + b = array([10, 20, 30], mask=[1, 0, 0]) + out = array([ 0, 0, 0], mask=[0, 0, 1]) + expected = array([11, 22, 33], mask=[1, 0, 0]) + + out_pos = out.copy() + res_pos = np.add(a, b, out_pos) + + out_kw = out.copy() + res_kw = np.add(a, b, out=out_kw) + + out_tup = out.copy() + res_tup = np.add(a, b, out=(out_tup,)) + + assert_equal(res_kw.mask, expected.mask) + assert_equal(res_kw.data, expected.data) + assert_equal(res_tup.mask, expected.mask) + assert_equal(res_tup.data, expected.data) + assert_equal(res_pos.mask, expected.mask) + assert_equal(res_pos.data, expected.data) + + +def test_astype_mask_ordering(): + descr = np.dtype([('v', int, 3), ('x', [('y', float)])]) + x = array([ + [([1, 2, 3], (1.0,)), ([1, 2, 3], (2.0,))], + [([1, 2, 3], (3.0,)), ([1, 2, 3], (4.0,))]], dtype=descr) + x[0]['v'][0] = np.ma.masked + + x_a = x.astype(descr) + assert x_a.dtype.names == np.dtype(descr).names + assert x_a.mask.dtype.names == np.dtype(descr).names + assert_equal(x, x_a) + + assert_(x is x.astype(x.dtype, copy=False)) + assert_equal(type(x.astype(x.dtype, subok=False)), np.ndarray) + + x_f = x.astype(x.dtype, order='F') + assert_(x_f.flags.f_contiguous) + assert_(x_f.mask.flags.f_contiguous) + + # Also test the same indirectly, via np.array + x_a2 = np.array(x, dtype=descr, subok=True) + assert x_a2.dtype.names == np.dtype(descr).names + assert x_a2.mask.dtype.names == np.dtype(descr).names + assert_equal(x, x_a2) + + assert_(x is np.array(x, dtype=descr, copy=None, subok=True)) + + x_f2 = np.array(x, dtype=x.dtype, order='F', subok=True) + assert_(x_f2.flags.f_contiguous) + assert_(x_f2.mask.flags.f_contiguous) + + +@pytest.mark.parametrize('dt1', num_dts, ids=num_ids) +@pytest.mark.parametrize('dt2', num_dts, ids=num_ids) +@pytest.mark.filterwarnings('ignore::numpy.exceptions.ComplexWarning') +def test_astype_basic(dt1, dt2): + # See gh-12070 + src = np.ma.array(ones(3, dt1), fill_value=1) + dst = src.astype(dt2) + + assert_(src.fill_value == 1) + assert_(src.dtype == dt1) + assert_(src.fill_value.dtype == dt1) + + assert_(dst.fill_value == 1) + assert_(dst.dtype == dt2) + assert_(dst.fill_value.dtype == dt2) + + assert_equal(src, dst) + + +def test_fieldless_void(): + dt = np.dtype([]) # a void dtype with no fields + x = np.empty(4, dt) + + # these arrays contain no values, so there's little to test - but this + # shouldn't crash + mx = np.ma.array(x) + assert_equal(mx.dtype, x.dtype) + assert_equal(mx.shape, x.shape) + + mx = np.ma.array(x, mask=x) + assert_equal(mx.dtype, x.dtype) + assert_equal(mx.shape, x.shape) + + +def test_mask_shape_assignment_does_not_break_masked(): + a = np.ma.masked + b = np.ma.array(1, mask=a.mask) + b.shape = (1,) + assert_equal(a.mask.shape, ()) + +@pytest.mark.skipif(sys.flags.optimize > 1, + reason="no docstrings present to inspect when PYTHONOPTIMIZE/Py_OptimizeFlag > 1") # noqa: E501 +def test_doc_note(): + def method(self): + """This docstring + + Has multiple lines + + And notes + + Notes + ----- + original note + """ + pass + + expected_doc = """This docstring + +Has multiple lines + +And notes + +Notes +----- +note + +original note""" + + assert_equal(np.ma.core.doc_note(method.__doc__, "note"), expected_doc) + + +def test_gh_22556(): + source = np.ma.array([0, [0, 1, 2]], dtype=object) + deepcopy = copy.deepcopy(source) + deepcopy[1].append('this should not appear in source') + assert len(source[1]) == 3 + + +def test_gh_21022(): + # testing for absence of reported error + source = np.ma.masked_array(data=[-1, -1], mask=True, dtype=np.float64) + axis = np.array(0) + result = np.prod(source, axis=axis, keepdims=False) + result = np.ma.masked_array(result, + mask=np.ones(result.shape, dtype=np.bool)) + array = np.ma.masked_array(data=-1, mask=True, dtype=np.float64) + copy.deepcopy(array) + copy.deepcopy(result) + + +def test_deepcopy_2d_obj(): + source = np.ma.array([[0, "dog"], + [1, 1], + [[1, 2], "cat"]], + mask=[[0, 1], + [0, 0], + [0, 0]], + dtype=object) + deepcopy = copy.deepcopy(source) + deepcopy[2, 0].extend(['this should not appear in source', 3]) + assert len(source[2, 0]) == 2 + assert len(deepcopy[2, 0]) == 4 + assert_equal(deepcopy._mask, source._mask) + deepcopy._mask[0, 0] = 1 + assert source._mask[0, 0] == 0 + + +def test_deepcopy_0d_obj(): + source = np.ma.array(0, mask=[0], dtype=object) + deepcopy = copy.deepcopy(source) + deepcopy[...] = 17 + assert_equal(source, 0) + assert_equal(deepcopy, 17) + + +def test_uint_fill_value_and_filled(): + # See also gh-27269 + a = np.ma.MaskedArray([1, 1], [True, False], dtype="uint16") + # the fill value should likely not be 99999, but for now guarantee it: + assert a.fill_value == 999999 + # However, it's type is uint: + assert a.fill_value.dtype.kind == "u" + # And this ensures things like filled work: + np.testing.assert_array_equal( + a.filled(), np.array([999999, 1]).astype("uint16"), strict=True) diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/test_deprecations.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_deprecations.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc8b9c72bb95217cb53f97c7a5bad43e52c8137 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_deprecations.py @@ -0,0 +1,87 @@ +"""Test deprecation and future warnings. + +""" +import io +import textwrap + +import pytest + +import numpy as np +from numpy.ma.core import MaskedArrayFutureWarning +from numpy.ma.testutils import assert_equal +from numpy.testing import assert_warns + + +class TestArgsort: + """ gh-8701 """ + def _test_base(self, argsort, cls): + arr_0d = np.array(1).view(cls) + argsort(arr_0d) + + arr_1d = np.array([1, 2, 3]).view(cls) + argsort(arr_1d) + + # argsort has a bad default for >1d arrays + arr_2d = np.array([[1, 2], [3, 4]]).view(cls) + result = assert_warns( + np.ma.core.MaskedArrayFutureWarning, argsort, arr_2d) + assert_equal(result, argsort(arr_2d, axis=None)) + + # should be no warnings for explicitly specifying it + argsort(arr_2d, axis=None) + argsort(arr_2d, axis=-1) + + def test_function_ndarray(self): + return self._test_base(np.ma.argsort, np.ndarray) + + def test_function_maskedarray(self): + return self._test_base(np.ma.argsort, np.ma.MaskedArray) + + def test_method(self): + return self._test_base(np.ma.MaskedArray.argsort, np.ma.MaskedArray) + + +class TestMinimumMaximum: + + def test_axis_default(self): + # NumPy 1.13, 2017-05-06 + + data1d = np.ma.arange(6) + data2d = data1d.reshape(2, 3) + + ma_min = np.ma.minimum.reduce + ma_max = np.ma.maximum.reduce + + # check that the default axis is still None, but warns on 2d arrays + result = assert_warns(MaskedArrayFutureWarning, ma_max, data2d) + assert_equal(result, ma_max(data2d, axis=None)) + + result = assert_warns(MaskedArrayFutureWarning, ma_min, data2d) + assert_equal(result, ma_min(data2d, axis=None)) + + # no warnings on 1d, as both new and old defaults are equivalent + result = ma_min(data1d) + assert_equal(result, ma_min(data1d, axis=None)) + assert_equal(result, ma_min(data1d, axis=0)) + + result = ma_max(data1d) + assert_equal(result, ma_max(data1d, axis=None)) + assert_equal(result, ma_max(data1d, axis=0)) + + +class TestFromtextfile: + def test_fromtextfile_delimitor(self): + # NumPy 1.22.0, 2021-09-23 + + textfile = io.StringIO(textwrap.dedent( + """ + A,B,C,D + 'string 1';1;1.0;'mixed column' + 'string 2';2;2.0; + 'string 3';3;3.0;123 + 'string 4';4;4.0;3.14 + """ + )) + + with pytest.warns(DeprecationWarning): + result = np.ma.mrecords.fromtextfile(textfile, delimitor=';') diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/test_extras.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_extras.py new file mode 100644 index 0000000000000000000000000000000000000000..3d10e839cbc959058a6bbc81a0483f1ac5762043 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_extras.py @@ -0,0 +1,1998 @@ +"""Tests suite for MaskedArray. +Adapted from the original test_ma by Pierre Gerard-Marchant + +:author: Pierre Gerard-Marchant +:contact: pierregm_at_uga_dot_edu + +""" +import itertools +import warnings + +import pytest + +import numpy as np +from numpy._core.numeric import normalize_axis_tuple +from numpy.ma.core import ( + MaskedArray, + arange, + array, + count, + getmaskarray, + masked, + masked_array, + nomask, + ones, + shape, + zeros, +) +from numpy.ma.extras import ( + _covhelper, + apply_along_axis, + apply_over_axes, + atleast_1d, + atleast_2d, + atleast_3d, + average, + clump_masked, + clump_unmasked, + compress_nd, + compress_rowcols, + corrcoef, + cov, + diagflat, + dot, + ediff1d, + flatnotmasked_contiguous, + in1d, + intersect1d, + isin, + mask_rowcols, + masked_all, + masked_all_like, + median, + mr_, + ndenumerate, + notmasked_contiguous, + notmasked_edges, + polyfit, + setdiff1d, + setxor1d, + stack, + union1d, + unique, + vstack, +) +from numpy.ma.testutils import ( + assert_, + assert_almost_equal, + assert_array_equal, + assert_equal, +) +from numpy.testing import assert_warns, suppress_warnings + + +class TestGeneric: + # + def test_masked_all(self): + # Tests masked_all + # Standard dtype + test = masked_all((2,), dtype=float) + control = array([1, 1], mask=[1, 1], dtype=float) + assert_equal(test, control) + # Flexible dtype + dt = np.dtype({'names': ['a', 'b'], 'formats': ['f', 'f']}) + test = masked_all((2,), dtype=dt) + control = array([(0, 0), (0, 0)], mask=[(1, 1), (1, 1)], dtype=dt) + assert_equal(test, control) + test = masked_all((2, 2), dtype=dt) + control = array([[(0, 0), (0, 0)], [(0, 0), (0, 0)]], + mask=[[(1, 1), (1, 1)], [(1, 1), (1, 1)]], + dtype=dt) + assert_equal(test, control) + # Nested dtype + dt = np.dtype([('a', 'f'), ('b', [('ba', 'f'), ('bb', 'f')])]) + test = masked_all((2,), dtype=dt) + control = array([(1, (1, 1)), (1, (1, 1))], + mask=[(1, (1, 1)), (1, (1, 1))], dtype=dt) + assert_equal(test, control) + test = masked_all((2,), dtype=dt) + control = array([(1, (1, 1)), (1, (1, 1))], + mask=[(1, (1, 1)), (1, (1, 1))], dtype=dt) + assert_equal(test, control) + test = masked_all((1, 1), dtype=dt) + control = array([[(1, (1, 1))]], mask=[[(1, (1, 1))]], dtype=dt) + assert_equal(test, control) + + def test_masked_all_with_object_nested(self): + # Test masked_all works with nested array with dtype of an 'object' + # refers to issue #15895 + my_dtype = np.dtype([('b', ([('c', object)], (1,)))]) + masked_arr = np.ma.masked_all((1,), my_dtype) + + assert_equal(type(masked_arr['b']), np.ma.core.MaskedArray) + assert_equal(type(masked_arr['b']['c']), np.ma.core.MaskedArray) + assert_equal(len(masked_arr['b']['c']), 1) + assert_equal(masked_arr['b']['c'].shape, (1, 1)) + assert_equal(masked_arr['b']['c']._fill_value.shape, ()) + + def test_masked_all_with_object(self): + # same as above except that the array is not nested + my_dtype = np.dtype([('b', (object, (1,)))]) + masked_arr = np.ma.masked_all((1,), my_dtype) + + assert_equal(type(masked_arr['b']), np.ma.core.MaskedArray) + assert_equal(len(masked_arr['b']), 1) + assert_equal(masked_arr['b'].shape, (1, 1)) + assert_equal(masked_arr['b']._fill_value.shape, ()) + + def test_masked_all_like(self): + # Tests masked_all + # Standard dtype + base = array([1, 2], dtype=float) + test = masked_all_like(base) + control = array([1, 1], mask=[1, 1], dtype=float) + assert_equal(test, control) + # Flexible dtype + dt = np.dtype({'names': ['a', 'b'], 'formats': ['f', 'f']}) + base = array([(0, 0), (0, 0)], mask=[(1, 1), (1, 1)], dtype=dt) + test = masked_all_like(base) + control = array([(10, 10), (10, 10)], mask=[(1, 1), (1, 1)], dtype=dt) + assert_equal(test, control) + # Nested dtype + dt = np.dtype([('a', 'f'), ('b', [('ba', 'f'), ('bb', 'f')])]) + control = array([(1, (1, 1)), (1, (1, 1))], + mask=[(1, (1, 1)), (1, (1, 1))], dtype=dt) + test = masked_all_like(control) + assert_equal(test, control) + + def check_clump(self, f): + for i in range(1, 7): + for j in range(2**i): + k = np.arange(i, dtype=int) + ja = np.full(i, j, dtype=int) + a = masked_array(2**k) + a.mask = (ja & (2**k)) != 0 + s = 0 + for sl in f(a): + s += a.data[sl].sum() + if f == clump_unmasked: + assert_equal(a.compressed().sum(), s) + else: + a.mask = ~a.mask + assert_equal(a.compressed().sum(), s) + + def test_clump_masked(self): + # Test clump_masked + a = masked_array(np.arange(10)) + a[[0, 1, 2, 6, 8, 9]] = masked + # + test = clump_masked(a) + control = [slice(0, 3), slice(6, 7), slice(8, 10)] + assert_equal(test, control) + + self.check_clump(clump_masked) + + def test_clump_unmasked(self): + # Test clump_unmasked + a = masked_array(np.arange(10)) + a[[0, 1, 2, 6, 8, 9]] = masked + test = clump_unmasked(a) + control = [slice(3, 6), slice(7, 8), ] + assert_equal(test, control) + + self.check_clump(clump_unmasked) + + def test_flatnotmasked_contiguous(self): + # Test flatnotmasked_contiguous + a = arange(10) + # No mask + test = flatnotmasked_contiguous(a) + assert_equal(test, [slice(0, a.size)]) + # mask of all false + a.mask = np.zeros(10, dtype=bool) + assert_equal(test, [slice(0, a.size)]) + # Some mask + a[(a < 3) | (a > 8) | (a == 5)] = masked + test = flatnotmasked_contiguous(a) + assert_equal(test, [slice(3, 5), slice(6, 9)]) + # + a[:] = masked + test = flatnotmasked_contiguous(a) + assert_equal(test, []) + + +class TestAverage: + # Several tests of average. Why so many ? Good point... + def test_testAverage1(self): + # Test of average. + ott = array([0., 1., 2., 3.], mask=[True, False, False, False]) + assert_equal(2.0, average(ott, axis=0)) + assert_equal(2.0, average(ott, weights=[1., 1., 2., 1.])) + result, wts = average(ott, weights=[1., 1., 2., 1.], returned=True) + assert_equal(2.0, result) + assert_(wts == 4.0) + ott[:] = masked + assert_equal(average(ott, axis=0).mask, [True]) + ott = array([0., 1., 2., 3.], mask=[True, False, False, False]) + ott = ott.reshape(2, 2) + ott[:, 1] = masked + assert_equal(average(ott, axis=0), [2.0, 0.0]) + assert_equal(average(ott, axis=1).mask[0], [True]) + assert_equal([2., 0.], average(ott, axis=0)) + result, wts = average(ott, axis=0, returned=True) + assert_equal(wts, [1., 0.]) + + def test_testAverage2(self): + # More tests of average. + w1 = [0, 1, 1, 1, 1, 0] + w2 = [[0, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 1]] + x = arange(6, dtype=np.float64) + assert_equal(average(x, axis=0), 2.5) + assert_equal(average(x, axis=0, weights=w1), 2.5) + y = array([arange(6, dtype=np.float64), 2.0 * arange(6)]) + assert_equal(average(y, None), np.add.reduce(np.arange(6)) * 3. / 12.) + assert_equal(average(y, axis=0), np.arange(6) * 3. / 2.) + assert_equal(average(y, axis=1), + [average(x, axis=0), average(x, axis=0) * 2.0]) + assert_equal(average(y, None, weights=w2), 20. / 6.) + assert_equal(average(y, axis=0, weights=w2), + [0., 1., 2., 3., 4., 10.]) + assert_equal(average(y, axis=1), + [average(x, axis=0), average(x, axis=0) * 2.0]) + m1 = zeros(6) + m2 = [0, 0, 1, 1, 0, 0] + m3 = [[0, 0, 1, 1, 0, 0], [0, 1, 1, 1, 1, 0]] + m4 = ones(6) + m5 = [0, 1, 1, 1, 1, 1] + assert_equal(average(masked_array(x, m1), axis=0), 2.5) + assert_equal(average(masked_array(x, m2), axis=0), 2.5) + assert_equal(average(masked_array(x, m4), axis=0).mask, [True]) + assert_equal(average(masked_array(x, m5), axis=0), 0.0) + assert_equal(count(average(masked_array(x, m4), axis=0)), 0) + z = masked_array(y, m3) + assert_equal(average(z, None), 20. / 6.) + assert_equal(average(z, axis=0), [0., 1., 99., 99., 4.0, 7.5]) + assert_equal(average(z, axis=1), [2.5, 5.0]) + assert_equal(average(z, axis=0, weights=w2), + [0., 1., 99., 99., 4.0, 10.0]) + + def test_testAverage3(self): + # Yet more tests of average! + a = arange(6) + b = arange(6) * 3 + r1, w1 = average([[a, b], [b, a]], axis=1, returned=True) + assert_equal(shape(r1), shape(w1)) + assert_equal(r1.shape, w1.shape) + r2, w2 = average(ones((2, 2, 3)), axis=0, weights=[3, 1], returned=True) + assert_equal(shape(w2), shape(r2)) + r2, w2 = average(ones((2, 2, 3)), returned=True) + assert_equal(shape(w2), shape(r2)) + r2, w2 = average(ones((2, 2, 3)), weights=ones((2, 2, 3)), returned=True) + assert_equal(shape(w2), shape(r2)) + a2d = array([[1, 2], [0, 4]], float) + a2dm = masked_array(a2d, [[False, False], [True, False]]) + a2da = average(a2d, axis=0) + assert_equal(a2da, [0.5, 3.0]) + a2dma = average(a2dm, axis=0) + assert_equal(a2dma, [1.0, 3.0]) + a2dma = average(a2dm, axis=None) + assert_equal(a2dma, 7. / 3.) + a2dma = average(a2dm, axis=1) + assert_equal(a2dma, [1.5, 4.0]) + + def test_testAverage4(self): + # Test that `keepdims` works with average + x = np.array([2, 3, 4]).reshape(3, 1) + b = np.ma.array(x, mask=[[False], [False], [True]]) + w = np.array([4, 5, 6]).reshape(3, 1) + actual = average(b, weights=w, axis=1, keepdims=True) + desired = masked_array([[2.], [3.], [4.]], [[False], [False], [True]]) + assert_equal(actual, desired) + + def test_weight_and_input_dims_different(self): + # this test mirrors a test for np.average() + # in lib/test/test_function_base.py + y = np.arange(12).reshape(2, 2, 3) + w = np.array([0., 0., 1., .5, .5, 0., 0., .5, .5, 1., 0., 0.])\ + .reshape(2, 2, 3) + + m = np.full((2, 2, 3), False) + yma = np.ma.array(y, mask=m) + subw0 = w[:, :, 0] + + actual = average(yma, axis=(0, 1), weights=subw0) + desired = masked_array([7., 8., 9.], mask=[False, False, False]) + assert_almost_equal(actual, desired) + + m = np.full((2, 2, 3), False) + m[:, :, 0] = True + m[0, 0, 1] = True + yma = np.ma.array(y, mask=m) + actual = average(yma, axis=(0, 1), weights=subw0) + desired = masked_array( + [np.nan, 8., 9.], + mask=[True, False, False]) + assert_almost_equal(actual, desired) + + m = np.full((2, 2, 3), False) + yma = np.ma.array(y, mask=m) + + subw1 = w[1, :, :] + actual = average(yma, axis=(1, 2), weights=subw1) + desired = masked_array([2.25, 8.25], mask=[False, False]) + assert_almost_equal(actual, desired) + + # here the weights have the wrong shape for the specified axes + with pytest.raises( + ValueError, + match="Shape of weights must be consistent with " + "shape of a along specified axis"): + average(yma, axis=(0, 1, 2), weights=subw0) + + with pytest.raises( + ValueError, + match="Shape of weights must be consistent with " + "shape of a along specified axis"): + average(yma, axis=(0, 1), weights=subw1) + + # swapping the axes should be same as transposing weights + actual = average(yma, axis=(1, 0), weights=subw0) + desired = average(yma, axis=(0, 1), weights=subw0.T) + assert_almost_equal(actual, desired) + + def test_onintegers_with_mask(self): + # Test average on integers with mask + a = average(array([1, 2])) + assert_equal(a, 1.5) + a = average(array([1, 2, 3, 4], mask=[False, False, True, True])) + assert_equal(a, 1.5) + + def test_complex(self): + # Test with complex data. + # (Regression test for https://github.com/numpy/numpy/issues/2684) + mask = np.array([[0, 0, 0, 1, 0], + [0, 1, 0, 0, 0]], dtype=bool) + a = masked_array([[0, 1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j], + [9j, 0 + 1j, 2 + 3j, 4 + 5j, 7 + 7j]], + mask=mask) + + av = average(a) + expected = np.average(a.compressed()) + assert_almost_equal(av.real, expected.real) + assert_almost_equal(av.imag, expected.imag) + + av0 = average(a, axis=0) + expected0 = average(a.real, axis=0) + average(a.imag, axis=0) * 1j + assert_almost_equal(av0.real, expected0.real) + assert_almost_equal(av0.imag, expected0.imag) + + av1 = average(a, axis=1) + expected1 = average(a.real, axis=1) + average(a.imag, axis=1) * 1j + assert_almost_equal(av1.real, expected1.real) + assert_almost_equal(av1.imag, expected1.imag) + + # Test with the 'weights' argument. + wts = np.array([[0.5, 1.0, 2.0, 1.0, 0.5], + [1.0, 1.0, 1.0, 1.0, 1.0]]) + wav = average(a, weights=wts) + expected = np.average(a.compressed(), weights=wts[~mask]) + assert_almost_equal(wav.real, expected.real) + assert_almost_equal(wav.imag, expected.imag) + + wav0 = average(a, weights=wts, axis=0) + expected0 = (average(a.real, weights=wts, axis=0) + + average(a.imag, weights=wts, axis=0) * 1j) + assert_almost_equal(wav0.real, expected0.real) + assert_almost_equal(wav0.imag, expected0.imag) + + wav1 = average(a, weights=wts, axis=1) + expected1 = (average(a.real, weights=wts, axis=1) + + average(a.imag, weights=wts, axis=1) * 1j) + assert_almost_equal(wav1.real, expected1.real) + assert_almost_equal(wav1.imag, expected1.imag) + + @pytest.mark.parametrize( + 'x, axis, expected_avg, weights, expected_wavg, expected_wsum', + [([1, 2, 3], None, [2.0], [3, 4, 1], [1.75], [8.0]), + ([[1, 2, 5], [1, 6, 11]], 0, [[1.0, 4.0, 8.0]], + [1, 3], [[1.0, 5.0, 9.5]], [[4, 4, 4]])], + ) + def test_basic_keepdims(self, x, axis, expected_avg, + weights, expected_wavg, expected_wsum): + avg = np.ma.average(x, axis=axis, keepdims=True) + assert avg.shape == np.shape(expected_avg) + assert_array_equal(avg, expected_avg) + + wavg = np.ma.average(x, axis=axis, weights=weights, keepdims=True) + assert wavg.shape == np.shape(expected_wavg) + assert_array_equal(wavg, expected_wavg) + + wavg, wsum = np.ma.average(x, axis=axis, weights=weights, + returned=True, keepdims=True) + assert wavg.shape == np.shape(expected_wavg) + assert_array_equal(wavg, expected_wavg) + assert wsum.shape == np.shape(expected_wsum) + assert_array_equal(wsum, expected_wsum) + + def test_masked_weights(self): + # Test with masked weights. + # (Regression test for https://github.com/numpy/numpy/issues/10438) + a = np.ma.array(np.arange(9).reshape(3, 3), + mask=[[1, 0, 0], [1, 0, 0], [0, 0, 0]]) + weights_unmasked = masked_array([5, 28, 31], mask=False) + weights_masked = masked_array([5, 28, 31], mask=[1, 0, 0]) + + avg_unmasked = average(a, axis=0, + weights=weights_unmasked, returned=False) + expected_unmasked = np.array([6.0, 5.21875, 6.21875]) + assert_almost_equal(avg_unmasked, expected_unmasked) + + avg_masked = average(a, axis=0, weights=weights_masked, returned=False) + expected_masked = np.array([6.0, 5.576271186440678, 6.576271186440678]) + assert_almost_equal(avg_masked, expected_masked) + + # weights should be masked if needed + # depending on the array mask. This is to avoid summing + # masked nan or other values that are not cancelled by a zero + a = np.ma.array([1.0, 2.0, 3.0, 4.0], + mask=[False, False, True, True]) + avg_unmasked = average(a, weights=[1, 1, 1, np.nan]) + + assert_almost_equal(avg_unmasked, 1.5) + + a = np.ma.array([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 1.0, 2.0, 3.0], + ], mask=[ + [False, True, True, False], + [True, False, True, True], + [True, False, True, False], + ]) + + avg_masked = np.ma.average(a, weights=[1, np.nan, 1], axis=0) + avg_expected = np.ma.array([1.0, np.nan, np.nan, 3.5], + mask=[False, True, True, False]) + + assert_almost_equal(avg_masked, avg_expected) + assert_equal(avg_masked.mask, avg_expected.mask) + + +class TestConcatenator: + # Tests for mr_, the equivalent of r_ for masked arrays. + + def test_1d(self): + # Tests mr_ on 1D arrays. + assert_array_equal(mr_[1, 2, 3, 4, 5, 6], array([1, 2, 3, 4, 5, 6])) + b = ones(5) + m = [1, 0, 0, 0, 0] + d = masked_array(b, mask=m) + c = mr_[d, 0, 0, d] + assert_(isinstance(c, MaskedArray)) + assert_array_equal(c, [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1]) + assert_array_equal(c.mask, mr_[m, 0, 0, m]) + + def test_2d(self): + # Tests mr_ on 2D arrays. + a_1 = np.random.rand(5, 5) + a_2 = np.random.rand(5, 5) + m_1 = np.round(np.random.rand(5, 5), 0) + m_2 = np.round(np.random.rand(5, 5), 0) + b_1 = masked_array(a_1, mask=m_1) + b_2 = masked_array(a_2, mask=m_2) + # append columns + d = mr_['1', b_1, b_2] + assert_(d.shape == (5, 10)) + assert_array_equal(d[:, :5], b_1) + assert_array_equal(d[:, 5:], b_2) + assert_array_equal(d.mask, np.r_['1', m_1, m_2]) + d = mr_[b_1, b_2] + assert_(d.shape == (10, 5)) + assert_array_equal(d[:5, :], b_1) + assert_array_equal(d[5:, :], b_2) + assert_array_equal(d.mask, np.r_[m_1, m_2]) + + def test_masked_constant(self): + actual = mr_[np.ma.masked, 1] + assert_equal(actual.mask, [True, False]) + assert_equal(actual.data[1], 1) + + actual = mr_[[1, 2], np.ma.masked] + assert_equal(actual.mask, [False, False, True]) + assert_equal(actual.data[:2], [1, 2]) + + +class TestNotMasked: + # Tests notmasked_edges and notmasked_contiguous. + + def test_edges(self): + # Tests unmasked_edges + data = masked_array(np.arange(25).reshape(5, 5), + mask=[[0, 0, 1, 0, 0], + [0, 0, 0, 1, 1], + [1, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 0, 0]],) + test = notmasked_edges(data, None) + assert_equal(test, [0, 24]) + test = notmasked_edges(data, 0) + assert_equal(test[0], [(0, 0, 1, 0, 0), (0, 1, 2, 3, 4)]) + assert_equal(test[1], [(3, 3, 3, 4, 4), (0, 1, 2, 3, 4)]) + test = notmasked_edges(data, 1) + assert_equal(test[0], [(0, 1, 2, 3, 4), (0, 0, 2, 0, 3)]) + assert_equal(test[1], [(0, 1, 2, 3, 4), (4, 2, 4, 4, 4)]) + # + test = notmasked_edges(data.data, None) + assert_equal(test, [0, 24]) + test = notmasked_edges(data.data, 0) + assert_equal(test[0], [(0, 0, 0, 0, 0), (0, 1, 2, 3, 4)]) + assert_equal(test[1], [(4, 4, 4, 4, 4), (0, 1, 2, 3, 4)]) + test = notmasked_edges(data.data, -1) + assert_equal(test[0], [(0, 1, 2, 3, 4), (0, 0, 0, 0, 0)]) + assert_equal(test[1], [(0, 1, 2, 3, 4), (4, 4, 4, 4, 4)]) + # + data[-2] = masked + test = notmasked_edges(data, 0) + assert_equal(test[0], [(0, 0, 1, 0, 0), (0, 1, 2, 3, 4)]) + assert_equal(test[1], [(1, 1, 2, 4, 4), (0, 1, 2, 3, 4)]) + test = notmasked_edges(data, -1) + assert_equal(test[0], [(0, 1, 2, 4), (0, 0, 2, 3)]) + assert_equal(test[1], [(0, 1, 2, 4), (4, 2, 4, 4)]) + + def test_contiguous(self): + # Tests notmasked_contiguous + a = masked_array(np.arange(24).reshape(3, 8), + mask=[[0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 1, 0]]) + tmp = notmasked_contiguous(a, None) + assert_equal(tmp, [ + slice(0, 4, None), + slice(16, 22, None), + slice(23, 24, None) + ]) + + tmp = notmasked_contiguous(a, 0) + assert_equal(tmp, [ + [slice(0, 1, None), slice(2, 3, None)], + [slice(0, 1, None), slice(2, 3, None)], + [slice(0, 1, None), slice(2, 3, None)], + [slice(0, 1, None), slice(2, 3, None)], + [slice(2, 3, None)], + [slice(2, 3, None)], + [], + [slice(2, 3, None)] + ]) + # + tmp = notmasked_contiguous(a, 1) + assert_equal(tmp, [ + [slice(0, 4, None)], + [], + [slice(0, 6, None), slice(7, 8, None)] + ]) + + +class TestCompressFunctions: + + def test_compress_nd(self): + # Tests compress_nd + x = np.array(list(range(3 * 4 * 5))).reshape(3, 4, 5) + m = np.zeros((3, 4, 5)).astype(bool) + m[1, 1, 1] = True + x = array(x, mask=m) + + # axis=None + a = compress_nd(x) + assert_equal(a, [[[ 0, 2, 3, 4], + [10, 12, 13, 14], + [15, 17, 18, 19]], + [[40, 42, 43, 44], + [50, 52, 53, 54], + [55, 57, 58, 59]]]) + + # axis=0 + a = compress_nd(x, 0) + assert_equal(a, [[[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + [[40, 41, 42, 43, 44], + [45, 46, 47, 48, 49], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59]]]) + + # axis=1 + a = compress_nd(x, 1) + assert_equal(a, [[[ 0, 1, 2, 3, 4], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + [[20, 21, 22, 23, 24], + [30, 31, 32, 33, 34], + [35, 36, 37, 38, 39]], + [[40, 41, 42, 43, 44], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59]]]) + + a2 = compress_nd(x, (1,)) + a3 = compress_nd(x, -2) + a4 = compress_nd(x, (-2,)) + assert_equal(a, a2) + assert_equal(a, a3) + assert_equal(a, a4) + + # axis=2 + a = compress_nd(x, 2) + assert_equal(a, [[[ 0, 2, 3, 4], + [ 5, 7, 8, 9], + [10, 12, 13, 14], + [15, 17, 18, 19]], + [[20, 22, 23, 24], + [25, 27, 28, 29], + [30, 32, 33, 34], + [35, 37, 38, 39]], + [[40, 42, 43, 44], + [45, 47, 48, 49], + [50, 52, 53, 54], + [55, 57, 58, 59]]]) + + a2 = compress_nd(x, (2,)) + a3 = compress_nd(x, -1) + a4 = compress_nd(x, (-1,)) + assert_equal(a, a2) + assert_equal(a, a3) + assert_equal(a, a4) + + # axis=(0, 1) + a = compress_nd(x, (0, 1)) + assert_equal(a, [[[ 0, 1, 2, 3, 4], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + [[40, 41, 42, 43, 44], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59]]]) + a2 = compress_nd(x, (0, -2)) + assert_equal(a, a2) + + # axis=(1, 2) + a = compress_nd(x, (1, 2)) + assert_equal(a, [[[ 0, 2, 3, 4], + [10, 12, 13, 14], + [15, 17, 18, 19]], + [[20, 22, 23, 24], + [30, 32, 33, 34], + [35, 37, 38, 39]], + [[40, 42, 43, 44], + [50, 52, 53, 54], + [55, 57, 58, 59]]]) + + a2 = compress_nd(x, (-2, 2)) + a3 = compress_nd(x, (1, -1)) + a4 = compress_nd(x, (-2, -1)) + assert_equal(a, a2) + assert_equal(a, a3) + assert_equal(a, a4) + + # axis=(0, 2) + a = compress_nd(x, (0, 2)) + assert_equal(a, [[[ 0, 2, 3, 4], + [ 5, 7, 8, 9], + [10, 12, 13, 14], + [15, 17, 18, 19]], + [[40, 42, 43, 44], + [45, 47, 48, 49], + [50, 52, 53, 54], + [55, 57, 58, 59]]]) + + a2 = compress_nd(x, (0, -1)) + assert_equal(a, a2) + + def test_compress_rowcols(self): + # Tests compress_rowcols + x = array(np.arange(9).reshape(3, 3), + mask=[[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert_equal(compress_rowcols(x), [[4, 5], [7, 8]]) + assert_equal(compress_rowcols(x, 0), [[3, 4, 5], [6, 7, 8]]) + assert_equal(compress_rowcols(x, 1), [[1, 2], [4, 5], [7, 8]]) + x = array(x._data, mask=[[0, 0, 0], [0, 1, 0], [0, 0, 0]]) + assert_equal(compress_rowcols(x), [[0, 2], [6, 8]]) + assert_equal(compress_rowcols(x, 0), [[0, 1, 2], [6, 7, 8]]) + assert_equal(compress_rowcols(x, 1), [[0, 2], [3, 5], [6, 8]]) + x = array(x._data, mask=[[1, 0, 0], [0, 1, 0], [0, 0, 0]]) + assert_equal(compress_rowcols(x), [[8]]) + assert_equal(compress_rowcols(x, 0), [[6, 7, 8]]) + assert_equal(compress_rowcols(x, 1,), [[2], [5], [8]]) + x = array(x._data, mask=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert_equal(compress_rowcols(x).size, 0) + assert_equal(compress_rowcols(x, 0).size, 0) + assert_equal(compress_rowcols(x, 1).size, 0) + + def test_mask_rowcols(self): + # Tests mask_rowcols. + x = array(np.arange(9).reshape(3, 3), + mask=[[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert_equal(mask_rowcols(x).mask, + [[1, 1, 1], [1, 0, 0], [1, 0, 0]]) + assert_equal(mask_rowcols(x, 0).mask, + [[1, 1, 1], [0, 0, 0], [0, 0, 0]]) + assert_equal(mask_rowcols(x, 1).mask, + [[1, 0, 0], [1, 0, 0], [1, 0, 0]]) + x = array(x._data, mask=[[0, 0, 0], [0, 1, 0], [0, 0, 0]]) + assert_equal(mask_rowcols(x).mask, + [[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + assert_equal(mask_rowcols(x, 0).mask, + [[0, 0, 0], [1, 1, 1], [0, 0, 0]]) + assert_equal(mask_rowcols(x, 1).mask, + [[0, 1, 0], [0, 1, 0], [0, 1, 0]]) + x = array(x._data, mask=[[1, 0, 0], [0, 1, 0], [0, 0, 0]]) + assert_equal(mask_rowcols(x).mask, + [[1, 1, 1], [1, 1, 1], [1, 1, 0]]) + assert_equal(mask_rowcols(x, 0).mask, + [[1, 1, 1], [1, 1, 1], [0, 0, 0]]) + assert_equal(mask_rowcols(x, 1,).mask, + [[1, 1, 0], [1, 1, 0], [1, 1, 0]]) + x = array(x._data, mask=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert_(mask_rowcols(x).all() is masked) + assert_(mask_rowcols(x, 0).all() is masked) + assert_(mask_rowcols(x, 1).all() is masked) + assert_(mask_rowcols(x).mask.all()) + assert_(mask_rowcols(x, 0).mask.all()) + assert_(mask_rowcols(x, 1).mask.all()) + + @pytest.mark.parametrize("axis", [None, 0, 1]) + @pytest.mark.parametrize(["func", "rowcols_axis"], + [(np.ma.mask_rows, 0), (np.ma.mask_cols, 1)]) + def test_mask_row_cols_axis_deprecation(self, axis, func, rowcols_axis): + # Test deprecation of the axis argument to `mask_rows` and `mask_cols` + x = array(np.arange(9).reshape(3, 3), + mask=[[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + + with assert_warns(DeprecationWarning): + res = func(x, axis=axis) + assert_equal(res, mask_rowcols(x, rowcols_axis)) + + def test_dot(self): + # Tests dot product + n = np.arange(1, 7) + # + m = [1, 0, 0, 0, 0, 0] + a = masked_array(n, mask=m).reshape(2, 3) + b = masked_array(n, mask=m).reshape(3, 2) + c = dot(a, b, strict=True) + assert_equal(c.mask, [[1, 1], [1, 0]]) + c = dot(b, a, strict=True) + assert_equal(c.mask, [[1, 1, 1], [1, 0, 0], [1, 0, 0]]) + c = dot(a, b, strict=False) + assert_equal(c, np.dot(a.filled(0), b.filled(0))) + c = dot(b, a, strict=False) + assert_equal(c, np.dot(b.filled(0), a.filled(0))) + # + m = [0, 0, 0, 0, 0, 1] + a = masked_array(n, mask=m).reshape(2, 3) + b = masked_array(n, mask=m).reshape(3, 2) + c = dot(a, b, strict=True) + assert_equal(c.mask, [[0, 1], [1, 1]]) + c = dot(b, a, strict=True) + assert_equal(c.mask, [[0, 0, 1], [0, 0, 1], [1, 1, 1]]) + c = dot(a, b, strict=False) + assert_equal(c, np.dot(a.filled(0), b.filled(0))) + assert_equal(c, dot(a, b)) + c = dot(b, a, strict=False) + assert_equal(c, np.dot(b.filled(0), a.filled(0))) + # + m = [0, 0, 0, 0, 0, 0] + a = masked_array(n, mask=m).reshape(2, 3) + b = masked_array(n, mask=m).reshape(3, 2) + c = dot(a, b) + assert_equal(c.mask, nomask) + c = dot(b, a) + assert_equal(c.mask, nomask) + # + a = masked_array(n, mask=[1, 0, 0, 0, 0, 0]).reshape(2, 3) + b = masked_array(n, mask=[0, 0, 0, 0, 0, 0]).reshape(3, 2) + c = dot(a, b, strict=True) + assert_equal(c.mask, [[1, 1], [0, 0]]) + c = dot(a, b, strict=False) + assert_equal(c, np.dot(a.filled(0), b.filled(0))) + c = dot(b, a, strict=True) + assert_equal(c.mask, [[1, 0, 0], [1, 0, 0], [1, 0, 0]]) + c = dot(b, a, strict=False) + assert_equal(c, np.dot(b.filled(0), a.filled(0))) + # + a = masked_array(n, mask=[0, 0, 0, 0, 0, 1]).reshape(2, 3) + b = masked_array(n, mask=[0, 0, 0, 0, 0, 0]).reshape(3, 2) + c = dot(a, b, strict=True) + assert_equal(c.mask, [[0, 0], [1, 1]]) + c = dot(a, b) + assert_equal(c, np.dot(a.filled(0), b.filled(0))) + c = dot(b, a, strict=True) + assert_equal(c.mask, [[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + c = dot(b, a, strict=False) + assert_equal(c, np.dot(b.filled(0), a.filled(0))) + # + a = masked_array(n, mask=[0, 0, 0, 0, 0, 1]).reshape(2, 3) + b = masked_array(n, mask=[0, 0, 1, 0, 0, 0]).reshape(3, 2) + c = dot(a, b, strict=True) + assert_equal(c.mask, [[1, 0], [1, 1]]) + c = dot(a, b, strict=False) + assert_equal(c, np.dot(a.filled(0), b.filled(0))) + c = dot(b, a, strict=True) + assert_equal(c.mask, [[0, 0, 1], [1, 1, 1], [0, 0, 1]]) + c = dot(b, a, strict=False) + assert_equal(c, np.dot(b.filled(0), a.filled(0))) + # + a = masked_array(np.arange(8).reshape(2, 2, 2), + mask=[[[1, 0], [0, 0]], [[0, 0], [0, 0]]]) + b = masked_array(np.arange(8).reshape(2, 2, 2), + mask=[[[0, 0], [0, 0]], [[0, 0], [0, 1]]]) + c = dot(a, b, strict=True) + assert_equal(c.mask, + [[[[1, 1], [1, 1]], [[0, 0], [0, 1]]], + [[[0, 0], [0, 1]], [[0, 0], [0, 1]]]]) + c = dot(a, b, strict=False) + assert_equal(c.mask, + [[[[0, 0], [0, 1]], [[0, 0], [0, 0]]], + [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]) + c = dot(b, a, strict=True) + assert_equal(c.mask, + [[[[1, 0], [0, 0]], [[1, 0], [0, 0]]], + [[[1, 0], [0, 0]], [[1, 1], [1, 1]]]]) + c = dot(b, a, strict=False) + assert_equal(c.mask, + [[[[0, 0], [0, 0]], [[0, 0], [0, 0]]], + [[[0, 0], [0, 0]], [[1, 0], [0, 0]]]]) + # + a = masked_array(np.arange(8).reshape(2, 2, 2), + mask=[[[1, 0], [0, 0]], [[0, 0], [0, 0]]]) + b = 5. + c = dot(a, b, strict=True) + assert_equal(c.mask, [[[1, 0], [0, 0]], [[0, 0], [0, 0]]]) + c = dot(a, b, strict=False) + assert_equal(c.mask, [[[1, 0], [0, 0]], [[0, 0], [0, 0]]]) + c = dot(b, a, strict=True) + assert_equal(c.mask, [[[1, 0], [0, 0]], [[0, 0], [0, 0]]]) + c = dot(b, a, strict=False) + assert_equal(c.mask, [[[1, 0], [0, 0]], [[0, 0], [0, 0]]]) + # + a = masked_array(np.arange(8).reshape(2, 2, 2), + mask=[[[1, 0], [0, 0]], [[0, 0], [0, 0]]]) + b = masked_array(np.arange(2), mask=[0, 1]) + c = dot(a, b, strict=True) + assert_equal(c.mask, [[1, 1], [1, 1]]) + c = dot(a, b, strict=False) + assert_equal(c.mask, [[1, 0], [0, 0]]) + + def test_dot_returns_maskedarray(self): + # See gh-6611 + a = np.eye(3) + b = array(a) + assert_(type(dot(a, a)) is MaskedArray) + assert_(type(dot(a, b)) is MaskedArray) + assert_(type(dot(b, a)) is MaskedArray) + assert_(type(dot(b, b)) is MaskedArray) + + def test_dot_out(self): + a = array(np.eye(3)) + out = array(np.zeros((3, 3))) + res = dot(a, a, out=out) + assert_(res is out) + assert_equal(a, res) + + +class TestApplyAlongAxis: + # Tests 2D functions + def test_3d(self): + a = arange(12.).reshape(2, 2, 3) + + def myfunc(b): + return b[1] + + xa = apply_along_axis(myfunc, 2, a) + assert_equal(xa, [[1, 4], [7, 10]]) + + # Tests kwargs functions + def test_3d_kwargs(self): + a = arange(12).reshape(2, 2, 3) + + def myfunc(b, offset=0): + return b[1 + offset] + + xa = apply_along_axis(myfunc, 2, a, offset=1) + assert_equal(xa, [[2, 5], [8, 11]]) + + +class TestApplyOverAxes: + # Tests apply_over_axes + def test_basic(self): + a = arange(24).reshape(2, 3, 4) + test = apply_over_axes(np.sum, a, [0, 2]) + ctrl = np.array([[[60], [92], [124]]]) + assert_equal(test, ctrl) + a[(a % 2).astype(bool)] = masked + test = apply_over_axes(np.sum, a, [0, 2]) + ctrl = np.array([[[28], [44], [60]]]) + assert_equal(test, ctrl) + + +class TestMedian: + def test_pytype(self): + r = np.ma.median([[np.inf, np.inf], [np.inf, np.inf]], axis=-1) + assert_equal(r, np.inf) + + def test_inf(self): + # test that even which computes handles inf / x = masked + r = np.ma.median(np.ma.masked_array([[np.inf, np.inf], + [np.inf, np.inf]]), axis=-1) + assert_equal(r, np.inf) + r = np.ma.median(np.ma.masked_array([[np.inf, np.inf], + [np.inf, np.inf]]), axis=None) + assert_equal(r, np.inf) + # all masked + r = np.ma.median(np.ma.masked_array([[np.inf, np.inf], + [np.inf, np.inf]], mask=True), + axis=-1) + assert_equal(r.mask, True) + r = np.ma.median(np.ma.masked_array([[np.inf, np.inf], + [np.inf, np.inf]], mask=True), + axis=None) + assert_equal(r.mask, True) + + def test_non_masked(self): + x = np.arange(9) + assert_equal(np.ma.median(x), 4.) + assert_(type(np.ma.median(x)) is not MaskedArray) + x = range(8) + assert_equal(np.ma.median(x), 3.5) + assert_(type(np.ma.median(x)) is not MaskedArray) + x = 5 + assert_equal(np.ma.median(x), 5.) + assert_(type(np.ma.median(x)) is not MaskedArray) + # integer + x = np.arange(9 * 8).reshape(9, 8) + assert_equal(np.ma.median(x, axis=0), np.median(x, axis=0)) + assert_equal(np.ma.median(x, axis=1), np.median(x, axis=1)) + assert_(np.ma.median(x, axis=1) is not MaskedArray) + # float + x = np.arange(9 * 8.).reshape(9, 8) + assert_equal(np.ma.median(x, axis=0), np.median(x, axis=0)) + assert_equal(np.ma.median(x, axis=1), np.median(x, axis=1)) + assert_(np.ma.median(x, axis=1) is not MaskedArray) + + def test_docstring_examples(self): + "test the examples given in the docstring of ma.median" + x = array(np.arange(8), mask=[0] * 4 + [1] * 4) + assert_equal(np.ma.median(x), 1.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + x = array(np.arange(10).reshape(2, 5), mask=[0] * 6 + [1] * 4) + assert_equal(np.ma.median(x), 2.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + ma_x = np.ma.median(x, axis=-1, overwrite_input=True) + assert_equal(ma_x, [2., 5.]) + assert_equal(ma_x.shape, (2,), "shape mismatch") + assert_(type(ma_x) is MaskedArray) + + def test_axis_argument_errors(self): + msg = "mask = %s, ndim = %s, axis = %s, overwrite_input = %s" + for ndmin in range(5): + for mask in [False, True]: + x = array(1, ndmin=ndmin, mask=mask) + + # Valid axis values should not raise exception + args = itertools.product(range(-ndmin, ndmin), [False, True]) + for axis, over in args: + try: + np.ma.median(x, axis=axis, overwrite_input=over) + except Exception: + raise AssertionError(msg % (mask, ndmin, axis, over)) + + # Invalid axis values should raise exception + args = itertools.product([-(ndmin + 1), ndmin], [False, True]) + for axis, over in args: + try: + np.ma.median(x, axis=axis, overwrite_input=over) + except np.exceptions.AxisError: + pass + else: + raise AssertionError(msg % (mask, ndmin, axis, over)) + + def test_masked_0d(self): + # Check values + x = array(1, mask=False) + assert_equal(np.ma.median(x), 1) + x = array(1, mask=True) + assert_equal(np.ma.median(x), np.ma.masked) + + def test_masked_1d(self): + x = array(np.arange(5), mask=True) + assert_equal(np.ma.median(x), np.ma.masked) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is np.ma.core.MaskedConstant) + x = array(np.arange(5), mask=False) + assert_equal(np.ma.median(x), 2.) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + x = array(np.arange(5), mask=[0, 1, 0, 0, 0]) + assert_equal(np.ma.median(x), 2.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + x = array(np.arange(5), mask=[0, 1, 1, 1, 1]) + assert_equal(np.ma.median(x), 0.) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + # integer + x = array(np.arange(5), mask=[0, 1, 1, 0, 0]) + assert_equal(np.ma.median(x), 3.) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + # float + x = array(np.arange(5.), mask=[0, 1, 1, 0, 0]) + assert_equal(np.ma.median(x), 3.) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + # integer + x = array(np.arange(6), mask=[0, 1, 1, 1, 1, 0]) + assert_equal(np.ma.median(x), 2.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + # float + x = array(np.arange(6.), mask=[0, 1, 1, 1, 1, 0]) + assert_equal(np.ma.median(x), 2.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + + def test_1d_shape_consistency(self): + assert_equal(np.ma.median(array([1, 2, 3], mask=[0, 0, 0])).shape, + np.ma.median(array([1, 2, 3], mask=[0, 1, 0])).shape) + + def test_2d(self): + # Tests median w/ 2D + (n, p) = (101, 30) + x = masked_array(np.linspace(-1., 1., n),) + x[:10] = x[-10:] = masked + z = masked_array(np.empty((n, p), dtype=float)) + z[:, 0] = x[:] + idx = np.arange(len(x)) + for i in range(1, p): + np.random.shuffle(idx) + z[:, i] = x[idx] + assert_equal(median(z[:, 0]), 0) + assert_equal(median(z), 0) + assert_equal(median(z, axis=0), np.zeros(p)) + assert_equal(median(z.T, axis=1), np.zeros(p)) + + def test_2d_waxis(self): + # Tests median w/ 2D arrays and different axis. + x = masked_array(np.arange(30).reshape(10, 3)) + x[:3] = x[-3:] = masked + assert_equal(median(x), 14.5) + assert_(type(np.ma.median(x)) is not MaskedArray) + assert_equal(median(x, axis=0), [13.5, 14.5, 15.5]) + assert_(type(np.ma.median(x, axis=0)) is MaskedArray) + assert_equal(median(x, axis=1), [0, 0, 0, 10, 13, 16, 19, 0, 0, 0]) + assert_(type(np.ma.median(x, axis=1)) is MaskedArray) + assert_equal(median(x, axis=1).mask, [1, 1, 1, 0, 0, 0, 0, 1, 1, 1]) + + def test_3d(self): + # Tests median w/ 3D + x = np.ma.arange(24).reshape(3, 4, 2) + x[x % 3 == 0] = masked + assert_equal(median(x, 0), [[12, 9], [6, 15], [12, 9], [18, 15]]) + x.shape = (4, 3, 2) + assert_equal(median(x, 0), [[99, 10], [11, 99], [13, 14]]) + x = np.ma.arange(24).reshape(4, 3, 2) + x[x % 5 == 0] = masked + assert_equal(median(x, 0), [[12, 10], [8, 9], [16, 17]]) + + def test_neg_axis(self): + x = masked_array(np.arange(30).reshape(10, 3)) + x[:3] = x[-3:] = masked + assert_equal(median(x, axis=-1), median(x, axis=1)) + + def test_out_1d(self): + # integer float even odd + for v in (30, 30., 31, 31.): + x = masked_array(np.arange(v)) + x[:3] = x[-3:] = masked + out = masked_array(np.ones(())) + r = median(x, out=out) + if v == 30: + assert_equal(out, 14.5) + else: + assert_equal(out, 15.) + assert_(r is out) + assert_(type(r) is MaskedArray) + + def test_out(self): + # integer float even odd + for v in (40, 40., 30, 30.): + x = masked_array(np.arange(v).reshape(10, -1)) + x[:3] = x[-3:] = masked + out = masked_array(np.ones(10)) + r = median(x, axis=1, out=out) + if v == 30: + e = masked_array([0.] * 3 + [10, 13, 16, 19] + [0.] * 3, + mask=[True] * 3 + [False] * 4 + [True] * 3) + else: + e = masked_array([0.] * 3 + [13.5, 17.5, 21.5, 25.5] + [0.] * 3, + mask=[True] * 3 + [False] * 4 + [True] * 3) + assert_equal(r, e) + assert_(r is out) + assert_(type(r) is MaskedArray) + + @pytest.mark.parametrize( + argnames='axis', + argvalues=[ + None, + 1, + (1, ), + (0, 1), + (-3, -1), + ] + ) + def test_keepdims_out(self, axis): + mask = np.zeros((3, 5, 7, 11), dtype=bool) + # Randomly set some elements to True: + w = np.random.random((4, 200)) * np.array(mask.shape)[:, None] + w = w.astype(np.intp) + mask[tuple(w)] = np.nan + d = masked_array(np.ones(mask.shape), mask=mask) + if axis is None: + shape_out = (1,) * d.ndim + else: + axis_norm = normalize_axis_tuple(axis, d.ndim) + shape_out = tuple( + 1 if i in axis_norm else d.shape[i] for i in range(d.ndim)) + out = masked_array(np.empty(shape_out)) + result = median(d, axis=axis, keepdims=True, out=out) + assert result is out + assert_equal(result.shape, shape_out) + + def test_single_non_masked_value_on_axis(self): + data = [[1., 0.], + [0., 3.], + [0., 0.]] + masked_arr = np.ma.masked_equal(data, 0) + expected = [1., 3.] + assert_array_equal(np.ma.median(masked_arr, axis=0), + expected) + + def test_nan(self): + for mask in (False, np.zeros(6, dtype=bool)): + dm = np.ma.array([[1, np.nan, 3], [1, 2, 3]]) + dm.mask = mask + + # scalar result + r = np.ma.median(dm, axis=None) + assert_(np.isscalar(r)) + assert_array_equal(r, np.nan) + r = np.ma.median(dm.ravel(), axis=0) + assert_(np.isscalar(r)) + assert_array_equal(r, np.nan) + + r = np.ma.median(dm, axis=0) + assert_equal(type(r), MaskedArray) + assert_array_equal(r, [1, np.nan, 3]) + r = np.ma.median(dm, axis=1) + assert_equal(type(r), MaskedArray) + assert_array_equal(r, [np.nan, 2]) + r = np.ma.median(dm, axis=-1) + assert_equal(type(r), MaskedArray) + assert_array_equal(r, [np.nan, 2]) + + dm = np.ma.array([[1, np.nan, 3], [1, 2, 3]]) + dm[:, 2] = np.ma.masked + assert_array_equal(np.ma.median(dm, axis=None), np.nan) + assert_array_equal(np.ma.median(dm, axis=0), [1, np.nan, 3]) + assert_array_equal(np.ma.median(dm, axis=1), [np.nan, 1.5]) + + def test_out_nan(self): + o = np.ma.masked_array(np.zeros((4,))) + d = np.ma.masked_array(np.ones((3, 4))) + d[2, 1] = np.nan + d[2, 2] = np.ma.masked + assert_equal(np.ma.median(d, 0, out=o), o) + o = np.ma.masked_array(np.zeros((3,))) + assert_equal(np.ma.median(d, 1, out=o), o) + o = np.ma.masked_array(np.zeros(())) + assert_equal(np.ma.median(d, out=o), o) + + def test_nan_behavior(self): + a = np.ma.masked_array(np.arange(24, dtype=float)) + a[::3] = np.ma.masked + a[2] = np.nan + assert_array_equal(np.ma.median(a), np.nan) + assert_array_equal(np.ma.median(a, axis=0), np.nan) + + a = np.ma.masked_array(np.arange(24, dtype=float).reshape(2, 3, 4)) + a.mask = np.arange(a.size) % 2 == 1 + aorig = a.copy() + a[1, 2, 3] = np.nan + a[1, 1, 2] = np.nan + + # no axis + assert_array_equal(np.ma.median(a), np.nan) + assert_(np.isscalar(np.ma.median(a))) + + # axis0 + b = np.ma.median(aorig, axis=0) + b[2, 3] = np.nan + b[1, 2] = np.nan + assert_equal(np.ma.median(a, 0), b) + + # axis1 + b = np.ma.median(aorig, axis=1) + b[1, 3] = np.nan + b[1, 2] = np.nan + assert_equal(np.ma.median(a, 1), b) + + # axis02 + b = np.ma.median(aorig, axis=(0, 2)) + b[1] = np.nan + b[2] = np.nan + assert_equal(np.ma.median(a, (0, 2)), b) + + def test_ambigous_fill(self): + # 255 is max value, used as filler for sort + a = np.array([[3, 3, 255], [3, 3, 255]], dtype=np.uint8) + a = np.ma.masked_array(a, mask=a == 3) + assert_array_equal(np.ma.median(a, axis=1), 255) + assert_array_equal(np.ma.median(a, axis=1).mask, False) + assert_array_equal(np.ma.median(a, axis=0), a[0]) + assert_array_equal(np.ma.median(a), 255) + + def test_special(self): + for inf in [np.inf, -np.inf]: + a = np.array([[inf, np.nan], [np.nan, np.nan]]) + a = np.ma.masked_array(a, mask=np.isnan(a)) + assert_equal(np.ma.median(a, axis=0), [inf, np.nan]) + assert_equal(np.ma.median(a, axis=1), [inf, np.nan]) + assert_equal(np.ma.median(a), inf) + + a = np.array([[np.nan, np.nan, inf], [np.nan, np.nan, inf]]) + a = np.ma.masked_array(a, mask=np.isnan(a)) + assert_array_equal(np.ma.median(a, axis=1), inf) + assert_array_equal(np.ma.median(a, axis=1).mask, False) + assert_array_equal(np.ma.median(a, axis=0), a[0]) + assert_array_equal(np.ma.median(a), inf) + + # no mask + a = np.array([[inf, inf], [inf, inf]]) + assert_equal(np.ma.median(a), inf) + assert_equal(np.ma.median(a, axis=0), inf) + assert_equal(np.ma.median(a, axis=1), inf) + + a = np.array([[inf, 7, -inf, -9], + [-10, np.nan, np.nan, 5], + [4, np.nan, np.nan, inf]], + dtype=np.float32) + a = np.ma.masked_array(a, mask=np.isnan(a)) + if inf > 0: + assert_equal(np.ma.median(a, axis=0), [4., 7., -inf, 5.]) + assert_equal(np.ma.median(a), 4.5) + else: + assert_equal(np.ma.median(a, axis=0), [-10., 7., -inf, -9.]) + assert_equal(np.ma.median(a), -2.5) + assert_equal(np.ma.median(a, axis=1), [-1., -2.5, inf]) + + for i in range(10): + for j in range(1, 10): + a = np.array([([np.nan] * i) + ([inf] * j)] * 2) + a = np.ma.masked_array(a, mask=np.isnan(a)) + assert_equal(np.ma.median(a), inf) + assert_equal(np.ma.median(a, axis=1), inf) + assert_equal(np.ma.median(a, axis=0), + ([np.nan] * i) + [inf] * j) + + def test_empty(self): + # empty arrays + a = np.ma.masked_array(np.array([], dtype=float)) + with suppress_warnings() as w: + w.record(RuntimeWarning) + assert_array_equal(np.ma.median(a), np.nan) + assert_(w.log[0].category is RuntimeWarning) + + # multiple dimensions + a = np.ma.masked_array(np.array([], dtype=float, ndmin=3)) + # no axis + with suppress_warnings() as w: + w.record(RuntimeWarning) + warnings.filterwarnings('always', '', RuntimeWarning) + assert_array_equal(np.ma.median(a), np.nan) + assert_(w.log[0].category is RuntimeWarning) + + # axis 0 and 1 + b = np.ma.masked_array(np.array([], dtype=float, ndmin=2)) + assert_equal(np.ma.median(a, axis=0), b) + assert_equal(np.ma.median(a, axis=1), b) + + # axis 2 + b = np.ma.masked_array(np.array(np.nan, dtype=float, ndmin=2)) + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', '', RuntimeWarning) + assert_equal(np.ma.median(a, axis=2), b) + assert_(w[0].category is RuntimeWarning) + + def test_object(self): + o = np.ma.masked_array(np.arange(7.)) + assert_(type(np.ma.median(o.astype(object))), float) + o[2] = np.nan + assert_(type(np.ma.median(o.astype(object))), float) + + +class TestCov: + + def setup_method(self): + self.data = array(np.random.rand(12)) + + def test_covhelper(self): + x = self.data + # Test not mask output type is a float. + assert_(_covhelper(x, rowvar=True)[1].dtype, np.float32) + assert_(_covhelper(x, y=x, rowvar=False)[1].dtype, np.float32) + # Test not mask output is equal after casting to float. + mask = x > 0.5 + assert_array_equal( + _covhelper( + np.ma.masked_array(x, mask), rowvar=True + )[1].astype(bool), + ~mask.reshape(1, -1), + ) + assert_array_equal( + _covhelper( + np.ma.masked_array(x, mask), y=x, rowvar=False + )[1].astype(bool), + np.vstack((~mask, ~mask)), + ) + + def test_1d_without_missing(self): + # Test cov on 1D variable w/o missing values + x = self.data + assert_almost_equal(np.cov(x), cov(x)) + assert_almost_equal(np.cov(x, rowvar=False), cov(x, rowvar=False)) + assert_almost_equal(np.cov(x, rowvar=False, bias=True), + cov(x, rowvar=False, bias=True)) + + def test_2d_without_missing(self): + # Test cov on 1 2D variable w/o missing values + x = self.data.reshape(3, 4) + assert_almost_equal(np.cov(x), cov(x)) + assert_almost_equal(np.cov(x, rowvar=False), cov(x, rowvar=False)) + assert_almost_equal(np.cov(x, rowvar=False, bias=True), + cov(x, rowvar=False, bias=True)) + + def test_1d_with_missing(self): + # Test cov 1 1D variable w/missing values + x = self.data + x[-1] = masked + x -= x.mean() + nx = x.compressed() + assert_almost_equal(np.cov(nx), cov(x)) + assert_almost_equal(np.cov(nx, rowvar=False), cov(x, rowvar=False)) + assert_almost_equal(np.cov(nx, rowvar=False, bias=True), + cov(x, rowvar=False, bias=True)) + # + try: + cov(x, allow_masked=False) + except ValueError: + pass + # + # 2 1D variables w/ missing values + nx = x[1:-1] + assert_almost_equal(np.cov(nx, nx[::-1]), cov(x, x[::-1])) + assert_almost_equal(np.cov(nx, nx[::-1], rowvar=False), + cov(x, x[::-1], rowvar=False)) + assert_almost_equal(np.cov(nx, nx[::-1], rowvar=False, bias=True), + cov(x, x[::-1], rowvar=False, bias=True)) + + def test_2d_with_missing(self): + # Test cov on 2D variable w/ missing value + x = self.data + x[-1] = masked + x = x.reshape(3, 4) + valid = np.logical_not(getmaskarray(x)).astype(int) + frac = np.dot(valid, valid.T) + xf = (x - x.mean(1)[:, None]).filled(0) + assert_almost_equal(cov(x), + np.cov(xf) * (x.shape[1] - 1) / (frac - 1.)) + assert_almost_equal(cov(x, bias=True), + np.cov(xf, bias=True) * x.shape[1] / frac) + frac = np.dot(valid.T, valid) + xf = (x - x.mean(0)).filled(0) + assert_almost_equal(cov(x, rowvar=False), + (np.cov(xf, rowvar=False) * + (x.shape[0] - 1) / (frac - 1.))) + assert_almost_equal(cov(x, rowvar=False, bias=True), + (np.cov(xf, rowvar=False, bias=True) * + x.shape[0] / frac)) + + +class TestCorrcoef: + + def setup_method(self): + self.data = array(np.random.rand(12)) + self.data2 = array(np.random.rand(12)) + + def test_ddof(self): + # ddof raises DeprecationWarning + x, y = self.data, self.data2 + expected = np.corrcoef(x) + expected2 = np.corrcoef(x, y) + with suppress_warnings() as sup: + warnings.simplefilter("always") + assert_warns(DeprecationWarning, corrcoef, x, ddof=-1) + sup.filter(DeprecationWarning, "bias and ddof have no effect") + # ddof has no or negligible effect on the function + assert_almost_equal(np.corrcoef(x, ddof=0), corrcoef(x, ddof=0)) + assert_almost_equal(corrcoef(x, ddof=-1), expected) + assert_almost_equal(corrcoef(x, y, ddof=-1), expected2) + assert_almost_equal(corrcoef(x, ddof=3), expected) + assert_almost_equal(corrcoef(x, y, ddof=3), expected2) + + def test_bias(self): + x, y = self.data, self.data2 + expected = np.corrcoef(x) + # bias raises DeprecationWarning + with suppress_warnings() as sup: + warnings.simplefilter("always") + assert_warns(DeprecationWarning, corrcoef, x, y, True, False) + assert_warns(DeprecationWarning, corrcoef, x, y, True, True) + assert_warns(DeprecationWarning, corrcoef, x, bias=False) + sup.filter(DeprecationWarning, "bias and ddof have no effect") + # bias has no or negligible effect on the function + assert_almost_equal(corrcoef(x, bias=1), expected) + + def test_1d_without_missing(self): + # Test cov on 1D variable w/o missing values + x = self.data + assert_almost_equal(np.corrcoef(x), corrcoef(x)) + assert_almost_equal(np.corrcoef(x, rowvar=False), + corrcoef(x, rowvar=False)) + with suppress_warnings() as sup: + sup.filter(DeprecationWarning, "bias and ddof have no effect") + assert_almost_equal(np.corrcoef(x, rowvar=False, bias=True), + corrcoef(x, rowvar=False, bias=True)) + + def test_2d_without_missing(self): + # Test corrcoef on 1 2D variable w/o missing values + x = self.data.reshape(3, 4) + assert_almost_equal(np.corrcoef(x), corrcoef(x)) + assert_almost_equal(np.corrcoef(x, rowvar=False), + corrcoef(x, rowvar=False)) + with suppress_warnings() as sup: + sup.filter(DeprecationWarning, "bias and ddof have no effect") + assert_almost_equal(np.corrcoef(x, rowvar=False, bias=True), + corrcoef(x, rowvar=False, bias=True)) + + def test_1d_with_missing(self): + # Test corrcoef 1 1D variable w/missing values + x = self.data + x[-1] = masked + x -= x.mean() + nx = x.compressed() + assert_almost_equal(np.corrcoef(nx), corrcoef(x)) + assert_almost_equal(np.corrcoef(nx, rowvar=False), + corrcoef(x, rowvar=False)) + with suppress_warnings() as sup: + sup.filter(DeprecationWarning, "bias and ddof have no effect") + assert_almost_equal(np.corrcoef(nx, rowvar=False, bias=True), + corrcoef(x, rowvar=False, bias=True)) + try: + corrcoef(x, allow_masked=False) + except ValueError: + pass + # 2 1D variables w/ missing values + nx = x[1:-1] + assert_almost_equal(np.corrcoef(nx, nx[::-1]), corrcoef(x, x[::-1])) + assert_almost_equal(np.corrcoef(nx, nx[::-1], rowvar=False), + corrcoef(x, x[::-1], rowvar=False)) + with suppress_warnings() as sup: + sup.filter(DeprecationWarning, "bias and ddof have no effect") + # ddof and bias have no or negligible effect on the function + assert_almost_equal(np.corrcoef(nx, nx[::-1]), + corrcoef(x, x[::-1], bias=1)) + assert_almost_equal(np.corrcoef(nx, nx[::-1]), + corrcoef(x, x[::-1], ddof=2)) + + def test_2d_with_missing(self): + # Test corrcoef on 2D variable w/ missing value + x = self.data + x[-1] = masked + x = x.reshape(3, 4) + + test = corrcoef(x) + control = np.corrcoef(x) + assert_almost_equal(test[:-1, :-1], control[:-1, :-1]) + with suppress_warnings() as sup: + sup.filter(DeprecationWarning, "bias and ddof have no effect") + # ddof and bias have no or negligible effect on the function + assert_almost_equal(corrcoef(x, ddof=-2)[:-1, :-1], + control[:-1, :-1]) + assert_almost_equal(corrcoef(x, ddof=3)[:-1, :-1], + control[:-1, :-1]) + assert_almost_equal(corrcoef(x, bias=1)[:-1, :-1], + control[:-1, :-1]) + + +class TestPolynomial: + # + def test_polyfit(self): + # Tests polyfit + # On ndarrays + x = np.random.rand(10) + y = np.random.rand(20).reshape(-1, 2) + assert_almost_equal(polyfit(x, y, 3), np.polyfit(x, y, 3)) + # ON 1D maskedarrays + x = x.view(MaskedArray) + x[0] = masked + y = y.view(MaskedArray) + y[0, 0] = y[-1, -1] = masked + # + (C, R, K, S, D) = polyfit(x, y[:, 0], 3, full=True) + (c, r, k, s, d) = np.polyfit(x[1:], y[1:, 0].compressed(), 3, + full=True) + for (a, a_) in zip((C, R, K, S, D), (c, r, k, s, d)): + assert_almost_equal(a, a_) + # + (C, R, K, S, D) = polyfit(x, y[:, -1], 3, full=True) + (c, r, k, s, d) = np.polyfit(x[1:-1], y[1:-1, -1], 3, full=True) + for (a, a_) in zip((C, R, K, S, D), (c, r, k, s, d)): + assert_almost_equal(a, a_) + # + (C, R, K, S, D) = polyfit(x, y, 3, full=True) + (c, r, k, s, d) = np.polyfit(x[1:-1], y[1:-1, :], 3, full=True) + for (a, a_) in zip((C, R, K, S, D), (c, r, k, s, d)): + assert_almost_equal(a, a_) + # + w = np.random.rand(10) + 1 + wo = w.copy() + xs = x[1:-1] + ys = y[1:-1] + ws = w[1:-1] + (C, R, K, S, D) = polyfit(x, y, 3, full=True, w=w) + (c, r, k, s, d) = np.polyfit(xs, ys, 3, full=True, w=ws) + assert_equal(w, wo) + for (a, a_) in zip((C, R, K, S, D), (c, r, k, s, d)): + assert_almost_equal(a, a_) + + def test_polyfit_with_masked_NaNs(self): + x = np.random.rand(10) + y = np.random.rand(20).reshape(-1, 2) + + x[0] = np.nan + y[-1, -1] = np.nan + x = x.view(MaskedArray) + y = y.view(MaskedArray) + x[0] = masked + y[-1, -1] = masked + + (C, R, K, S, D) = polyfit(x, y, 3, full=True) + (c, r, k, s, d) = np.polyfit(x[1:-1], y[1:-1, :], 3, full=True) + for (a, a_) in zip((C, R, K, S, D), (c, r, k, s, d)): + assert_almost_equal(a, a_) + + +class TestArraySetOps: + + def test_unique_onlist(self): + # Test unique on list + data = [1, 1, 1, 2, 2, 3] + test = unique(data, return_index=True, return_inverse=True) + assert_(isinstance(test[0], MaskedArray)) + assert_equal(test[0], masked_array([1, 2, 3], mask=[0, 0, 0])) + assert_equal(test[1], [0, 3, 5]) + assert_equal(test[2], [0, 0, 0, 1, 1, 2]) + + def test_unique_onmaskedarray(self): + # Test unique on masked data w/use_mask=True + data = masked_array([1, 1, 1, 2, 2, 3], mask=[0, 0, 1, 0, 1, 0]) + test = unique(data, return_index=True, return_inverse=True) + assert_equal(test[0], masked_array([1, 2, 3, -1], mask=[0, 0, 0, 1])) + assert_equal(test[1], [0, 3, 5, 2]) + assert_equal(test[2], [0, 0, 3, 1, 3, 2]) + # + data.fill_value = 3 + data = masked_array(data=[1, 1, 1, 2, 2, 3], + mask=[0, 0, 1, 0, 1, 0], fill_value=3) + test = unique(data, return_index=True, return_inverse=True) + assert_equal(test[0], masked_array([1, 2, 3, -1], mask=[0, 0, 0, 1])) + assert_equal(test[1], [0, 3, 5, 2]) + assert_equal(test[2], [0, 0, 3, 1, 3, 2]) + + def test_unique_allmasked(self): + # Test all masked + data = masked_array([1, 1, 1], mask=True) + test = unique(data, return_index=True, return_inverse=True) + assert_equal(test[0], masked_array([1, ], mask=[True])) + assert_equal(test[1], [0]) + assert_equal(test[2], [0, 0, 0]) + # + # Test masked + data = masked + test = unique(data, return_index=True, return_inverse=True) + assert_equal(test[0], masked_array(masked)) + assert_equal(test[1], [0]) + assert_equal(test[2], [0]) + + def test_ediff1d(self): + # Tests mediff1d + x = masked_array(np.arange(5), mask=[1, 0, 0, 0, 1]) + control = array([1, 1, 1, 4], mask=[1, 0, 0, 1]) + test = ediff1d(x) + assert_equal(test, control) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + + def test_ediff1d_tobegin(self): + # Test ediff1d w/ to_begin + x = masked_array(np.arange(5), mask=[1, 0, 0, 0, 1]) + test = ediff1d(x, to_begin=masked) + control = array([0, 1, 1, 1, 4], mask=[1, 1, 0, 0, 1]) + assert_equal(test, control) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + # + test = ediff1d(x, to_begin=[1, 2, 3]) + control = array([1, 2, 3, 1, 1, 1, 4], mask=[0, 0, 0, 1, 0, 0, 1]) + assert_equal(test, control) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + + def test_ediff1d_toend(self): + # Test ediff1d w/ to_end + x = masked_array(np.arange(5), mask=[1, 0, 0, 0, 1]) + test = ediff1d(x, to_end=masked) + control = array([1, 1, 1, 4, 0], mask=[1, 0, 0, 1, 1]) + assert_equal(test, control) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + # + test = ediff1d(x, to_end=[1, 2, 3]) + control = array([1, 1, 1, 4, 1, 2, 3], mask=[1, 0, 0, 1, 0, 0, 0]) + assert_equal(test, control) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + + def test_ediff1d_tobegin_toend(self): + # Test ediff1d w/ to_begin and to_end + x = masked_array(np.arange(5), mask=[1, 0, 0, 0, 1]) + test = ediff1d(x, to_end=masked, to_begin=masked) + control = array([0, 1, 1, 1, 4, 0], mask=[1, 1, 0, 0, 1, 1]) + assert_equal(test, control) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + # + test = ediff1d(x, to_end=[1, 2, 3], to_begin=masked) + control = array([0, 1, 1, 1, 4, 1, 2, 3], + mask=[1, 1, 0, 0, 1, 0, 0, 0]) + assert_equal(test, control) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + + def test_ediff1d_ndarray(self): + # Test ediff1d w/ a ndarray + x = np.arange(5) + test = ediff1d(x) + control = array([1, 1, 1, 1], mask=[0, 0, 0, 0]) + assert_equal(test, control) + assert_(isinstance(test, MaskedArray)) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + # + test = ediff1d(x, to_end=masked, to_begin=masked) + control = array([0, 1, 1, 1, 1, 0], mask=[1, 0, 0, 0, 0, 1]) + assert_(isinstance(test, MaskedArray)) + assert_equal(test.filled(0), control.filled(0)) + assert_equal(test.mask, control.mask) + + def test_intersect1d(self): + # Test intersect1d + x = array([1, 3, 3, 3], mask=[0, 0, 0, 1]) + y = array([3, 1, 1, 1], mask=[0, 0, 0, 1]) + test = intersect1d(x, y) + control = array([1, 3, -1], mask=[0, 0, 1]) + assert_equal(test, control) + + def test_setxor1d(self): + # Test setxor1d + a = array([1, 2, 5, 7, -1], mask=[0, 0, 0, 0, 1]) + b = array([1, 2, 3, 4, 5, -1], mask=[0, 0, 0, 0, 0, 1]) + test = setxor1d(a, b) + assert_equal(test, array([3, 4, 7])) + # + a = array([1, 2, 5, 7, -1], mask=[0, 0, 0, 0, 1]) + b = [1, 2, 3, 4, 5] + test = setxor1d(a, b) + assert_equal(test, array([3, 4, 7, -1], mask=[0, 0, 0, 1])) + # + a = array([1, 2, 3]) + b = array([6, 5, 4]) + test = setxor1d(a, b) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, [1, 2, 3, 4, 5, 6]) + # + a = array([1, 8, 2, 3], mask=[0, 1, 0, 0]) + b = array([6, 5, 4, 8], mask=[0, 0, 0, 1]) + test = setxor1d(a, b) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, [1, 2, 3, 4, 5, 6]) + # + assert_array_equal([], setxor1d([], [])) + + def test_setxor1d_unique(self): + # Test setxor1d with assume_unique=True + a = array([1, 2, 5, 7, -1], mask=[0, 0, 0, 0, 1]) + b = [1, 2, 3, 4, 5] + test = setxor1d(a, b, assume_unique=True) + assert_equal(test, array([3, 4, 7, -1], mask=[0, 0, 0, 1])) + # + a = array([1, 8, 2, 3], mask=[0, 1, 0, 0]) + b = array([6, 5, 4, 8], mask=[0, 0, 0, 1]) + test = setxor1d(a, b, assume_unique=True) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, [1, 2, 3, 4, 5, 6]) + # + a = array([[1], [8], [2], [3]]) + b = array([[6, 5], [4, 8]]) + test = setxor1d(a, b, assume_unique=True) + assert_(isinstance(test, MaskedArray)) + assert_equal(test, [1, 2, 3, 4, 5, 6]) + + def test_isin(self): + # the tests for in1d cover most of isin's behavior + # if in1d is removed, would need to change those tests to test + # isin instead. + a = np.arange(24).reshape([2, 3, 4]) + mask = np.zeros([2, 3, 4]) + mask[1, 2, 0] = 1 + a = array(a, mask=mask) + b = array(data=[0, 10, 20, 30, 1, 3, 11, 22, 33], + mask=[0, 1, 0, 1, 0, 1, 0, 1, 0]) + ec = zeros((2, 3, 4), dtype=bool) + ec[0, 0, 0] = True + ec[0, 0, 1] = True + ec[0, 2, 3] = True + c = isin(a, b) + assert_(isinstance(c, MaskedArray)) + assert_array_equal(c, ec) + # compare results of np.isin to ma.isin + d = np.isin(a, b[~b.mask]) & ~a.mask + assert_array_equal(c, d) + + def test_in1d(self): + # Test in1d + a = array([1, 2, 5, 7, -1], mask=[0, 0, 0, 0, 1]) + b = array([1, 2, 3, 4, 5, -1], mask=[0, 0, 0, 0, 0, 1]) + test = in1d(a, b) + assert_equal(test, [True, True, True, False, True]) + # + a = array([5, 5, 2, 1, -1], mask=[0, 0, 0, 0, 1]) + b = array([1, 5, -1], mask=[0, 0, 1]) + test = in1d(a, b) + assert_equal(test, [True, True, False, True, True]) + # + assert_array_equal([], in1d([], [])) + + def test_in1d_invert(self): + # Test in1d's invert parameter + a = array([1, 2, 5, 7, -1], mask=[0, 0, 0, 0, 1]) + b = array([1, 2, 3, 4, 5, -1], mask=[0, 0, 0, 0, 0, 1]) + assert_equal(np.invert(in1d(a, b)), in1d(a, b, invert=True)) + + a = array([5, 5, 2, 1, -1], mask=[0, 0, 0, 0, 1]) + b = array([1, 5, -1], mask=[0, 0, 1]) + assert_equal(np.invert(in1d(a, b)), in1d(a, b, invert=True)) + + assert_array_equal([], in1d([], [], invert=True)) + + def test_union1d(self): + # Test union1d + a = array([1, 2, 5, 7, 5, -1], mask=[0, 0, 0, 0, 0, 1]) + b = array([1, 2, 3, 4, 5, -1], mask=[0, 0, 0, 0, 0, 1]) + test = union1d(a, b) + control = array([1, 2, 3, 4, 5, 7, -1], mask=[0, 0, 0, 0, 0, 0, 1]) + assert_equal(test, control) + + # Tests gh-10340, arguments to union1d should be + # flattened if they are not already 1D + x = array([[0, 1, 2], [3, 4, 5]], mask=[[0, 0, 0], [0, 0, 1]]) + y = array([0, 1, 2, 3, 4], mask=[0, 0, 0, 0, 1]) + ez = array([0, 1, 2, 3, 4, 5], mask=[0, 0, 0, 0, 0, 1]) + z = union1d(x, y) + assert_equal(z, ez) + # + assert_array_equal([], union1d([], [])) + + def test_setdiff1d(self): + # Test setdiff1d + a = array([6, 5, 4, 7, 7, 1, 2, 1], mask=[0, 0, 0, 0, 0, 0, 0, 1]) + b = array([2, 4, 3, 3, 2, 1, 5]) + test = setdiff1d(a, b) + assert_equal(test, array([6, 7, -1], mask=[0, 0, 1])) + # + a = arange(10) + b = arange(8) + assert_equal(setdiff1d(a, b), array([8, 9])) + a = array([], np.uint32, mask=[]) + assert_equal(setdiff1d(a, []).dtype, np.uint32) + + def test_setdiff1d_char_array(self): + # Test setdiff1d_charray + a = np.array(['a', 'b', 'c']) + b = np.array(['a', 'b', 's']) + assert_array_equal(setdiff1d(a, b), np.array(['c'])) + + +class TestShapeBase: + + def test_atleast_2d(self): + # Test atleast_2d + a = masked_array([0, 1, 2], mask=[0, 1, 0]) + b = atleast_2d(a) + assert_equal(b.shape, (1, 3)) + assert_equal(b.mask.shape, b.data.shape) + assert_equal(a.shape, (3,)) + assert_equal(a.mask.shape, a.data.shape) + assert_equal(b.mask.shape, b.data.shape) + + def test_shape_scalar(self): + # the atleast and diagflat function should work with scalars + # GitHub issue #3367 + # Additionally, the atleast functions should accept multiple scalars + # correctly + b = atleast_1d(1.0) + assert_equal(b.shape, (1,)) + assert_equal(b.mask.shape, b.shape) + assert_equal(b.data.shape, b.shape) + + b = atleast_1d(1.0, 2.0) + for a in b: + assert_equal(a.shape, (1,)) + assert_equal(a.mask.shape, a.shape) + assert_equal(a.data.shape, a.shape) + + b = atleast_2d(1.0) + assert_equal(b.shape, (1, 1)) + assert_equal(b.mask.shape, b.shape) + assert_equal(b.data.shape, b.shape) + + b = atleast_2d(1.0, 2.0) + for a in b: + assert_equal(a.shape, (1, 1)) + assert_equal(a.mask.shape, a.shape) + assert_equal(a.data.shape, a.shape) + + b = atleast_3d(1.0) + assert_equal(b.shape, (1, 1, 1)) + assert_equal(b.mask.shape, b.shape) + assert_equal(b.data.shape, b.shape) + + b = atleast_3d(1.0, 2.0) + for a in b: + assert_equal(a.shape, (1, 1, 1)) + assert_equal(a.mask.shape, a.shape) + assert_equal(a.data.shape, a.shape) + + b = diagflat(1.0) + assert_equal(b.shape, (1, 1)) + assert_equal(b.mask.shape, b.data.shape) + + +class TestNDEnumerate: + + def test_ndenumerate_nomasked(self): + ordinary = np.arange(6.).reshape((1, 3, 2)) + empty_mask = np.zeros_like(ordinary, dtype=bool) + with_mask = masked_array(ordinary, mask=empty_mask) + assert_equal(list(np.ndenumerate(ordinary)), + list(ndenumerate(ordinary))) + assert_equal(list(ndenumerate(ordinary)), + list(ndenumerate(with_mask))) + assert_equal(list(ndenumerate(with_mask)), + list(ndenumerate(with_mask, compressed=False))) + + def test_ndenumerate_allmasked(self): + a = masked_all(()) + b = masked_all((100,)) + c = masked_all((2, 3, 4)) + assert_equal(list(ndenumerate(a)), []) + assert_equal(list(ndenumerate(b)), []) + assert_equal(list(ndenumerate(b, compressed=False)), + list(zip(np.ndindex((100,)), 100 * [masked]))) + assert_equal(list(ndenumerate(c)), []) + assert_equal(list(ndenumerate(c, compressed=False)), + list(zip(np.ndindex((2, 3, 4)), 2 * 3 * 4 * [masked]))) + + def test_ndenumerate_mixedmasked(self): + a = masked_array(np.arange(12).reshape((3, 4)), + mask=[[1, 1, 1, 1], + [1, 1, 0, 1], + [0, 0, 0, 0]]) + items = [((1, 2), 6), + ((2, 0), 8), ((2, 1), 9), ((2, 2), 10), ((2, 3), 11)] + assert_equal(list(ndenumerate(a)), items) + assert_equal(len(list(ndenumerate(a, compressed=False))), a.size) + for coordinate, value in ndenumerate(a, compressed=False): + assert_equal(a[coordinate], value) + + +class TestStack: + + def test_stack_1d(self): + a = masked_array([0, 1, 2], mask=[0, 1, 0]) + b = masked_array([9, 8, 7], mask=[1, 0, 0]) + + c = stack([a, b], axis=0) + assert_equal(c.shape, (2, 3)) + assert_array_equal(a.mask, c[0].mask) + assert_array_equal(b.mask, c[1].mask) + + d = vstack([a, b]) + assert_array_equal(c.data, d.data) + assert_array_equal(c.mask, d.mask) + + c = stack([a, b], axis=1) + assert_equal(c.shape, (3, 2)) + assert_array_equal(a.mask, c[:, 0].mask) + assert_array_equal(b.mask, c[:, 1].mask) + + def test_stack_masks(self): + a = masked_array([0, 1, 2], mask=True) + b = masked_array([9, 8, 7], mask=False) + + c = stack([a, b], axis=0) + assert_equal(c.shape, (2, 3)) + assert_array_equal(a.mask, c[0].mask) + assert_array_equal(b.mask, c[1].mask) + + d = vstack([a, b]) + assert_array_equal(c.data, d.data) + assert_array_equal(c.mask, d.mask) + + c = stack([a, b], axis=1) + assert_equal(c.shape, (3, 2)) + assert_array_equal(a.mask, c[:, 0].mask) + assert_array_equal(b.mask, c[:, 1].mask) + + def test_stack_nd(self): + # 2D + shp = (3, 2) + d1 = np.random.randint(0, 10, shp) + d2 = np.random.randint(0, 10, shp) + m1 = np.random.randint(0, 2, shp).astype(bool) + m2 = np.random.randint(0, 2, shp).astype(bool) + a1 = masked_array(d1, mask=m1) + a2 = masked_array(d2, mask=m2) + + c = stack([a1, a2], axis=0) + c_shp = (2,) + shp + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[0].mask) + assert_array_equal(a2.mask, c[1].mask) + + c = stack([a1, a2], axis=-1) + c_shp = shp + (2,) + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[..., 0].mask) + assert_array_equal(a2.mask, c[..., 1].mask) + + # 4D + shp = (3, 2, 4, 5,) + d1 = np.random.randint(0, 10, shp) + d2 = np.random.randint(0, 10, shp) + m1 = np.random.randint(0, 2, shp).astype(bool) + m2 = np.random.randint(0, 2, shp).astype(bool) + a1 = masked_array(d1, mask=m1) + a2 = masked_array(d2, mask=m2) + + c = stack([a1, a2], axis=0) + c_shp = (2,) + shp + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[0].mask) + assert_array_equal(a2.mask, c[1].mask) + + c = stack([a1, a2], axis=-1) + c_shp = shp + (2,) + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[..., 0].mask) + assert_array_equal(a2.mask, c[..., 1].mask) diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/test_mrecords.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_mrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..0da915101511aac28d913ce0e74d27a828d18a52 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_mrecords.py @@ -0,0 +1,497 @@ +"""Tests suite for mrecords. + +:author: Pierre Gerard-Marchant +:contact: pierregm_at_uga_dot_edu + +""" +import pickle + +import numpy as np +import numpy.ma as ma +from numpy._core.records import fromarrays as recfromarrays +from numpy._core.records import fromrecords as recfromrecords +from numpy._core.records import recarray +from numpy.ma import masked, nomask +from numpy.ma.mrecords import ( + MaskedRecords, + addfield, + fromarrays, + fromrecords, + fromtextfile, + mrecarray, +) +from numpy.ma.testutils import ( + assert_, + assert_equal, + assert_equal_records, +) +from numpy.testing import temppath + + +class TestMRecords: + + ilist = [1, 2, 3, 4, 5] + flist = [1.1, 2.2, 3.3, 4.4, 5.5] + slist = [b'one', b'two', b'three', b'four', b'five'] + ddtype = [('a', int), ('b', float), ('c', '|S8')] + mask = [0, 1, 0, 0, 1] + base = ma.array(list(zip(ilist, flist, slist)), mask=mask, dtype=ddtype) + + def test_byview(self): + # Test creation by view + base = self.base + mbase = base.view(mrecarray) + assert_equal(mbase.recordmask, base.recordmask) + assert_equal_records(mbase._mask, base._mask) + assert_(isinstance(mbase._data, recarray)) + assert_equal_records(mbase._data, base._data.view(recarray)) + for field in ('a', 'b', 'c'): + assert_equal(base[field], mbase[field]) + assert_equal_records(mbase.view(mrecarray), mbase) + + def test_get(self): + # Tests fields retrieval + base = self.base.copy() + mbase = base.view(mrecarray) + # As fields.......... + for field in ('a', 'b', 'c'): + assert_equal(getattr(mbase, field), mbase[field]) + assert_equal(base[field], mbase[field]) + # as elements ....... + mbase_first = mbase[0] + assert_(isinstance(mbase_first, mrecarray)) + assert_equal(mbase_first.dtype, mbase.dtype) + assert_equal(mbase_first.tolist(), (1, 1.1, b'one')) + # Used to be mask, now it's recordmask + assert_equal(mbase_first.recordmask, nomask) + assert_equal(mbase_first._mask.item(), (False, False, False)) + assert_equal(mbase_first['a'], mbase['a'][0]) + mbase_last = mbase[-1] + assert_(isinstance(mbase_last, mrecarray)) + assert_equal(mbase_last.dtype, mbase.dtype) + assert_equal(mbase_last.tolist(), (None, None, None)) + # Used to be mask, now it's recordmask + assert_equal(mbase_last.recordmask, True) + assert_equal(mbase_last._mask.item(), (True, True, True)) + assert_equal(mbase_last['a'], mbase['a'][-1]) + assert_(mbase_last['a'] is masked) + # as slice .......... + mbase_sl = mbase[:2] + assert_(isinstance(mbase_sl, mrecarray)) + assert_equal(mbase_sl.dtype, mbase.dtype) + # Used to be mask, now it's recordmask + assert_equal(mbase_sl.recordmask, [0, 1]) + assert_equal_records(mbase_sl.mask, + np.array([(False, False, False), + (True, True, True)], + dtype=mbase._mask.dtype)) + assert_equal_records(mbase_sl, base[:2].view(mrecarray)) + for field in ('a', 'b', 'c'): + assert_equal(getattr(mbase_sl, field), base[:2][field]) + + def test_set_fields(self): + # Tests setting fields. + base = self.base.copy() + mbase = base.view(mrecarray) + mbase = mbase.copy() + mbase.fill_value = (999999, 1e20, 'N/A') + # Change the data, the mask should be conserved + mbase.a._data[:] = 5 + assert_equal(mbase['a']._data, [5, 5, 5, 5, 5]) + assert_equal(mbase['a']._mask, [0, 1, 0, 0, 1]) + # Change the elements, and the mask will follow + mbase.a = 1 + assert_equal(mbase['a']._data, [1] * 5) + assert_equal(ma.getmaskarray(mbase['a']), [0] * 5) + # Use to be _mask, now it's recordmask + assert_equal(mbase.recordmask, [False] * 5) + assert_equal(mbase._mask.tolist(), + np.array([(0, 0, 0), + (0, 1, 1), + (0, 0, 0), + (0, 0, 0), + (0, 1, 1)], + dtype=bool)) + # Set a field to mask ........................ + mbase.c = masked + # Use to be mask, and now it's still mask ! + assert_equal(mbase.c.mask, [1] * 5) + assert_equal(mbase.c.recordmask, [1] * 5) + assert_equal(ma.getmaskarray(mbase['c']), [1] * 5) + assert_equal(ma.getdata(mbase['c']), [b'N/A'] * 5) + assert_equal(mbase._mask.tolist(), + np.array([(0, 0, 1), + (0, 1, 1), + (0, 0, 1), + (0, 0, 1), + (0, 1, 1)], + dtype=bool)) + # Set fields by slices ....................... + mbase = base.view(mrecarray).copy() + mbase.a[3:] = 5 + assert_equal(mbase.a, [1, 2, 3, 5, 5]) + assert_equal(mbase.a._mask, [0, 1, 0, 0, 0]) + mbase.b[3:] = masked + assert_equal(mbase.b, base['b']) + assert_equal(mbase.b._mask, [0, 1, 0, 1, 1]) + # Set fields globally.......................... + ndtype = [('alpha', '|S1'), ('num', int)] + data = ma.array([('a', 1), ('b', 2), ('c', 3)], dtype=ndtype) + rdata = data.view(MaskedRecords) + val = ma.array([10, 20, 30], mask=[1, 0, 0]) + + rdata['num'] = val + assert_equal(rdata.num, val) + assert_equal(rdata.num.mask, [1, 0, 0]) + + def test_set_fields_mask(self): + # Tests setting the mask of a field. + base = self.base.copy() + # This one has already a mask.... + mbase = base.view(mrecarray) + mbase['a'][-2] = masked + assert_equal(mbase.a, [1, 2, 3, 4, 5]) + assert_equal(mbase.a._mask, [0, 1, 0, 1, 1]) + # This one has not yet + mbase = fromarrays([np.arange(5), np.random.rand(5)], + dtype=[('a', int), ('b', float)]) + mbase['a'][-2] = masked + assert_equal(mbase.a, [0, 1, 2, 3, 4]) + assert_equal(mbase.a._mask, [0, 0, 0, 1, 0]) + + def test_set_mask(self): + base = self.base.copy() + mbase = base.view(mrecarray) + # Set the mask to True ....................... + mbase.mask = masked + assert_equal(ma.getmaskarray(mbase['b']), [1] * 5) + assert_equal(mbase['a']._mask, mbase['b']._mask) + assert_equal(mbase['a']._mask, mbase['c']._mask) + assert_equal(mbase._mask.tolist(), + np.array([(1, 1, 1)] * 5, dtype=bool)) + # Delete the mask ............................ + mbase.mask = nomask + assert_equal(ma.getmaskarray(mbase['c']), [0] * 5) + assert_equal(mbase._mask.tolist(), + np.array([(0, 0, 0)] * 5, dtype=bool)) + + def test_set_mask_fromarray(self): + base = self.base.copy() + mbase = base.view(mrecarray) + # Sets the mask w/ an array + mbase.mask = [1, 0, 0, 0, 1] + assert_equal(mbase.a.mask, [1, 0, 0, 0, 1]) + assert_equal(mbase.b.mask, [1, 0, 0, 0, 1]) + assert_equal(mbase.c.mask, [1, 0, 0, 0, 1]) + # Yay, once more ! + mbase.mask = [0, 0, 0, 0, 1] + assert_equal(mbase.a.mask, [0, 0, 0, 0, 1]) + assert_equal(mbase.b.mask, [0, 0, 0, 0, 1]) + assert_equal(mbase.c.mask, [0, 0, 0, 0, 1]) + + def test_set_mask_fromfields(self): + mbase = self.base.copy().view(mrecarray) + + nmask = np.array( + [(0, 1, 0), (0, 1, 0), (1, 0, 1), (1, 0, 1), (0, 0, 0)], + dtype=[('a', bool), ('b', bool), ('c', bool)]) + mbase.mask = nmask + assert_equal(mbase.a.mask, [0, 0, 1, 1, 0]) + assert_equal(mbase.b.mask, [1, 1, 0, 0, 0]) + assert_equal(mbase.c.mask, [0, 0, 1, 1, 0]) + # Reinitialize and redo + mbase.mask = False + mbase.fieldmask = nmask + assert_equal(mbase.a.mask, [0, 0, 1, 1, 0]) + assert_equal(mbase.b.mask, [1, 1, 0, 0, 0]) + assert_equal(mbase.c.mask, [0, 0, 1, 1, 0]) + + def test_set_elements(self): + base = self.base.copy() + # Set an element to mask ..................... + mbase = base.view(mrecarray).copy() + mbase[-2] = masked + assert_equal( + mbase._mask.tolist(), + np.array([(0, 0, 0), (1, 1, 1), (0, 0, 0), (1, 1, 1), (1, 1, 1)], + dtype=bool)) + # Used to be mask, now it's recordmask! + assert_equal(mbase.recordmask, [0, 1, 0, 1, 1]) + # Set slices ................................. + mbase = base.view(mrecarray).copy() + mbase[:2] = (5, 5, 5) + assert_equal(mbase.a._data, [5, 5, 3, 4, 5]) + assert_equal(mbase.a._mask, [0, 0, 0, 0, 1]) + assert_equal(mbase.b._data, [5., 5., 3.3, 4.4, 5.5]) + assert_equal(mbase.b._mask, [0, 0, 0, 0, 1]) + assert_equal(mbase.c._data, + [b'5', b'5', b'three', b'four', b'five']) + assert_equal(mbase.b._mask, [0, 0, 0, 0, 1]) + + mbase = base.view(mrecarray).copy() + mbase[:2] = masked + assert_equal(mbase.a._data, [1, 2, 3, 4, 5]) + assert_equal(mbase.a._mask, [1, 1, 0, 0, 1]) + assert_equal(mbase.b._data, [1.1, 2.2, 3.3, 4.4, 5.5]) + assert_equal(mbase.b._mask, [1, 1, 0, 0, 1]) + assert_equal(mbase.c._data, + [b'one', b'two', b'three', b'four', b'five']) + assert_equal(mbase.b._mask, [1, 1, 0, 0, 1]) + + def test_setslices_hardmask(self): + # Tests setting slices w/ hardmask. + base = self.base.copy() + mbase = base.view(mrecarray) + mbase.harden_mask() + try: + mbase[-2:] = (5, 5, 5) + assert_equal(mbase.a._data, [1, 2, 3, 5, 5]) + assert_equal(mbase.b._data, [1.1, 2.2, 3.3, 5, 5.5]) + assert_equal(mbase.c._data, + [b'one', b'two', b'three', b'5', b'five']) + assert_equal(mbase.a._mask, [0, 1, 0, 0, 1]) + assert_equal(mbase.b._mask, mbase.a._mask) + assert_equal(mbase.b._mask, mbase.c._mask) + except NotImplementedError: + # OK, not implemented yet... + pass + except AssertionError: + raise + else: + raise Exception("Flexible hard masks should be supported !") + # Not using a tuple should crash + try: + mbase[-2:] = 3 + except (NotImplementedError, TypeError): + pass + else: + raise TypeError("Should have expected a readable buffer object!") + + def test_hardmask(self): + # Test hardmask + base = self.base.copy() + mbase = base.view(mrecarray) + mbase.harden_mask() + assert_(mbase._hardmask) + mbase.mask = nomask + assert_equal_records(mbase._mask, base._mask) + mbase.soften_mask() + assert_(not mbase._hardmask) + mbase.mask = nomask + # So, the mask of a field is no longer set to nomask... + assert_equal_records(mbase._mask, + ma.make_mask_none(base.shape, base.dtype)) + assert_(ma.make_mask(mbase['b']._mask) is nomask) + assert_equal(mbase['a']._mask, mbase['b']._mask) + + def test_pickling(self): + # Test pickling + base = self.base.copy() + mrec = base.view(mrecarray) + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + _ = pickle.dumps(mrec, protocol=proto) + mrec_ = pickle.loads(_) + assert_equal(mrec_.dtype, mrec.dtype) + assert_equal_records(mrec_._data, mrec._data) + assert_equal(mrec_._mask, mrec._mask) + assert_equal_records(mrec_._mask, mrec._mask) + + def test_filled(self): + # Test filling the array + _a = ma.array([1, 2, 3], mask=[0, 0, 1], dtype=int) + _b = ma.array([1.1, 2.2, 3.3], mask=[0, 0, 1], dtype=float) + _c = ma.array(['one', 'two', 'three'], mask=[0, 0, 1], dtype='|S8') + ddtype = [('a', int), ('b', float), ('c', '|S8')] + mrec = fromarrays([_a, _b, _c], dtype=ddtype, + fill_value=(99999, 99999., 'N/A')) + mrecfilled = mrec.filled() + assert_equal(mrecfilled['a'], np.array((1, 2, 99999), dtype=int)) + assert_equal(mrecfilled['b'], np.array((1.1, 2.2, 99999.), + dtype=float)) + assert_equal(mrecfilled['c'], np.array(('one', 'two', 'N/A'), + dtype='|S8')) + + def test_tolist(self): + # Test tolist. + _a = ma.array([1, 2, 3], mask=[0, 0, 1], dtype=int) + _b = ma.array([1.1, 2.2, 3.3], mask=[0, 0, 1], dtype=float) + _c = ma.array(['one', 'two', 'three'], mask=[1, 0, 0], dtype='|S8') + ddtype = [('a', int), ('b', float), ('c', '|S8')] + mrec = fromarrays([_a, _b, _c], dtype=ddtype, + fill_value=(99999, 99999., 'N/A')) + + assert_equal(mrec.tolist(), + [(1, 1.1, None), (2, 2.2, b'two'), + (None, None, b'three')]) + + def test_withnames(self): + # Test the creation w/ format and names + x = mrecarray(1, formats=float, names='base') + x[0]['base'] = 10 + assert_equal(x['base'][0], 10) + + def test_exotic_formats(self): + # Test that 'exotic' formats are processed properly + easy = mrecarray(1, dtype=[('i', int), ('s', '|S8'), ('f', float)]) + easy[0] = masked + assert_equal(easy.filled(1).item(), (1, b'1', 1.)) + + solo = mrecarray(1, dtype=[('f0', ' 1: + assert_(eq(np.concatenate((x, y), 1), + concatenate((xm, ym), 1))) + assert_(eq(np.add.reduce(x, 1), add.reduce(x, 1))) + assert_(eq(np.sum(x, 1), sum(x, 1))) + assert_(eq(np.prod(x, 1), product(x, 1))) + + def test_testCI(self): + # Test of conversions and indexing + x1 = np.array([1, 2, 4, 3]) + x2 = array(x1, mask=[1, 0, 0, 0]) + x3 = array(x1, mask=[0, 1, 0, 1]) + x4 = array(x1) + # test conversion to strings + str(x2) # raises? + repr(x2) # raises? + assert_(eq(np.sort(x1), sort(x2, fill_value=0))) + # tests of indexing + assert_(type(x2[1]) is type(x1[1])) + assert_(x1[1] == x2[1]) + assert_(x2[0] is masked) + assert_(eq(x1[2], x2[2])) + assert_(eq(x1[2:5], x2[2:5])) + assert_(eq(x1[:], x2[:])) + assert_(eq(x1[1:], x3[1:])) + x1[2] = 9 + x2[2] = 9 + assert_(eq(x1, x2)) + x1[1:3] = 99 + x2[1:3] = 99 + assert_(eq(x1, x2)) + x2[1] = masked + assert_(eq(x1, x2)) + x2[1:3] = masked + assert_(eq(x1, x2)) + x2[:] = x1 + x2[1] = masked + assert_(allequal(getmask(x2), array([0, 1, 0, 0]))) + x3[:] = masked_array([1, 2, 3, 4], [0, 1, 1, 0]) + assert_(allequal(getmask(x3), array([0, 1, 1, 0]))) + x4[:] = masked_array([1, 2, 3, 4], [0, 1, 1, 0]) + assert_(allequal(getmask(x4), array([0, 1, 1, 0]))) + assert_(allequal(x4, array([1, 2, 3, 4]))) + x1 = np.arange(5) * 1.0 + x2 = masked_values(x1, 3.0) + assert_(eq(x1, x2)) + assert_(allequal(array([0, 0, 0, 1, 0], MaskType), x2.mask)) + assert_(eq(3.0, x2.fill_value)) + x1 = array([1, 'hello', 2, 3], object) + x2 = np.array([1, 'hello', 2, 3], object) + s1 = x1[1] + s2 = x2[1] + assert_equal(type(s2), str) + assert_equal(type(s1), str) + assert_equal(s1, s2) + assert_(x1[1:1].shape == (0,)) + + def test_testCopySize(self): + # Tests of some subtle points of copying and sizing. + n = [0, 0, 1, 0, 0] + m = make_mask(n) + m2 = make_mask(m) + assert_(m is m2) + m3 = make_mask(m, copy=True) + assert_(m is not m3) + + x1 = np.arange(5) + y1 = array(x1, mask=m) + assert_(y1._data is not x1) + assert_(allequal(x1, y1._data)) + assert_(y1._mask is m) + + y1a = array(y1, copy=0) + # For copy=False, one might expect that the array would just + # passed on, i.e., that it would be "is" instead of "==". + # See gh-4043 for discussion. + assert_(y1a._mask.__array_interface__ == + y1._mask.__array_interface__) + + y2 = array(x1, mask=m3, copy=0) + assert_(y2._mask is m3) + assert_(y2[2] is masked) + y2[2] = 9 + assert_(y2[2] is not masked) + assert_(y2._mask is m3) + assert_(allequal(y2.mask, 0)) + + y2a = array(x1, mask=m, copy=1) + assert_(y2a._mask is not m) + assert_(y2a[2] is masked) + y2a[2] = 9 + assert_(y2a[2] is not masked) + assert_(y2a._mask is not m) + assert_(allequal(y2a.mask, 0)) + + y3 = array(x1 * 1.0, mask=m) + assert_(filled(y3).dtype is (x1 * 1.0).dtype) + + x4 = arange(4) + x4[2] = masked + y4 = resize(x4, (8,)) + assert_(eq(concatenate([x4, x4]), y4)) + assert_(eq(getmask(y4), [0, 0, 1, 0, 0, 0, 1, 0])) + y5 = repeat(x4, (2, 2, 2, 2), axis=0) + assert_(eq(y5, [0, 0, 1, 1, 2, 2, 3, 3])) + y6 = repeat(x4, 2, axis=0) + assert_(eq(y5, y6)) + + def test_testPut(self): + # Test of put + d = arange(5) + n = [0, 0, 0, 1, 1] + m = make_mask(n) + m2 = m.copy() + x = array(d, mask=m) + assert_(x[3] is masked) + assert_(x[4] is masked) + x[[1, 4]] = [10, 40] + assert_(x._mask is m) + assert_(x[3] is masked) + assert_(x[4] is not masked) + assert_(eq(x, [0, 10, 2, -1, 40])) + + x = array(d, mask=m2, copy=True) + x.put([0, 1, 2], [-1, 100, 200]) + assert_(x._mask is not m2) + assert_(x[3] is masked) + assert_(x[4] is masked) + assert_(eq(x, [-1, 100, 200, 0, 0])) + + def test_testPut2(self): + # Test of put + d = arange(5) + x = array(d, mask=[0, 0, 0, 0, 0]) + z = array([10, 40], mask=[1, 0]) + assert_(x[2] is not masked) + assert_(x[3] is not masked) + x[2:4] = z + assert_(x[2] is masked) + assert_(x[3] is not masked) + assert_(eq(x, [0, 1, 10, 40, 4])) + + d = arange(5) + x = array(d, mask=[0, 0, 0, 0, 0]) + y = x[2:4] + z = array([10, 40], mask=[1, 0]) + assert_(x[2] is not masked) + assert_(x[3] is not masked) + y[:] = z + assert_(y[0] is masked) + assert_(y[1] is not masked) + assert_(eq(y, [10, 40])) + assert_(x[2] is masked) + assert_(x[3] is not masked) + assert_(eq(x, [0, 1, 10, 40, 4])) + + def test_testMaPut(self): + (x, y, a10, m1, m2, xm, ym, z, zm, xf, s) = self.d + m = [1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1] + i = np.nonzero(m)[0] + put(ym, i, zm) + assert_(all(take(ym, i, axis=0) == zm)) + + def test_testOddFeatures(self): + # Test of other odd features + x = arange(20) + x = x.reshape(4, 5) + x.flat[5] = 12 + assert_(x[1, 0] == 12) + z = x + 10j * x + assert_(eq(z.real, x)) + assert_(eq(z.imag, 10 * x)) + assert_(eq((z * conjugate(z)).real, 101 * x * x)) + z.imag[...] = 0.0 + + x = arange(10) + x[3] = masked + assert_(str(x[3]) == str(masked)) + c = x >= 8 + assert_(count(where(c, masked, masked)) == 0) + assert_(shape(where(c, masked, masked)) == c.shape) + z = where(c, x, masked) + assert_(z.dtype is x.dtype) + assert_(z[3] is masked) + assert_(z[4] is masked) + assert_(z[7] is masked) + assert_(z[8] is not masked) + assert_(z[9] is not masked) + assert_(eq(x, z)) + z = where(c, masked, x) + assert_(z.dtype is x.dtype) + assert_(z[3] is masked) + assert_(z[4] is not masked) + assert_(z[7] is not masked) + assert_(z[8] is masked) + assert_(z[9] is masked) + z = masked_where(c, x) + assert_(z.dtype is x.dtype) + assert_(z[3] is masked) + assert_(z[4] is not masked) + assert_(z[7] is not masked) + assert_(z[8] is masked) + assert_(z[9] is masked) + assert_(eq(x, z)) + x = array([1., 2., 3., 4., 5.]) + c = array([1, 1, 1, 0, 0]) + x[2] = masked + z = where(c, x, -x) + assert_(eq(z, [1., 2., 0., -4., -5])) + c[0] = masked + z = where(c, x, -x) + assert_(eq(z, [1., 2., 0., -4., -5])) + assert_(z[0] is masked) + assert_(z[1] is not masked) + assert_(z[2] is masked) + assert_(eq(masked_where(greater(x, 2), x), masked_greater(x, 2))) + assert_(eq(masked_where(greater_equal(x, 2), x), + masked_greater_equal(x, 2))) + assert_(eq(masked_where(less(x, 2), x), masked_less(x, 2))) + assert_(eq(masked_where(less_equal(x, 2), x), masked_less_equal(x, 2))) + assert_(eq(masked_where(not_equal(x, 2), x), masked_not_equal(x, 2))) + assert_(eq(masked_where(equal(x, 2), x), masked_equal(x, 2))) + assert_(eq(masked_where(not_equal(x, 2), x), masked_not_equal(x, 2))) + assert_(eq(masked_inside(list(range(5)), 1, 3), [0, 199, 199, 199, 4])) + assert_(eq(masked_outside(list(range(5)), 1, 3), [199, 1, 2, 3, 199])) + assert_(eq(masked_inside(array(list(range(5)), + mask=[1, 0, 0, 0, 0]), 1, 3).mask, + [1, 1, 1, 1, 0])) + assert_(eq(masked_outside(array(list(range(5)), + mask=[0, 1, 0, 0, 0]), 1, 3).mask, + [1, 1, 0, 0, 1])) + assert_(eq(masked_equal(array(list(range(5)), + mask=[1, 0, 0, 0, 0]), 2).mask, + [1, 0, 1, 0, 0])) + assert_(eq(masked_not_equal(array([2, 2, 1, 2, 1], + mask=[1, 0, 0, 0, 0]), 2).mask, + [1, 0, 1, 0, 1])) + assert_(eq(masked_where([1, 1, 0, 0, 0], [1, 2, 3, 4, 5]), + [99, 99, 3, 4, 5])) + atest = ones((10, 10, 10), dtype=np.float32) + btest = zeros(atest.shape, MaskType) + ctest = masked_where(btest, atest) + assert_(eq(atest, ctest)) + z = choose(c, (-x, x)) + assert_(eq(z, [1., 2., 0., -4., -5])) + assert_(z[0] is masked) + assert_(z[1] is not masked) + assert_(z[2] is masked) + x = arange(6) + x[5] = masked + y = arange(6) * 10 + y[2] = masked + c = array([1, 1, 1, 0, 0, 0], mask=[1, 0, 0, 0, 0, 0]) + cm = c.filled(1) + z = where(c, x, y) + zm = where(cm, x, y) + assert_(eq(z, zm)) + assert_(getmask(zm) is nomask) + assert_(eq(zm, [0, 1, 2, 30, 40, 50])) + z = where(c, masked, 1) + assert_(eq(z, [99, 99, 99, 1, 1, 1])) + z = where(c, 1, masked) + assert_(eq(z, [99, 1, 1, 99, 99, 99])) + + def test_testMinMax2(self): + # Test of minimum, maximum. + assert_(eq(minimum([1, 2, 3], [4, 0, 9]), [1, 0, 3])) + assert_(eq(maximum([1, 2, 3], [4, 0, 9]), [4, 2, 9])) + x = arange(5) + y = arange(5) - 2 + x[3] = masked + y[0] = masked + assert_(eq(minimum(x, y), where(less(x, y), x, y))) + assert_(eq(maximum(x, y), where(greater(x, y), x, y))) + assert_(minimum.reduce(x) == 0) + assert_(maximum.reduce(x) == 4) + + def test_testTakeTransposeInnerOuter(self): + # Test of take, transpose, inner, outer products + x = arange(24) + y = np.arange(24) + x[5:6] = masked + x = x.reshape(2, 3, 4) + y = y.reshape(2, 3, 4) + assert_(eq(np.transpose(y, (2, 0, 1)), transpose(x, (2, 0, 1)))) + assert_(eq(np.take(y, (2, 0, 1), 1), take(x, (2, 0, 1), 1))) + assert_(eq(np.inner(filled(x, 0), filled(y, 0)), + inner(x, y))) + assert_(eq(np.outer(filled(x, 0), filled(y, 0)), + outer(x, y))) + y = array(['abc', 1, 'def', 2, 3], object) + y[2] = masked + t = take(y, [0, 3, 4]) + assert_(t[0] == 'abc') + assert_(t[1] == 2) + assert_(t[2] == 3) + + def test_testInplace(self): + # Test of inplace operations and rich comparisons + y = arange(10) + + x = arange(10) + xm = arange(10) + xm[2] = masked + x += 1 + assert_(eq(x, y + 1)) + xm += 1 + assert_(eq(x, y + 1)) + + x = arange(10) + xm = arange(10) + xm[2] = masked + x -= 1 + assert_(eq(x, y - 1)) + xm -= 1 + assert_(eq(xm, y - 1)) + + x = arange(10) * 1.0 + xm = arange(10) * 1.0 + xm[2] = masked + x *= 2.0 + assert_(eq(x, y * 2)) + xm *= 2.0 + assert_(eq(xm, y * 2)) + + x = arange(10) * 2 + xm = arange(10) + xm[2] = masked + x //= 2 + assert_(eq(x, y)) + xm //= 2 + assert_(eq(x, y)) + + x = arange(10) * 1.0 + xm = arange(10) * 1.0 + xm[2] = masked + x /= 2.0 + assert_(eq(x, y / 2.0)) + xm /= arange(10) + assert_(eq(xm, ones((10,)))) + + x = arange(10).astype(np.float32) + xm = arange(10) + xm[2] = masked + x += 1. + assert_(eq(x, y + 1.)) + + def test_testPickle(self): + # Test of pickling + x = arange(12) + x[4:10:2] = masked + x = x.reshape(4, 3) + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + s = pickle.dumps(x, protocol=proto) + y = pickle.loads(s) + assert_(eq(x, y)) + + def test_testMasked(self): + # Test of masked element + xx = arange(6) + xx[1] = masked + assert_(str(masked) == '--') + assert_(xx[1] is masked) + assert_equal(filled(xx[1], 0), 0) + + def test_testAverage1(self): + # Test of average. + ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0]) + assert_(eq(2.0, average(ott, axis=0))) + assert_(eq(2.0, average(ott, weights=[1., 1., 2., 1.]))) + result, wts = average(ott, weights=[1., 1., 2., 1.], returned=True) + assert_(eq(2.0, result)) + assert_(wts == 4.0) + ott[:] = masked + assert_(average(ott, axis=0) is masked) + ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0]) + ott = ott.reshape(2, 2) + ott[:, 1] = masked + assert_(eq(average(ott, axis=0), [2.0, 0.0])) + assert_(average(ott, axis=1)[0] is masked) + assert_(eq([2., 0.], average(ott, axis=0))) + result, wts = average(ott, axis=0, returned=True) + assert_(eq(wts, [1., 0.])) + + def test_testAverage2(self): + # More tests of average. + w1 = [0, 1, 1, 1, 1, 0] + w2 = [[0, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 1]] + x = arange(6) + assert_(allclose(average(x, axis=0), 2.5)) + assert_(allclose(average(x, axis=0, weights=w1), 2.5)) + y = array([arange(6), 2.0 * arange(6)]) + assert_(allclose(average(y, None), + np.add.reduce(np.arange(6)) * 3. / 12.)) + assert_(allclose(average(y, axis=0), np.arange(6) * 3. / 2.)) + assert_(allclose(average(y, axis=1), + [average(x, axis=0), average(x, axis=0) * 2.0])) + assert_(allclose(average(y, None, weights=w2), 20. / 6.)) + assert_(allclose(average(y, axis=0, weights=w2), + [0., 1., 2., 3., 4., 10.])) + assert_(allclose(average(y, axis=1), + [average(x, axis=0), average(x, axis=0) * 2.0])) + m1 = zeros(6) + m2 = [0, 0, 1, 1, 0, 0] + m3 = [[0, 0, 1, 1, 0, 0], [0, 1, 1, 1, 1, 0]] + m4 = ones(6) + m5 = [0, 1, 1, 1, 1, 1] + assert_(allclose(average(masked_array(x, m1), axis=0), 2.5)) + assert_(allclose(average(masked_array(x, m2), axis=0), 2.5)) + assert_(average(masked_array(x, m4), axis=0) is masked) + assert_equal(average(masked_array(x, m5), axis=0), 0.0) + assert_equal(count(average(masked_array(x, m4), axis=0)), 0) + z = masked_array(y, m3) + assert_(allclose(average(z, None), 20. / 6.)) + assert_(allclose(average(z, axis=0), + [0., 1., 99., 99., 4.0, 7.5])) + assert_(allclose(average(z, axis=1), [2.5, 5.0])) + assert_(allclose(average(z, axis=0, weights=w2), + [0., 1., 99., 99., 4.0, 10.0])) + + a = arange(6) + b = arange(6) * 3 + r1, w1 = average([[a, b], [b, a]], axis=1, returned=True) + assert_equal(shape(r1), shape(w1)) + assert_equal(r1.shape, w1.shape) + r2, w2 = average(ones((2, 2, 3)), axis=0, weights=[3, 1], returned=True) + assert_equal(shape(w2), shape(r2)) + r2, w2 = average(ones((2, 2, 3)), returned=True) + assert_equal(shape(w2), shape(r2)) + r2, w2 = average(ones((2, 2, 3)), weights=ones((2, 2, 3)), returned=True) + assert_(shape(w2) == shape(r2)) + a2d = array([[1, 2], [0, 4]], float) + a2dm = masked_array(a2d, [[0, 0], [1, 0]]) + a2da = average(a2d, axis=0) + assert_(eq(a2da, [0.5, 3.0])) + a2dma = average(a2dm, axis=0) + assert_(eq(a2dma, [1.0, 3.0])) + a2dma = average(a2dm, axis=None) + assert_(eq(a2dma, 7. / 3.)) + a2dma = average(a2dm, axis=1) + assert_(eq(a2dma, [1.5, 4.0])) + + def test_testToPython(self): + assert_equal(1, int(array(1))) + assert_equal(1.0, float(array(1))) + assert_equal(1, int(array([[[1]]]))) + assert_equal(1.0, float(array([[1]]))) + assert_raises(TypeError, float, array([1, 1])) + assert_raises(ValueError, bool, array([0, 1])) + assert_raises(ValueError, bool, array([0, 0], mask=[0, 1])) + + def test_testScalarArithmetic(self): + xm = array(0, mask=1) + # TODO FIXME: Find out what the following raises a warning in r8247 + with np.errstate(divide='ignore'): + assert_((1 / array(0)).mask) + assert_((1 + xm).mask) + assert_((-xm).mask) + assert_((-xm).mask) + assert_(maximum(xm, xm).mask) + assert_(minimum(xm, xm).mask) + assert_(xm.filled().dtype is xm._data.dtype) + x = array(0, mask=0) + assert_(x.filled() == x._data) + assert_equal(str(xm), str(masked_print_option)) + + def test_testArrayMethods(self): + a = array([1, 3, 2]) + assert_(eq(a.any(), a._data.any())) + assert_(eq(a.all(), a._data.all())) + assert_(eq(a.argmax(), a._data.argmax())) + assert_(eq(a.argmin(), a._data.argmin())) + assert_(eq(a.choose(0, 1, 2, 3, 4), + a._data.choose(0, 1, 2, 3, 4))) + assert_(eq(a.compress([1, 0, 1]), a._data.compress([1, 0, 1]))) + assert_(eq(a.conj(), a._data.conj())) + assert_(eq(a.conjugate(), a._data.conjugate())) + m = array([[1, 2], [3, 4]]) + assert_(eq(m.diagonal(), m._data.diagonal())) + assert_(eq(a.sum(), a._data.sum())) + assert_(eq(a.take([1, 2]), a._data.take([1, 2]))) + assert_(eq(m.transpose(), m._data.transpose())) + + def test_testArrayAttributes(self): + a = array([1, 3, 2]) + assert_equal(a.ndim, 1) + + def test_testAPI(self): + assert_(not [m for m in dir(np.ndarray) + if m not in dir(MaskedArray) and + not m.startswith('_')]) + + def test_testSingleElementSubscript(self): + a = array([1, 3, 2]) + b = array([1, 3, 2], mask=[1, 0, 1]) + assert_equal(a[0].shape, ()) + assert_equal(b[0].shape, ()) + assert_equal(b[1].shape, ()) + + def test_assignment_by_condition(self): + # Test for gh-18951 + a = array([1, 2, 3, 4], mask=[1, 0, 1, 0]) + c = a >= 3 + a[c] = 5 + assert_(a[2] is masked) + + def test_assignment_by_condition_2(self): + # gh-19721 + a = masked_array([0, 1], mask=[False, False]) + b = masked_array([0, 1], mask=[True, True]) + mask = a < 1 + b[mask] = a[mask] + expected_mask = [False, True] + assert_equal(b.mask, expected_mask) + + +class TestUfuncs: + def setup_method(self): + self.d = (array([1.0, 0, -1, pi / 2] * 2, mask=[0, 1] + [0] * 6), + array([1.0, 0, -1, pi / 2] * 2, mask=[1, 0] + [0] * 6),) + + def test_testUfuncRegression(self): + f_invalid_ignore = [ + 'sqrt', 'arctanh', 'arcsin', 'arccos', + 'arccosh', 'arctanh', 'log', 'log10', 'divide', + 'true_divide', 'floor_divide', 'remainder', 'fmod'] + for f in ['sqrt', 'log', 'log10', 'exp', 'conjugate', + 'sin', 'cos', 'tan', + 'arcsin', 'arccos', 'arctan', + 'sinh', 'cosh', 'tanh', + 'arcsinh', + 'arccosh', + 'arctanh', + 'absolute', 'fabs', 'negative', + 'floor', 'ceil', + 'logical_not', + 'add', 'subtract', 'multiply', + 'divide', 'true_divide', 'floor_divide', + 'remainder', 'fmod', 'hypot', 'arctan2', + 'equal', 'not_equal', 'less_equal', 'greater_equal', + 'less', 'greater', + 'logical_and', 'logical_or', 'logical_xor']: + try: + uf = getattr(umath, f) + except AttributeError: + uf = getattr(fromnumeric, f) + mf = getattr(np.ma, f) + args = self.d[:uf.nin] + with np.errstate(): + if f in f_invalid_ignore: + np.seterr(invalid='ignore') + if f in ['arctanh', 'log', 'log10']: + np.seterr(divide='ignore') + ur = uf(*args) + mr = mf(*args) + assert_(eq(ur.filled(0), mr.filled(0), f)) + assert_(eqmask(ur.mask, mr.mask)) + + def test_reduce(self): + a = self.d[0] + assert_(not alltrue(a, axis=0)) + assert_(sometrue(a, axis=0)) + assert_equal(sum(a[:3], axis=0), 0) + assert_equal(product(a, axis=0), 0) + + def test_minmax(self): + a = arange(1, 13).reshape(3, 4) + amask = masked_where(a < 5, a) + assert_equal(amask.max(), a.max()) + assert_equal(amask.min(), 5) + assert_((amask.max(0) == a.max(0)).all()) + assert_((amask.min(0) == [5, 6, 7, 8]).all()) + assert_(amask.max(1)[0].mask) + assert_(amask.min(1)[0].mask) + + def test_nonzero(self): + for t in "?bhilqpBHILQPfdgFDGO": + x = array([1, 0, 2, 0], mask=[0, 0, 1, 1]) + assert_(eq(nonzero(x), [0])) + + +class TestArrayMethods: + + def setup_method(self): + x = np.array([8.375, 7.545, 8.828, 8.5, 1.757, 5.928, + 8.43, 7.78, 9.865, 5.878, 8.979, 4.732, + 3.012, 6.022, 5.095, 3.116, 5.238, 3.957, + 6.04, 9.63, 7.712, 3.382, 4.489, 6.479, + 7.189, 9.645, 5.395, 4.961, 9.894, 2.893, + 7.357, 9.828, 6.272, 3.758, 6.693, 0.993]) + X = x.reshape(6, 6) + XX = x.reshape(3, 2, 2, 3) + + m = np.array([0, 1, 0, 1, 0, 0, + 1, 0, 1, 1, 0, 1, + 0, 0, 0, 1, 0, 1, + 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, 0]) + mx = array(data=x, mask=m) + mX = array(data=X, mask=m.reshape(X.shape)) + mXX = array(data=XX, mask=m.reshape(XX.shape)) + + self.d = (x, X, XX, m, mx, mX, mXX) + + def test_trace(self): + (x, X, XX, m, mx, mX, mXX,) = self.d + mXdiag = mX.diagonal() + assert_equal(mX.trace(), mX.diagonal().compressed().sum()) + assert_(eq(mX.trace(), + X.trace() - sum(mXdiag.mask * X.diagonal(), + axis=0))) + + def test_clip(self): + (x, X, XX, m, mx, mX, mXX,) = self.d + clipped = mx.clip(2, 8) + assert_(eq(clipped.mask, mx.mask)) + assert_(eq(clipped._data, x.clip(2, 8))) + assert_(eq(clipped._data, mx._data.clip(2, 8))) + + def test_ptp(self): + (x, X, XX, m, mx, mX, mXX,) = self.d + (n, m) = X.shape + # print(type(mx), mx.compressed()) + # raise Exception() + assert_equal(mx.ptp(), np.ptp(mx.compressed())) + rows = np.zeros(n, np.float64) + cols = np.zeros(m, np.float64) + for k in range(m): + cols[k] = np.ptp(mX[:, k].compressed()) + for k in range(n): + rows[k] = np.ptp(mX[k].compressed()) + assert_(eq(mX.ptp(0), cols)) + assert_(eq(mX.ptp(1), rows)) + + def test_swapaxes(self): + (x, X, XX, m, mx, mX, mXX,) = self.d + mXswapped = mX.swapaxes(0, 1) + assert_(eq(mXswapped[-1], mX[:, -1])) + mXXswapped = mXX.swapaxes(0, 2) + assert_equal(mXXswapped.shape, (2, 2, 3, 3)) + + def test_cumprod(self): + (x, X, XX, m, mx, mX, mXX,) = self.d + mXcp = mX.cumprod(0) + assert_(eq(mXcp._data, mX.filled(1).cumprod(0))) + mXcp = mX.cumprod(1) + assert_(eq(mXcp._data, mX.filled(1).cumprod(1))) + + def test_cumsum(self): + (x, X, XX, m, mx, mX, mXX,) = self.d + mXcp = mX.cumsum(0) + assert_(eq(mXcp._data, mX.filled(0).cumsum(0))) + mXcp = mX.cumsum(1) + assert_(eq(mXcp._data, mX.filled(0).cumsum(1))) + + def test_varstd(self): + (x, X, XX, m, mx, mX, mXX,) = self.d + assert_(eq(mX.var(axis=None), mX.compressed().var())) + assert_(eq(mX.std(axis=None), mX.compressed().std())) + assert_(eq(mXX.var(axis=3).shape, XX.var(axis=3).shape)) + assert_(eq(mX.var().shape, X.var().shape)) + (mXvar0, mXvar1) = (mX.var(axis=0), mX.var(axis=1)) + for k in range(6): + assert_(eq(mXvar1[k], mX[k].compressed().var())) + assert_(eq(mXvar0[k], mX[:, k].compressed().var())) + assert_(eq(np.sqrt(mXvar0[k]), + mX[:, k].compressed().std())) + + +def eqmask(m1, m2): + if m1 is nomask: + return m2 is nomask + if m2 is nomask: + return m1 is nomask + return (m1 == m2).all() diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/test_regression.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..025387ba454c480df5f7a8380e2f317b37a7e75d --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_regression.py @@ -0,0 +1,100 @@ +import numpy as np +from numpy.testing import ( + assert_, + assert_allclose, + assert_array_equal, + suppress_warnings, +) + + +class TestRegression: + def test_masked_array_create(self): + # Ticket #17 + x = np.ma.masked_array([0, 1, 2, 3, 0, 4, 5, 6], + mask=[0, 0, 0, 1, 1, 1, 0, 0]) + assert_array_equal(np.ma.nonzero(x), [[1, 2, 6, 7]]) + + def test_masked_array(self): + # Ticket #61 + np.ma.array(1, mask=[1]) + + def test_mem_masked_where(self): + # Ticket #62 + from numpy.ma import MaskType, masked_where + a = np.zeros((1, 1)) + b = np.zeros(a.shape, MaskType) + c = masked_where(b, a) + a - c + + def test_masked_array_multiply(self): + # Ticket #254 + a = np.ma.zeros((4, 1)) + a[2, 0] = np.ma.masked + b = np.zeros((4, 2)) + a * b + b * a + + def test_masked_array_repeat(self): + # Ticket #271 + np.ma.array([1], mask=False).repeat(10) + + def test_masked_array_repr_unicode(self): + # Ticket #1256 + repr(np.ma.array("Unicode")) + + def test_atleast_2d(self): + # Ticket #1559 + a = np.ma.masked_array([0.0, 1.2, 3.5], mask=[False, True, False]) + b = np.atleast_2d(a) + assert_(a.mask.ndim == 1) + assert_(b.mask.ndim == 2) + + def test_set_fill_value_unicode_py3(self): + # Ticket #2733 + a = np.ma.masked_array(['a', 'b', 'c'], mask=[1, 0, 0]) + a.fill_value = 'X' + assert_(a.fill_value == 'X') + + def test_var_sets_maskedarray_scalar(self): + # Issue gh-2757 + a = np.ma.array(np.arange(5), mask=True) + mout = np.ma.array(-1, dtype=float) + a.var(out=mout) + assert_(mout._data == 0) + + def test_ddof_corrcoef(self): + # See gh-3336 + x = np.ma.masked_equal([1, 2, 3, 4, 5], 4) + y = np.array([2, 2.5, 3.1, 3, 5]) + # this test can be removed after deprecation. + with suppress_warnings() as sup: + sup.filter(DeprecationWarning, "bias and ddof have no effect") + r0 = np.ma.corrcoef(x, y, ddof=0) + r1 = np.ma.corrcoef(x, y, ddof=1) + # ddof should not have an effect (it gets cancelled out) + assert_allclose(r0.data, r1.data) + + def test_mask_not_backmangled(self): + # See gh-10314. Test case taken from gh-3140. + a = np.ma.MaskedArray([1., 2.], mask=[False, False]) + assert_(a.mask.shape == (2,)) + b = np.tile(a, (2, 1)) + # Check that the above no longer changes a.shape to (1, 2) + assert_(a.mask.shape == (2,)) + assert_(b.shape == (2, 2)) + assert_(b.mask.shape == (2, 2)) + + def test_empty_list_on_structured(self): + # See gh-12464. Indexing with empty list should give empty result. + ma = np.ma.MaskedArray([(1, 1.), (2, 2.), (3, 3.)], dtype='i4,f4') + assert_array_equal(ma[[]], ma[:0]) + + def test_masked_array_tobytes_fortran(self): + ma = np.ma.arange(4).reshape((2, 2)) + assert_array_equal(ma.tobytes(order='F'), ma.T.tobytes()) + + def test_structured_array(self): + # see gh-22041 + np.ma.array((1, (b"", b"")), + dtype=[("x", np.int_), + ("y", [("i", np.void), ("j", np.void)])]) diff --git a/venv/lib/python3.13/site-packages/numpy/ma/tests/test_subclassing.py b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_subclassing.py new file mode 100644 index 0000000000000000000000000000000000000000..3364e563097e94a197438b4f2ba258b94ee87102 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/ma/tests/test_subclassing.py @@ -0,0 +1,469 @@ +"""Tests suite for MaskedArray & subclassing. + +:author: Pierre Gerard-Marchant +:contact: pierregm_at_uga_dot_edu + +""" +import numpy as np +from numpy.lib.mixins import NDArrayOperatorsMixin +from numpy.ma.core import ( + MaskedArray, + add, + arange, + array, + asanyarray, + asarray, + divide, + hypot, + log, + masked, + masked_array, + nomask, +) +from numpy.ma.testutils import assert_equal +from numpy.testing import assert_, assert_raises + +# from numpy.ma.core import ( + +def assert_startswith(a, b): + # produces a better error message than assert_(a.startswith(b)) + assert_equal(a[:len(b)], b) + +class SubArray(np.ndarray): + # Defines a generic np.ndarray subclass, that stores some metadata + # in the dictionary `info`. + def __new__(cls, arr, info={}): + x = np.asanyarray(arr).view(cls) + x.info = info.copy() + return x + + def __array_finalize__(self, obj): + super().__array_finalize__(obj) + self.info = getattr(obj, 'info', {}).copy() + + def __add__(self, other): + result = super().__add__(other) + result.info['added'] = result.info.get('added', 0) + 1 + return result + + def __iadd__(self, other): + result = super().__iadd__(other) + result.info['iadded'] = result.info.get('iadded', 0) + 1 + return result + + +subarray = SubArray + + +class SubMaskedArray(MaskedArray): + """Pure subclass of MaskedArray, keeping some info on subclass.""" + def __new__(cls, info=None, **kwargs): + obj = super().__new__(cls, **kwargs) + obj._optinfo['info'] = info + return obj + + +class MSubArray(SubArray, MaskedArray): + + def __new__(cls, data, info={}, mask=nomask): + subarr = SubArray(data, info) + _data = MaskedArray.__new__(cls, data=subarr, mask=mask) + _data.info = subarr.info + return _data + + @property + def _series(self): + _view = self.view(MaskedArray) + _view._sharedmask = False + return _view + + +msubarray = MSubArray + + +# Also a subclass that overrides __str__, __repr__ and __setitem__, disallowing +# setting to non-class values (and thus np.ma.core.masked_print_option) +# and overrides __array_wrap__, updating the info dict, to check that this +# doesn't get destroyed by MaskedArray._update_from. But this one also needs +# its own iterator... +class CSAIterator: + """ + Flat iterator object that uses its own setter/getter + (works around ndarray.flat not propagating subclass setters/getters + see https://github.com/numpy/numpy/issues/4564) + roughly following MaskedIterator + """ + def __init__(self, a): + self._original = a + self._dataiter = a.view(np.ndarray).flat + + def __iter__(self): + return self + + def __getitem__(self, indx): + out = self._dataiter.__getitem__(indx) + if not isinstance(out, np.ndarray): + out = out.__array__() + out = out.view(type(self._original)) + return out + + def __setitem__(self, index, value): + self._dataiter[index] = self._original._validate_input(value) + + def __next__(self): + return next(self._dataiter).__array__().view(type(self._original)) + + +class ComplicatedSubArray(SubArray): + + def __str__(self): + return f'myprefix {self.view(SubArray)} mypostfix' + + def __repr__(self): + # Return a repr that does not start with 'name(' + return f'<{self.__class__.__name__} {self}>' + + def _validate_input(self, value): + if not isinstance(value, ComplicatedSubArray): + raise ValueError("Can only set to MySubArray values") + return value + + def __setitem__(self, item, value): + # validation ensures direct assignment with ndarray or + # masked_print_option will fail + super().__setitem__(item, self._validate_input(value)) + + def __getitem__(self, item): + # ensure getter returns our own class also for scalars + value = super().__getitem__(item) + if not isinstance(value, np.ndarray): # scalar + value = value.__array__().view(ComplicatedSubArray) + return value + + @property + def flat(self): + return CSAIterator(self) + + @flat.setter + def flat(self, value): + y = self.ravel() + y[:] = value + + def __array_wrap__(self, obj, context=None, return_scalar=False): + obj = super().__array_wrap__(obj, context, return_scalar) + if context is not None and context[0] is np.multiply: + obj.info['multiplied'] = obj.info.get('multiplied', 0) + 1 + + return obj + + +class WrappedArray(NDArrayOperatorsMixin): + """ + Wrapping a MaskedArray rather than subclassing to test that + ufunc deferrals are commutative. + See: https://github.com/numpy/numpy/issues/15200) + """ + __slots__ = ('_array', 'attrs') + __array_priority__ = 20 + + def __init__(self, array, **attrs): + self._array = array + self.attrs = attrs + + def __repr__(self): + return f"{self.__class__.__name__}(\n{self._array}\n{self.attrs}\n)" + + def __array__(self, dtype=None, copy=None): + return np.asarray(self._array) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if method == '__call__': + inputs = [arg._array if isinstance(arg, self.__class__) else arg + for arg in inputs] + return self.__class__(ufunc(*inputs, **kwargs), **self.attrs) + else: + return NotImplemented + + +class TestSubclassing: + # Test suite for masked subclasses of ndarray. + + def setup_method(self): + x = np.arange(5, dtype='float') + mx = msubarray(x, mask=[0, 1, 0, 0, 0]) + self.data = (x, mx) + + def test_data_subclassing(self): + # Tests whether the subclass is kept. + x = np.arange(5) + m = [0, 0, 1, 0, 0] + xsub = SubArray(x) + xmsub = masked_array(xsub, mask=m) + assert_(isinstance(xmsub, MaskedArray)) + assert_equal(xmsub._data, xsub) + assert_(isinstance(xmsub._data, SubArray)) + + def test_maskedarray_subclassing(self): + # Tests subclassing MaskedArray + (x, mx) = self.data + assert_(isinstance(mx._data, subarray)) + + def test_masked_unary_operations(self): + # Tests masked_unary_operation + (x, mx) = self.data + with np.errstate(divide='ignore'): + assert_(isinstance(log(mx), msubarray)) + assert_equal(log(x), np.log(x)) + + def test_masked_binary_operations(self): + # Tests masked_binary_operation + (x, mx) = self.data + # Result should be a msubarray + assert_(isinstance(add(mx, mx), msubarray)) + assert_(isinstance(add(mx, x), msubarray)) + # Result should work + assert_equal(add(mx, x), mx + x) + assert_(isinstance(add(mx, mx)._data, subarray)) + assert_(isinstance(add.outer(mx, mx), msubarray)) + assert_(isinstance(hypot(mx, mx), msubarray)) + assert_(isinstance(hypot(mx, x), msubarray)) + + def test_masked_binary_operations2(self): + # Tests domained_masked_binary_operation + (x, mx) = self.data + xmx = masked_array(mx.data.__array__(), mask=mx.mask) + assert_(isinstance(divide(mx, mx), msubarray)) + assert_(isinstance(divide(mx, x), msubarray)) + assert_equal(divide(mx, mx), divide(xmx, xmx)) + + def test_attributepropagation(self): + x = array(arange(5), mask=[0] + [1] * 4) + my = masked_array(subarray(x)) + ym = msubarray(x) + # + z = (my + 1) + assert_(isinstance(z, MaskedArray)) + assert_(not isinstance(z, MSubArray)) + assert_(isinstance(z._data, SubArray)) + assert_equal(z._data.info, {}) + # + z = (ym + 1) + assert_(isinstance(z, MaskedArray)) + assert_(isinstance(z, MSubArray)) + assert_(isinstance(z._data, SubArray)) + assert_(z._data.info['added'] > 0) + # Test that inplace methods from data get used (gh-4617) + ym += 1 + assert_(isinstance(ym, MaskedArray)) + assert_(isinstance(ym, MSubArray)) + assert_(isinstance(ym._data, SubArray)) + assert_(ym._data.info['iadded'] > 0) + # + ym._set_mask([1, 0, 0, 0, 1]) + assert_equal(ym._mask, [1, 0, 0, 0, 1]) + ym._series._set_mask([0, 0, 0, 0, 1]) + assert_equal(ym._mask, [0, 0, 0, 0, 1]) + # + xsub = subarray(x, info={'name': 'x'}) + mxsub = masked_array(xsub) + assert_(hasattr(mxsub, 'info')) + assert_equal(mxsub.info, xsub.info) + + def test_subclasspreservation(self): + # Checks that masked_array(...,subok=True) preserves the class. + x = np.arange(5) + m = [0, 0, 1, 0, 0] + xinfo = list(zip(x, m)) + xsub = MSubArray(x, mask=m, info={'xsub': xinfo}) + # + mxsub = masked_array(xsub, subok=False) + assert_(not isinstance(mxsub, MSubArray)) + assert_(isinstance(mxsub, MaskedArray)) + assert_equal(mxsub._mask, m) + # + mxsub = asarray(xsub) + assert_(not isinstance(mxsub, MSubArray)) + assert_(isinstance(mxsub, MaskedArray)) + assert_equal(mxsub._mask, m) + # + mxsub = masked_array(xsub, subok=True) + assert_(isinstance(mxsub, MSubArray)) + assert_equal(mxsub.info, xsub.info) + assert_equal(mxsub._mask, xsub._mask) + # + mxsub = asanyarray(xsub) + assert_(isinstance(mxsub, MSubArray)) + assert_equal(mxsub.info, xsub.info) + assert_equal(mxsub._mask, m) + + def test_subclass_items(self): + """test that getter and setter go via baseclass""" + x = np.arange(5) + xcsub = ComplicatedSubArray(x) + mxcsub = masked_array(xcsub, mask=[True, False, True, False, False]) + # getter should return a ComplicatedSubArray, even for single item + # first check we wrote ComplicatedSubArray correctly + assert_(isinstance(xcsub[1], ComplicatedSubArray)) + assert_(isinstance(xcsub[1, ...], ComplicatedSubArray)) + assert_(isinstance(xcsub[1:4], ComplicatedSubArray)) + + # now that it propagates inside the MaskedArray + assert_(isinstance(mxcsub[1], ComplicatedSubArray)) + assert_(isinstance(mxcsub[1, ...].data, ComplicatedSubArray)) + assert_(mxcsub[0] is masked) + assert_(isinstance(mxcsub[0, ...].data, ComplicatedSubArray)) + assert_(isinstance(mxcsub[1:4].data, ComplicatedSubArray)) + + # also for flattened version (which goes via MaskedIterator) + assert_(isinstance(mxcsub.flat[1].data, ComplicatedSubArray)) + assert_(mxcsub.flat[0] is masked) + assert_(isinstance(mxcsub.flat[1:4].base, ComplicatedSubArray)) + + # setter should only work with ComplicatedSubArray input + # first check we wrote ComplicatedSubArray correctly + assert_raises(ValueError, xcsub.__setitem__, 1, x[4]) + # now that it propagates inside the MaskedArray + assert_raises(ValueError, mxcsub.__setitem__, 1, x[4]) + assert_raises(ValueError, mxcsub.__setitem__, slice(1, 4), x[1:4]) + mxcsub[1] = xcsub[4] + mxcsub[1:4] = xcsub[1:4] + # also for flattened version (which goes via MaskedIterator) + assert_raises(ValueError, mxcsub.flat.__setitem__, 1, x[4]) + assert_raises(ValueError, mxcsub.flat.__setitem__, slice(1, 4), x[1:4]) + mxcsub.flat[1] = xcsub[4] + mxcsub.flat[1:4] = xcsub[1:4] + + def test_subclass_nomask_items(self): + x = np.arange(5) + xcsub = ComplicatedSubArray(x) + mxcsub_nomask = masked_array(xcsub) + + assert_(isinstance(mxcsub_nomask[1, ...].data, ComplicatedSubArray)) + assert_(isinstance(mxcsub_nomask[0, ...].data, ComplicatedSubArray)) + + assert_(isinstance(mxcsub_nomask[1], ComplicatedSubArray)) + assert_(isinstance(mxcsub_nomask[0], ComplicatedSubArray)) + + def test_subclass_repr(self): + """test that repr uses the name of the subclass + and 'array' for np.ndarray""" + x = np.arange(5) + mx = masked_array(x, mask=[True, False, True, False, False]) + assert_startswith(repr(mx), 'masked_array') + xsub = SubArray(x) + mxsub = masked_array(xsub, mask=[True, False, True, False, False]) + assert_startswith(repr(mxsub), + f'masked_{SubArray.__name__}(data=[--, 1, --, 3, 4]') + + def test_subclass_str(self): + """test str with subclass that has overridden str, setitem""" + # first without override + x = np.arange(5) + xsub = SubArray(x) + mxsub = masked_array(xsub, mask=[True, False, True, False, False]) + assert_equal(str(mxsub), '[-- 1 -- 3 4]') + + xcsub = ComplicatedSubArray(x) + assert_raises(ValueError, xcsub.__setitem__, 0, + np.ma.core.masked_print_option) + mxcsub = masked_array(xcsub, mask=[True, False, True, False, False]) + assert_equal(str(mxcsub), 'myprefix [-- 1 -- 3 4] mypostfix') + + def test_pure_subclass_info_preservation(self): + # Test that ufuncs and methods conserve extra information consistently; + # see gh-7122. + arr1 = SubMaskedArray('test', data=[1, 2, 3, 4, 5, 6]) + arr2 = SubMaskedArray(data=[0, 1, 2, 3, 4, 5]) + diff1 = np.subtract(arr1, arr2) + assert_('info' in diff1._optinfo) + assert_(diff1._optinfo['info'] == 'test') + diff2 = arr1 - arr2 + assert_('info' in diff2._optinfo) + assert_(diff2._optinfo['info'] == 'test') + + +class ArrayNoInheritance: + """Quantity-like class that does not inherit from ndarray""" + def __init__(self, data, units): + self.magnitude = data + self.units = units + + def __getattr__(self, attr): + return getattr(self.magnitude, attr) + + +def test_array_no_inheritance(): + data_masked = np.ma.array([1, 2, 3], mask=[True, False, True]) + data_masked_units = ArrayNoInheritance(data_masked, 'meters') + + # Get the masked representation of the Quantity-like class + new_array = np.ma.array(data_masked_units) + assert_equal(data_masked.data, new_array.data) + assert_equal(data_masked.mask, new_array.mask) + # Test sharing the mask + data_masked.mask = [True, False, False] + assert_equal(data_masked.mask, new_array.mask) + assert_(new_array.sharedmask) + + # Get the masked representation of the Quantity-like class + new_array = np.ma.array(data_masked_units, copy=True) + assert_equal(data_masked.data, new_array.data) + assert_equal(data_masked.mask, new_array.mask) + # Test that the mask is not shared when copy=True + data_masked.mask = [True, False, True] + assert_equal([True, False, False], new_array.mask) + assert_(not new_array.sharedmask) + + # Get the masked representation of the Quantity-like class + new_array = np.ma.array(data_masked_units, keep_mask=False) + assert_equal(data_masked.data, new_array.data) + # The change did not affect the original mask + assert_equal(data_masked.mask, [True, False, True]) + # Test that the mask is False and not shared when keep_mask=False + assert_(not new_array.mask) + assert_(not new_array.sharedmask) + + +class TestClassWrapping: + # Test suite for classes that wrap MaskedArrays + + def setup_method(self): + m = np.ma.masked_array([1, 3, 5], mask=[False, True, False]) + wm = WrappedArray(m) + self.data = (m, wm) + + def test_masked_unary_operations(self): + # Tests masked_unary_operation + (m, wm) = self.data + with np.errstate(divide='ignore'): + assert_(isinstance(np.log(wm), WrappedArray)) + + def test_masked_binary_operations(self): + # Tests masked_binary_operation + (m, wm) = self.data + # Result should be a WrappedArray + assert_(isinstance(np.add(wm, wm), WrappedArray)) + assert_(isinstance(np.add(m, wm), WrappedArray)) + assert_(isinstance(np.add(wm, m), WrappedArray)) + # add and '+' should call the same ufunc + assert_equal(np.add(m, wm), m + wm) + assert_(isinstance(np.hypot(m, wm), WrappedArray)) + assert_(isinstance(np.hypot(wm, m), WrappedArray)) + # Test domained binary operations + assert_(isinstance(np.divide(wm, m), WrappedArray)) + assert_(isinstance(np.divide(m, wm), WrappedArray)) + assert_equal(np.divide(wm, m) * m, np.divide(m, m) * wm) + # Test broadcasting + m2 = np.stack([m, m]) + assert_(isinstance(np.divide(wm, m2), WrappedArray)) + assert_(isinstance(np.divide(m2, wm), WrappedArray)) + assert_equal(np.divide(m2, wm), np.divide(wm, m2)) + + def test_mixins_have_slots(self): + mixin = NDArrayOperatorsMixin() + # Should raise an error + assert_raises(AttributeError, mixin.__setattr__, "not_a_real_attr", 1) + + m = np.ma.masked_array([1, 3, 5], mask=[False, True, False]) + wm = WrappedArray(m) + assert_raises(AttributeError, wm.__setattr__, "not_an_attr", 2) diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89ebfe55e44185d38c711887b53d1a66f4eded29 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/__pycache__/defmatrix.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/__pycache__/defmatrix.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3cf7930427d7ecc51718ce67f2d1a5eed9d0b2 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/__pycache__/defmatrix.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__init__.py b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bd8be35fb79596fdf5c475a3d01d51cea62228e Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_defmatrix.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_defmatrix.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6deb1226920ad94c64a0753f70f1d3a3dea08521 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_defmatrix.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_interaction.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_interaction.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9278c3a37aa587a276a5de24c601dc7cda64bb4e Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_interaction.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_masked_matrix.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_masked_matrix.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f17e2c0c76a7b9c61578b0c79f04196dccc3c86 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_masked_matrix.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_matrix_linalg.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_matrix_linalg.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8c55e0681452411fef622acf48fa175f544534d Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_matrix_linalg.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_multiarray.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_multiarray.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab387ab99e03543a3b4cd1179b8cd0679d8b42f Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_multiarray.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_numeric.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_numeric.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0c8a5c6bdc865a436551ab35b48196341ae89c8 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_numeric.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_regression.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_regression.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ee2373dabb92ea10e8f1a780212c327237c692 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/__pycache__/test_regression.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_defmatrix.py b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_defmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..ce23933ab7f72ece4510c96bc39bb69e87de2a92 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_defmatrix.py @@ -0,0 +1,455 @@ +import collections.abc + +import numpy as np +from numpy import asmatrix, bmat, matrix +from numpy.linalg import matrix_power +from numpy.testing import ( + assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises, +) + + +class TestCtor: + def test_basic(self): + A = np.array([[1, 2], [3, 4]]) + mA = matrix(A) + assert_(np.all(mA.A == A)) + + B = bmat("A,A;A,A") + C = bmat([[A, A], [A, A]]) + D = np.array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + assert_(np.all(B.A == D)) + assert_(np.all(C.A == D)) + + E = np.array([[5, 6], [7, 8]]) + AEresult = matrix([[1, 2, 5, 6], [3, 4, 7, 8]]) + assert_(np.all(bmat([A, E]) == AEresult)) + + vec = np.arange(5) + mvec = matrix(vec) + assert_(mvec.shape == (1, 5)) + + def test_exceptions(self): + # Check for ValueError when called with invalid string data. + assert_raises(ValueError, matrix, "invalid") + + def test_bmat_nondefault_str(self): + A = np.array([[1, 2], [3, 4]]) + B = np.array([[5, 6], [7, 8]]) + Aresult = np.array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + mixresult = np.array([[1, 2, 5, 6], + [3, 4, 7, 8], + [5, 6, 1, 2], + [7, 8, 3, 4]]) + assert_(np.all(bmat("A,A;A,A") == Aresult)) + assert_(np.all(bmat("A,A;A,A", ldict={'A': B}) == Aresult)) + assert_raises(TypeError, bmat, "A,A;A,A", gdict={'A': B}) + assert_( + np.all(bmat("A,A;A,A", ldict={'A': A}, gdict={'A': B}) == Aresult)) + b2 = bmat("A,B;C,D", ldict={'A': A, 'B': B}, gdict={'C': B, 'D': A}) + assert_(np.all(b2 == mixresult)) + + +class TestProperties: + def test_sum(self): + """Test whether matrix.sum(axis=1) preserves orientation. + Fails in NumPy <= 0.9.6.2127. + """ + M = matrix([[1, 2, 0, 0], + [3, 4, 0, 0], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + sum0 = matrix([8, 12, 4, 6]) + sum1 = matrix([3, 7, 6, 14]).T + sumall = 30 + assert_array_equal(sum0, M.sum(axis=0)) + assert_array_equal(sum1, M.sum(axis=1)) + assert_equal(sumall, M.sum()) + + assert_array_equal(sum0, np.sum(M, axis=0)) + assert_array_equal(sum1, np.sum(M, axis=1)) + assert_equal(sumall, np.sum(M)) + + def test_prod(self): + x = matrix([[1, 2, 3], [4, 5, 6]]) + assert_equal(x.prod(), 720) + assert_equal(x.prod(0), matrix([[4, 10, 18]])) + assert_equal(x.prod(1), matrix([[6], [120]])) + + assert_equal(np.prod(x), 720) + assert_equal(np.prod(x, axis=0), matrix([[4, 10, 18]])) + assert_equal(np.prod(x, axis=1), matrix([[6], [120]])) + + y = matrix([0, 1, 3]) + assert_(y.prod() == 0) + + def test_max(self): + x = matrix([[1, 2, 3], [4, 5, 6]]) + assert_equal(x.max(), 6) + assert_equal(x.max(0), matrix([[4, 5, 6]])) + assert_equal(x.max(1), matrix([[3], [6]])) + + assert_equal(np.max(x), 6) + assert_equal(np.max(x, axis=0), matrix([[4, 5, 6]])) + assert_equal(np.max(x, axis=1), matrix([[3], [6]])) + + def test_min(self): + x = matrix([[1, 2, 3], [4, 5, 6]]) + assert_equal(x.min(), 1) + assert_equal(x.min(0), matrix([[1, 2, 3]])) + assert_equal(x.min(1), matrix([[1], [4]])) + + assert_equal(np.min(x), 1) + assert_equal(np.min(x, axis=0), matrix([[1, 2, 3]])) + assert_equal(np.min(x, axis=1), matrix([[1], [4]])) + + def test_ptp(self): + x = np.arange(4).reshape((2, 2)) + mx = x.view(np.matrix) + assert_(mx.ptp() == 3) + assert_(np.all(mx.ptp(0) == np.array([2, 2]))) + assert_(np.all(mx.ptp(1) == np.array([1, 1]))) + + def test_var(self): + x = np.arange(9).reshape((3, 3)) + mx = x.view(np.matrix) + assert_equal(x.var(ddof=0), mx.var(ddof=0)) + assert_equal(x.var(ddof=1), mx.var(ddof=1)) + + def test_basic(self): + import numpy.linalg as linalg + + A = np.array([[1., 2.], + [3., 4.]]) + mA = matrix(A) + assert_(np.allclose(linalg.inv(A), mA.I)) + assert_(np.all(np.array(np.transpose(A) == mA.T))) + assert_(np.all(np.array(np.transpose(A) == mA.H))) + assert_(np.all(A == mA.A)) + + B = A + 2j * A + mB = matrix(B) + assert_(np.allclose(linalg.inv(B), mB.I)) + assert_(np.all(np.array(np.transpose(B) == mB.T))) + assert_(np.all(np.array(np.transpose(B).conj() == mB.H))) + + def test_pinv(self): + x = matrix(np.arange(6).reshape(2, 3)) + xpinv = matrix([[-0.77777778, 0.27777778], + [-0.11111111, 0.11111111], + [ 0.55555556, -0.05555556]]) + assert_almost_equal(x.I, xpinv) + + def test_comparisons(self): + A = np.arange(100).reshape(10, 10) + mA = matrix(A) + mB = matrix(A) + 0.1 + assert_(np.all(mB == A + 0.1)) + assert_(np.all(mB == matrix(A + 0.1))) + assert_(not np.any(mB == matrix(A - 0.1))) + assert_(np.all(mA < mB)) + assert_(np.all(mA <= mB)) + assert_(np.all(mA <= mA)) + assert_(not np.any(mA < mA)) + + assert_(not np.any(mB < mA)) + assert_(np.all(mB >= mA)) + assert_(np.all(mB >= mB)) + assert_(not np.any(mB > mB)) + + assert_(np.all(mA == mA)) + assert_(not np.any(mA == mB)) + assert_(np.all(mB != mA)) + + assert_(not np.all(abs(mA) > 0)) + assert_(np.all(abs(mB > 0))) + + def test_asmatrix(self): + A = np.arange(100).reshape(10, 10) + mA = asmatrix(A) + A[0, 0] = -10 + assert_(A[0, 0] == mA[0, 0]) + + def test_noaxis(self): + A = matrix([[1, 0], [0, 1]]) + assert_(A.sum() == matrix(2)) + assert_(A.mean() == matrix(0.5)) + + def test_repr(self): + A = matrix([[1, 0], [0, 1]]) + assert_(repr(A) == "matrix([[1, 0],\n [0, 1]])") + + def test_make_bool_matrix_from_str(self): + A = matrix('True; True; False') + B = matrix([[True], [True], [False]]) + assert_array_equal(A, B) + +class TestCasting: + def test_basic(self): + A = np.arange(100).reshape(10, 10) + mA = matrix(A) + + mB = mA.copy() + O = np.ones((10, 10), np.float64) * 0.1 + mB = mB + O + assert_(mB.dtype.type == np.float64) + assert_(np.all(mA != mB)) + assert_(np.all(mB == mA + 0.1)) + + mC = mA.copy() + O = np.ones((10, 10), np.complex128) + mC = mC * O + assert_(mC.dtype.type == np.complex128) + assert_(np.all(mA != mB)) + + +class TestAlgebra: + def test_basic(self): + import numpy.linalg as linalg + + A = np.array([[1., 2.], [3., 4.]]) + mA = matrix(A) + + B = np.identity(2) + for i in range(6): + assert_(np.allclose((mA ** i).A, B)) + B = np.dot(B, A) + + Ainv = linalg.inv(A) + B = np.identity(2) + for i in range(6): + assert_(np.allclose((mA ** -i).A, B)) + B = np.dot(B, Ainv) + + assert_(np.allclose((mA * mA).A, np.dot(A, A))) + assert_(np.allclose((mA + mA).A, (A + A))) + assert_(np.allclose((3 * mA).A, (3 * A))) + + mA2 = matrix(A) + mA2 *= 3 + assert_(np.allclose(mA2.A, 3 * A)) + + def test_pow(self): + """Test raising a matrix to an integer power works as expected.""" + m = matrix("1. 2.; 3. 4.") + m2 = m.copy() + m2 **= 2 + mi = m.copy() + mi **= -1 + m4 = m2.copy() + m4 **= 2 + assert_array_almost_equal(m2, m**2) + assert_array_almost_equal(m4, np.dot(m2, m2)) + assert_array_almost_equal(np.dot(mi, m), np.eye(2)) + + def test_scalar_type_pow(self): + m = matrix([[1, 2], [3, 4]]) + for scalar_t in [np.int8, np.uint8]: + two = scalar_t(2) + assert_array_almost_equal(m ** 2, m ** two) + + def test_notimplemented(self): + '''Check that 'not implemented' operations produce a failure.''' + A = matrix([[1., 2.], + [3., 4.]]) + + # __rpow__ + with assert_raises(TypeError): + 1.0**A + + # __mul__ with something not a list, ndarray, tuple, or scalar + with assert_raises(TypeError): + A * object() + + +class TestMatrixReturn: + def test_instance_methods(self): + a = matrix([1.0], dtype='f8') + methodargs = { + 'astype': ('intc',), + 'clip': (0.0, 1.0), + 'compress': ([1],), + 'repeat': (1,), + 'reshape': (1,), + 'swapaxes': (0, 0), + 'dot': np.array([1.0]), + } + excluded_methods = [ + 'argmin', 'choose', 'dump', 'dumps', 'fill', 'getfield', + 'getA', 'getA1', 'item', 'nonzero', 'put', 'putmask', 'resize', + 'searchsorted', 'setflags', 'setfield', 'sort', + 'partition', 'argpartition', 'newbyteorder', 'to_device', + 'take', 'tofile', 'tolist', 'tobytes', 'all', 'any', + 'sum', 'argmax', 'argmin', 'min', 'max', 'mean', 'var', 'ptp', + 'prod', 'std', 'ctypes', 'itemset', 'bitwise_count', + ] + for attrib in dir(a): + if attrib.startswith('_') or attrib in excluded_methods: + continue + f = getattr(a, attrib) + if isinstance(f, collections.abc.Callable): + # reset contents of a + a.astype('f8') + a.fill(1.0) + args = methodargs.get(attrib, ()) + b = f(*args) + assert_(type(b) is matrix, f"{attrib}") + assert_(type(a.real) is matrix) + assert_(type(a.imag) is matrix) + c, d = matrix([0.0]).nonzero() + assert_(type(c) is np.ndarray) + assert_(type(d) is np.ndarray) + + +class TestIndexing: + def test_basic(self): + x = asmatrix(np.zeros((3, 2), float)) + y = np.zeros((3, 1), float) + y[:, 0] = [0.8, 0.2, 0.3] + x[:, 1] = y > 0.5 + assert_equal(x, [[0, 1], [0, 0], [0, 0]]) + + +class TestNewScalarIndexing: + a = matrix([[1, 2], [3, 4]]) + + def test_dimesions(self): + a = self.a + x = a[0] + assert_equal(x.ndim, 2) + + def test_array_from_matrix_list(self): + a = self.a + x = np.array([a, a]) + assert_equal(x.shape, [2, 2, 2]) + + def test_array_to_list(self): + a = self.a + assert_equal(a.tolist(), [[1, 2], [3, 4]]) + + def test_fancy_indexing(self): + a = self.a + x = a[1, [0, 1, 0]] + assert_(isinstance(x, matrix)) + assert_equal(x, matrix([[3, 4, 3]])) + x = a[[1, 0]] + assert_(isinstance(x, matrix)) + assert_equal(x, matrix([[3, 4], [1, 2]])) + x = a[[[1], [0]], [[1, 0], [0, 1]]] + assert_(isinstance(x, matrix)) + assert_equal(x, matrix([[4, 3], [1, 2]])) + + def test_matrix_element(self): + x = matrix([[1, 2, 3], [4, 5, 6]]) + assert_equal(x[0][0], matrix([[1, 2, 3]])) + assert_equal(x[0][0].shape, (1, 3)) + assert_equal(x[0].shape, (1, 3)) + assert_equal(x[:, 0].shape, (2, 1)) + + x = matrix(0) + assert_equal(x[0, 0], 0) + assert_equal(x[0], 0) + assert_equal(x[:, 0].shape, x.shape) + + def test_scalar_indexing(self): + x = asmatrix(np.zeros((3, 2), float)) + assert_equal(x[0, 0], x[0][0]) + + def test_row_column_indexing(self): + x = asmatrix(np.eye(2)) + assert_array_equal(x[0, :], [[1, 0]]) + assert_array_equal(x[1, :], [[0, 1]]) + assert_array_equal(x[:, 0], [[1], [0]]) + assert_array_equal(x[:, 1], [[0], [1]]) + + def test_boolean_indexing(self): + A = np.arange(6) + A.shape = (3, 2) + x = asmatrix(A) + assert_array_equal(x[:, np.array([True, False])], x[:, 0]) + assert_array_equal(x[np.array([True, False, False]), :], x[0, :]) + + def test_list_indexing(self): + A = np.arange(6) + A.shape = (3, 2) + x = asmatrix(A) + assert_array_equal(x[:, [1, 0]], x[:, ::-1]) + assert_array_equal(x[[2, 1, 0], :], x[::-1, :]) + + +class TestPower: + def test_returntype(self): + a = np.array([[0, 1], [0, 0]]) + assert_(type(matrix_power(a, 2)) is np.ndarray) + a = asmatrix(a) + assert_(type(matrix_power(a, 2)) is matrix) + + def test_list(self): + assert_array_equal(matrix_power([[0, 1], [0, 0]], 2), [[0, 0], [0, 0]]) + + +class TestShape: + + a = np.array([[1], [2]]) + m = matrix([[1], [2]]) + + def test_shape(self): + assert_equal(self.a.shape, (2, 1)) + assert_equal(self.m.shape, (2, 1)) + + def test_numpy_ravel(self): + assert_equal(np.ravel(self.a).shape, (2,)) + assert_equal(np.ravel(self.m).shape, (2,)) + + def test_member_ravel(self): + assert_equal(self.a.ravel().shape, (2,)) + assert_equal(self.m.ravel().shape, (1, 2)) + + def test_member_flatten(self): + assert_equal(self.a.flatten().shape, (2,)) + assert_equal(self.m.flatten().shape, (1, 2)) + + def test_numpy_ravel_order(self): + x = np.array([[1, 2, 3], [4, 5, 6]]) + assert_equal(np.ravel(x), [1, 2, 3, 4, 5, 6]) + assert_equal(np.ravel(x, order='F'), [1, 4, 2, 5, 3, 6]) + assert_equal(np.ravel(x.T), [1, 4, 2, 5, 3, 6]) + assert_equal(np.ravel(x.T, order='A'), [1, 2, 3, 4, 5, 6]) + x = matrix([[1, 2, 3], [4, 5, 6]]) + assert_equal(np.ravel(x), [1, 2, 3, 4, 5, 6]) + assert_equal(np.ravel(x, order='F'), [1, 4, 2, 5, 3, 6]) + assert_equal(np.ravel(x.T), [1, 4, 2, 5, 3, 6]) + assert_equal(np.ravel(x.T, order='A'), [1, 2, 3, 4, 5, 6]) + + def test_matrix_ravel_order(self): + x = matrix([[1, 2, 3], [4, 5, 6]]) + assert_equal(x.ravel(), [[1, 2, 3, 4, 5, 6]]) + assert_equal(x.ravel(order='F'), [[1, 4, 2, 5, 3, 6]]) + assert_equal(x.T.ravel(), [[1, 4, 2, 5, 3, 6]]) + assert_equal(x.T.ravel(order='A'), [[1, 2, 3, 4, 5, 6]]) + + def test_array_memory_sharing(self): + assert_(np.may_share_memory(self.a, self.a.ravel())) + assert_(not np.may_share_memory(self.a, self.a.flatten())) + + def test_matrix_memory_sharing(self): + assert_(np.may_share_memory(self.m, self.m.ravel())) + assert_(not np.may_share_memory(self.m, self.m.flatten())) + + def test_expand_dims_matrix(self): + # matrices are always 2d - so expand_dims only makes sense when the + # type is changed away from matrix. + a = np.arange(10).reshape((2, 5)).view(np.matrix) + expanded = np.expand_dims(a, axis=1) + assert_equal(expanded.ndim, 3) + assert_(not isinstance(expanded, np.matrix)) diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_interaction.py b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_interaction.py new file mode 100644 index 0000000000000000000000000000000000000000..87d133a2c58616f98a6f0eedc95b8ac925019a49 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_interaction.py @@ -0,0 +1,360 @@ +"""Tests of interaction of matrix with other parts of numpy. + +Note that tests with MaskedArray and linalg are done in separate files. +""" +import textwrap +import warnings + +import pytest + +import numpy as np +from numpy.testing import ( + assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises, + assert_raises_regex, +) + + +def test_fancy_indexing(): + # The matrix class messes with the shape. While this is always + # weird (getitem is not used, it does not have setitem nor knows + # about fancy indexing), this tests gh-3110 + # 2018-04-29: moved here from core.tests.test_index. + m = np.matrix([[1, 2], [3, 4]]) + + assert_(isinstance(m[[0, 1, 0], :], np.matrix)) + + # gh-3110. Note the transpose currently because matrices do *not* + # support dimension fixing for fancy indexing correctly. + x = np.asmatrix(np.arange(50).reshape(5, 10)) + assert_equal(x[:2, np.array(-1)], x[:2, -1].T) + + +def test_polynomial_mapdomain(): + # test that polynomial preserved matrix subtype. + # 2018-04-29: moved here from polynomial.tests.polyutils. + dom1 = [0, 4] + dom2 = [1, 3] + x = np.matrix([dom1, dom1]) + res = np.polynomial.polyutils.mapdomain(x, dom1, dom2) + assert_(isinstance(res, np.matrix)) + + +def test_sort_matrix_none(): + # 2018-04-29: moved here from core.tests.test_multiarray + a = np.matrix([[2, 1, 0]]) + actual = np.sort(a, axis=None) + expected = np.matrix([[0, 1, 2]]) + assert_equal(actual, expected) + assert_(type(expected) is np.matrix) + + +def test_partition_matrix_none(): + # gh-4301 + # 2018-04-29: moved here from core.tests.test_multiarray + a = np.matrix([[2, 1, 0]]) + actual = np.partition(a, 1, axis=None) + expected = np.matrix([[0, 1, 2]]) + assert_equal(actual, expected) + assert_(type(expected) is np.matrix) + + +def test_dot_scalar_and_matrix_of_objects(): + # Ticket #2469 + # 2018-04-29: moved here from core.tests.test_multiarray + arr = np.matrix([1, 2], dtype=object) + desired = np.matrix([[3, 6]], dtype=object) + assert_equal(np.dot(arr, 3), desired) + assert_equal(np.dot(3, arr), desired) + + +def test_inner_scalar_and_matrix(): + # 2018-04-29: moved here from core.tests.test_multiarray + for dt in np.typecodes['AllInteger'] + np.typecodes['AllFloat'] + '?': + sca = np.array(3, dtype=dt)[()] + arr = np.matrix([[1, 2], [3, 4]], dtype=dt) + desired = np.matrix([[3, 6], [9, 12]], dtype=dt) + assert_equal(np.inner(arr, sca), desired) + assert_equal(np.inner(sca, arr), desired) + + +def test_inner_scalar_and_matrix_of_objects(): + # Ticket #4482 + # 2018-04-29: moved here from core.tests.test_multiarray + arr = np.matrix([1, 2], dtype=object) + desired = np.matrix([[3, 6]], dtype=object) + assert_equal(np.inner(arr, 3), desired) + assert_equal(np.inner(3, arr), desired) + + +def test_iter_allocate_output_subtype(): + # Make sure that the subtype with priority wins + # 2018-04-29: moved here from core.tests.test_nditer, given the + # matrix specific shape test. + + # matrix vs ndarray + a = np.matrix([[1, 2], [3, 4]]) + b = np.arange(4).reshape(2, 2).T + i = np.nditer([a, b, None], [], + [['readonly'], ['readonly'], ['writeonly', 'allocate']]) + assert_(type(i.operands[2]) is np.matrix) + assert_(type(i.operands[2]) is not np.ndarray) + assert_equal(i.operands[2].shape, (2, 2)) + + # matrix always wants things to be 2D + b = np.arange(4).reshape(1, 2, 2) + assert_raises(RuntimeError, np.nditer, [a, b, None], [], + [['readonly'], ['readonly'], ['writeonly', 'allocate']]) + # but if subtypes are disabled, the result can still work + i = np.nditer([a, b, None], [], + [['readonly'], ['readonly'], + ['writeonly', 'allocate', 'no_subtype']]) + assert_(type(i.operands[2]) is np.ndarray) + assert_(type(i.operands[2]) is not np.matrix) + assert_equal(i.operands[2].shape, (1, 2, 2)) + + +def like_function(): + # 2018-04-29: moved here from core.tests.test_numeric + a = np.matrix([[1, 2], [3, 4]]) + for like_function in np.zeros_like, np.ones_like, np.empty_like: + b = like_function(a) + assert_(type(b) is np.matrix) + + c = like_function(a, subok=False) + assert_(type(c) is not np.matrix) + + +def test_array_astype(): + # 2018-04-29: copied here from core.tests.test_api + # subok=True passes through a matrix + a = np.matrix([[0, 1, 2], [3, 4, 5]], dtype='f4') + b = a.astype('f4', subok=True, copy=False) + assert_(a is b) + + # subok=True is default, and creates a subtype on a cast + b = a.astype('i4', copy=False) + assert_equal(a, b) + assert_equal(type(b), np.matrix) + + # subok=False never returns a matrix + b = a.astype('f4', subok=False, copy=False) + assert_equal(a, b) + assert_(not (a is b)) + assert_(type(b) is not np.matrix) + + +def test_stack(): + # 2018-04-29: copied here from core.tests.test_shape_base + # check np.matrix cannot be stacked + m = np.matrix([[1, 2], [3, 4]]) + assert_raises_regex(ValueError, 'shape too large to be a matrix', + np.stack, [m, m]) + + +def test_object_scalar_multiply(): + # Tickets #2469 and #4482 + # 2018-04-29: moved here from core.tests.test_ufunc + arr = np.matrix([1, 2], dtype=object) + desired = np.matrix([[3, 6]], dtype=object) + assert_equal(np.multiply(arr, 3), desired) + assert_equal(np.multiply(3, arr), desired) + + +def test_nanfunctions_matrices(): + # Check that it works and that type and + # shape are preserved + # 2018-04-29: moved here from core.tests.test_nanfunctions + mat = np.matrix(np.eye(3)) + for f in [np.nanmin, np.nanmax]: + res = f(mat, axis=0) + assert_(isinstance(res, np.matrix)) + assert_(res.shape == (1, 3)) + res = f(mat, axis=1) + assert_(isinstance(res, np.matrix)) + assert_(res.shape == (3, 1)) + res = f(mat) + assert_(np.isscalar(res)) + # check that rows of nan are dealt with for subclasses (#4628) + mat[1] = np.nan + for f in [np.nanmin, np.nanmax]: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + res = f(mat, axis=0) + assert_(isinstance(res, np.matrix)) + assert_(not np.any(np.isnan(res))) + assert_(len(w) == 0) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + res = f(mat, axis=1) + assert_(isinstance(res, np.matrix)) + assert_(np.isnan(res[1, 0]) and not np.isnan(res[0, 0]) + and not np.isnan(res[2, 0])) + assert_(len(w) == 1, 'no warning raised') + assert_(issubclass(w[0].category, RuntimeWarning)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + res = f(mat) + assert_(np.isscalar(res)) + assert_(res != np.nan) + assert_(len(w) == 0) + + +def test_nanfunctions_matrices_general(): + # Check that it works and that type and + # shape are preserved + # 2018-04-29: moved here from core.tests.test_nanfunctions + mat = np.matrix(np.eye(3)) + for f in (np.nanargmin, np.nanargmax, np.nansum, np.nanprod, + np.nanmean, np.nanvar, np.nanstd): + res = f(mat, axis=0) + assert_(isinstance(res, np.matrix)) + assert_(res.shape == (1, 3)) + res = f(mat, axis=1) + assert_(isinstance(res, np.matrix)) + assert_(res.shape == (3, 1)) + res = f(mat) + assert_(np.isscalar(res)) + + for f in np.nancumsum, np.nancumprod: + res = f(mat, axis=0) + assert_(isinstance(res, np.matrix)) + assert_(res.shape == (3, 3)) + res = f(mat, axis=1) + assert_(isinstance(res, np.matrix)) + assert_(res.shape == (3, 3)) + res = f(mat) + assert_(isinstance(res, np.matrix)) + assert_(res.shape == (1, 3 * 3)) + + +def test_average_matrix(): + # 2018-04-29: moved here from core.tests.test_function_base. + y = np.matrix(np.random.rand(5, 5)) + assert_array_equal(y.mean(0), np.average(y, 0)) + + a = np.matrix([[1, 2], [3, 4]]) + w = np.matrix([[1, 2], [3, 4]]) + + r = np.average(a, axis=0, weights=w) + assert_equal(type(r), np.matrix) + assert_equal(r, [[2.5, 10.0 / 3]]) + + +def test_dot_matrix(): + # Test to make sure matrices give the same answer as ndarrays + # 2018-04-29: moved here from core.tests.test_function_base. + x = np.linspace(0, 5) + y = np.linspace(-5, 0) + mx = np.matrix(x) + my = np.matrix(y) + r = np.dot(x, y) + mr = np.dot(mx, my.T) + assert_almost_equal(mr, r) + + +def test_ediff1d_matrix(): + # 2018-04-29: moved here from core.tests.test_arraysetops. + assert isinstance(np.ediff1d(np.matrix(1)), np.matrix) + assert isinstance(np.ediff1d(np.matrix(1), to_begin=1), np.matrix) + + +def test_apply_along_axis_matrix(): + # this test is particularly malicious because matrix + # refuses to become 1d + # 2018-04-29: moved here from core.tests.test_shape_base. + def double(row): + return row * 2 + + m = np.matrix([[0, 1], [2, 3]]) + expected = np.matrix([[0, 2], [4, 6]]) + + result = np.apply_along_axis(double, 0, m) + assert_(isinstance(result, np.matrix)) + assert_array_equal(result, expected) + + result = np.apply_along_axis(double, 1, m) + assert_(isinstance(result, np.matrix)) + assert_array_equal(result, expected) + + +def test_kron_matrix(): + # 2018-04-29: moved here from core.tests.test_shape_base. + a = np.ones([2, 2]) + m = np.asmatrix(a) + assert_equal(type(np.kron(a, a)), np.ndarray) + assert_equal(type(np.kron(m, m)), np.matrix) + assert_equal(type(np.kron(a, m)), np.matrix) + assert_equal(type(np.kron(m, a)), np.matrix) + + +class TestConcatenatorMatrix: + # 2018-04-29: moved here from core.tests.test_index_tricks. + def test_matrix(self): + a = [1, 2] + b = [3, 4] + + ab_r = np.r_['r', a, b] + ab_c = np.r_['c', a, b] + + assert_equal(type(ab_r), np.matrix) + assert_equal(type(ab_c), np.matrix) + + assert_equal(np.array(ab_r), [[1, 2, 3, 4]]) + assert_equal(np.array(ab_c), [[1], [2], [3], [4]]) + + assert_raises(ValueError, lambda: np.r_['rc', a, b]) + + def test_matrix_scalar(self): + r = np.r_['r', [1, 2], 3] + assert_equal(type(r), np.matrix) + assert_equal(np.array(r), [[1, 2, 3]]) + + def test_matrix_builder(self): + a = np.array([1]) + b = np.array([2]) + c = np.array([3]) + d = np.array([4]) + actual = np.r_['a, b; c, d'] + expected = np.bmat([[a, b], [c, d]]) + + assert_equal(actual, expected) + assert_equal(type(actual), type(expected)) + + +def test_array_equal_error_message_matrix(): + # 2018-04-29: moved here from testing.tests.test_utils. + with pytest.raises(AssertionError) as exc_info: + assert_equal(np.array([1, 2]), np.matrix([1, 2])) + msg = str(exc_info.value) + msg_reference = textwrap.dedent("""\ + + Arrays are not equal + + (shapes (2,), (1, 2) mismatch) + ACTUAL: array([1, 2]) + DESIRED: matrix([[1, 2]])""") + assert_equal(msg, msg_reference) + + +def test_array_almost_equal_matrix(): + # Matrix slicing keeps things 2-D, while array does not necessarily. + # See gh-8452. + # 2018-04-29: moved here from testing.tests.test_utils. + m1 = np.matrix([[1., 2.]]) + m2 = np.matrix([[1., np.nan]]) + m3 = np.matrix([[1., -np.inf]]) + m4 = np.matrix([[np.nan, np.inf]]) + m5 = np.matrix([[1., 2.], [np.nan, np.inf]]) + for assert_func in assert_array_almost_equal, assert_almost_equal: + for m in m1, m2, m3, m4, m5: + assert_func(m, m) + a = np.array(m) + assert_func(a, m) + assert_func(m, a) diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_masked_matrix.py b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_masked_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..e6df047ee6ca23f7e92fbba7b17d4cbcba70ed6f --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_masked_matrix.py @@ -0,0 +1,240 @@ +import pickle + +import numpy as np +from numpy.ma.core import ( + MaskedArray, + MaskType, + add, + allequal, + divide, + getmask, + hypot, + log, + masked, + masked_array, + masked_values, + nomask, +) +from numpy.ma.extras import mr_ +from numpy.ma.testutils import assert_, assert_array_equal, assert_equal, assert_raises + + +class MMatrix(MaskedArray, np.matrix,): + + def __new__(cls, data, mask=nomask): + mat = np.matrix(data) + _data = MaskedArray.__new__(cls, data=mat, mask=mask) + return _data + + def __array_finalize__(self, obj): + np.matrix.__array_finalize__(self, obj) + MaskedArray.__array_finalize__(self, obj) + + @property + def _series(self): + _view = self.view(MaskedArray) + _view._sharedmask = False + return _view + + +class TestMaskedMatrix: + def test_matrix_indexing(self): + # Tests conversions and indexing + x1 = np.matrix([[1, 2, 3], [4, 3, 2]]) + x2 = masked_array(x1, mask=[[1, 0, 0], [0, 1, 0]]) + x3 = masked_array(x1, mask=[[0, 1, 0], [1, 0, 0]]) + x4 = masked_array(x1) + # test conversion to strings + str(x2) # raises? + repr(x2) # raises? + # tests of indexing + assert_(type(x2[1, 0]) is type(x1[1, 0])) + assert_(x1[1, 0] == x2[1, 0]) + assert_(x2[1, 1] is masked) + assert_equal(x1[0, 2], x2[0, 2]) + assert_equal(x1[0, 1:], x2[0, 1:]) + assert_equal(x1[:, 2], x2[:, 2]) + assert_equal(x1[:], x2[:]) + assert_equal(x1[1:], x3[1:]) + x1[0, 2] = 9 + x2[0, 2] = 9 + assert_equal(x1, x2) + x1[0, 1:] = 99 + x2[0, 1:] = 99 + assert_equal(x1, x2) + x2[0, 1] = masked + assert_equal(x1, x2) + x2[0, 1:] = masked + assert_equal(x1, x2) + x2[0, :] = x1[0, :] + x2[0, 1] = masked + assert_(allequal(getmask(x2), np.array([[0, 1, 0], [0, 1, 0]]))) + x3[1, :] = masked_array([1, 2, 3], [1, 1, 0]) + assert_(allequal(getmask(x3)[1], masked_array([1, 1, 0]))) + assert_(allequal(getmask(x3[1]), masked_array([1, 1, 0]))) + x4[1, :] = masked_array([1, 2, 3], [1, 1, 0]) + assert_(allequal(getmask(x4[1]), masked_array([1, 1, 0]))) + assert_(allequal(x4[1], masked_array([1, 2, 3]))) + x1 = np.matrix(np.arange(5) * 1.0) + x2 = masked_values(x1, 3.0) + assert_equal(x1, x2) + assert_(allequal(masked_array([0, 0, 0, 1, 0], dtype=MaskType), + x2.mask)) + assert_equal(3.0, x2.fill_value) + + def test_pickling_subbaseclass(self): + # Test pickling w/ a subclass of ndarray + a = masked_array(np.matrix(list(range(10))), mask=[1, 0, 1, 0, 0] * 2) + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + a_pickled = pickle.loads(pickle.dumps(a, protocol=proto)) + assert_equal(a_pickled._mask, a._mask) + assert_equal(a_pickled, a) + assert_(isinstance(a_pickled._data, np.matrix)) + + def test_count_mean_with_matrix(self): + m = masked_array(np.matrix([[1, 2], [3, 4]]), mask=np.zeros((2, 2))) + + assert_equal(m.count(axis=0).shape, (1, 2)) + assert_equal(m.count(axis=1).shape, (2, 1)) + + # Make sure broadcasting inside mean and var work + assert_equal(m.mean(axis=0), [[2., 3.]]) + assert_equal(m.mean(axis=1), [[1.5], [3.5]]) + + def test_flat(self): + # Test that flat can return items even for matrices [#4585, #4615] + # test simple access + test = masked_array(np.matrix([[1, 2, 3]]), mask=[0, 0, 1]) + assert_equal(test.flat[1], 2) + assert_equal(test.flat[2], masked) + assert_(np.all(test.flat[0:2] == test[0, 0:2])) + # Test flat on masked_matrices + test = masked_array(np.matrix([[1, 2, 3]]), mask=[0, 0, 1]) + test.flat = masked_array([3, 2, 1], mask=[1, 0, 0]) + control = masked_array(np.matrix([[3, 2, 1]]), mask=[1, 0, 0]) + assert_equal(test, control) + # Test setting + test = masked_array(np.matrix([[1, 2, 3]]), mask=[0, 0, 1]) + testflat = test.flat + testflat[:] = testflat[[2, 1, 0]] + assert_equal(test, control) + testflat[0] = 9 + # test that matrices keep the correct shape (#4615) + a = masked_array(np.matrix(np.eye(2)), mask=0) + b = a.flat + b01 = b[:2] + assert_equal(b01.data, np.array([[1., 0.]])) + assert_equal(b01.mask, np.array([[False, False]])) + + def test_allany_onmatrices(self): + x = np.array([[0.13, 0.26, 0.90], + [0.28, 0.33, 0.63], + [0.31, 0.87, 0.70]]) + X = np.matrix(x) + m = np.array([[True, False, False], + [False, False, False], + [True, True, False]], dtype=np.bool) + mX = masked_array(X, mask=m) + mXbig = (mX > 0.5) + mXsmall = (mX < 0.5) + + assert_(not mXbig.all()) + assert_(mXbig.any()) + assert_equal(mXbig.all(0), np.matrix([False, False, True])) + assert_equal(mXbig.all(1), np.matrix([False, False, True]).T) + assert_equal(mXbig.any(0), np.matrix([False, False, True])) + assert_equal(mXbig.any(1), np.matrix([True, True, True]).T) + + assert_(not mXsmall.all()) + assert_(mXsmall.any()) + assert_equal(mXsmall.all(0), np.matrix([True, True, False])) + assert_equal(mXsmall.all(1), np.matrix([False, False, False]).T) + assert_equal(mXsmall.any(0), np.matrix([True, True, False])) + assert_equal(mXsmall.any(1), np.matrix([True, True, False]).T) + + def test_compressed(self): + a = masked_array(np.matrix([1, 2, 3, 4]), mask=[0, 0, 0, 0]) + b = a.compressed() + assert_equal(b, a) + assert_(isinstance(b, np.matrix)) + a[0, 0] = masked + b = a.compressed() + assert_equal(b, [[2, 3, 4]]) + + def test_ravel(self): + a = masked_array(np.matrix([1, 2, 3, 4, 5]), mask=[[0, 1, 0, 0, 0]]) + aravel = a.ravel() + assert_equal(aravel.shape, (1, 5)) + assert_equal(aravel._mask.shape, a.shape) + + def test_view(self): + # Test view w/ flexible dtype + iterator = list(zip(np.arange(10), np.random.rand(10))) + data = np.array(iterator) + a = masked_array(iterator, dtype=[('a', float), ('b', float)]) + a.mask[0] = (1, 0) + test = a.view((float, 2), np.matrix) + assert_equal(test, data) + assert_(isinstance(test, np.matrix)) + assert_(not isinstance(test, MaskedArray)) + + +class TestSubclassing: + # Test suite for masked subclasses of ndarray. + + def setup_method(self): + x = np.arange(5, dtype='float') + mx = MMatrix(x, mask=[0, 1, 0, 0, 0]) + self.data = (x, mx) + + def test_maskedarray_subclassing(self): + # Tests subclassing MaskedArray + (x, mx) = self.data + assert_(isinstance(mx._data, np.matrix)) + + def test_masked_unary_operations(self): + # Tests masked_unary_operation + (x, mx) = self.data + with np.errstate(divide='ignore'): + assert_(isinstance(log(mx), MMatrix)) + assert_equal(log(x), np.log(x)) + + def test_masked_binary_operations(self): + # Tests masked_binary_operation + (x, mx) = self.data + # Result should be a MMatrix + assert_(isinstance(add(mx, mx), MMatrix)) + assert_(isinstance(add(mx, x), MMatrix)) + # Result should work + assert_equal(add(mx, x), mx + x) + assert_(isinstance(add(mx, mx)._data, np.matrix)) + with assert_raises(TypeError): + add.outer(mx, mx) + assert_(isinstance(hypot(mx, mx), MMatrix)) + assert_(isinstance(hypot(mx, x), MMatrix)) + + def test_masked_binary_operations2(self): + # Tests domained_masked_binary_operation + (x, mx) = self.data + xmx = masked_array(mx.data.__array__(), mask=mx.mask) + assert_(isinstance(divide(mx, mx), MMatrix)) + assert_(isinstance(divide(mx, x), MMatrix)) + assert_equal(divide(mx, mx), divide(xmx, xmx)) + +class TestConcatenator: + # Tests for mr_, the equivalent of r_ for masked arrays. + + def test_matrix_builder(self): + assert_raises(np.ma.MAError, lambda: mr_['1, 2; 3, 4']) + + def test_matrix(self): + # Test consistency with unmasked version. If we ever deprecate + # matrix, this test should either still pass, or both actual and + # expected should fail to be build. + actual = mr_['r', 1, 2, 3] + expected = np.ma.array(np.r_['r', 1, 2, 3]) + assert_array_equal(actual, expected) + + # outer type is masked array, inner type is matrix + assert_equal(type(actual), type(expected)) + assert_equal(type(actual.data), type(expected.data)) diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_matrix_linalg.py b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_matrix_linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..4e639653bda446b7c208f4352ce6e2d9b1e2ee17 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_matrix_linalg.py @@ -0,0 +1,105 @@ +""" Test functions for linalg module using the matrix class.""" +import numpy as np +from numpy.linalg.tests.test_linalg import ( + CondCases, + DetCases, + EigCases, + EigvalsCases, + InvCases, + LinalgCase, + LinalgTestCase, + LstsqCases, + PinvCases, + SolveCases, + SVDCases, + _TestNorm2D, + _TestNormDoubleBase, + _TestNormInt64Base, + _TestNormSingleBase, + apply_tag, +) +from numpy.linalg.tests.test_linalg import TestQR as _TestQR + +CASES = [] + +# square test cases +CASES += apply_tag('square', [ + LinalgCase("0x0_matrix", + np.empty((0, 0), dtype=np.double).view(np.matrix), + np.empty((0, 1), dtype=np.double).view(np.matrix), + tags={'size-0'}), + LinalgCase("matrix_b_only", + np.array([[1., 2.], [3., 4.]]), + np.matrix([2., 1.]).T), + LinalgCase("matrix_a_and_b", + np.matrix([[1., 2.], [3., 4.]]), + np.matrix([2., 1.]).T), +]) + +# hermitian test-cases +CASES += apply_tag('hermitian', [ + LinalgCase("hmatrix_a_and_b", + np.matrix([[1., 2.], [2., 1.]]), + None), +]) +# No need to make generalized or strided cases for matrices. + + +class MatrixTestCase(LinalgTestCase): + TEST_CASES = CASES + + +class TestSolveMatrix(SolveCases, MatrixTestCase): + pass + + +class TestInvMatrix(InvCases, MatrixTestCase): + pass + + +class TestEigvalsMatrix(EigvalsCases, MatrixTestCase): + pass + + +class TestEigMatrix(EigCases, MatrixTestCase): + pass + + +class TestSVDMatrix(SVDCases, MatrixTestCase): + pass + + +class TestCondMatrix(CondCases, MatrixTestCase): + pass + + +class TestPinvMatrix(PinvCases, MatrixTestCase): + pass + + +class TestDetMatrix(DetCases, MatrixTestCase): + pass + + +class TestLstsqMatrix(LstsqCases, MatrixTestCase): + pass + + +class _TestNorm2DMatrix(_TestNorm2D): + array = np.matrix + + +class TestNormDoubleMatrix(_TestNorm2DMatrix, _TestNormDoubleBase): + pass + + +class TestNormSingleMatrix(_TestNorm2DMatrix, _TestNormSingleBase): + pass + + +class TestNormInt64Matrix(_TestNorm2DMatrix, _TestNormInt64Base): + pass + + +class TestQRMatrix(_TestQR): + array = np.matrix diff --git a/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_multiarray.py b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_multiarray.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9d1f8efe4175e36120bb51a8106b23ca02154e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/matrixlib/tests/test_multiarray.py @@ -0,0 +1,17 @@ +import numpy as np +from numpy.testing import assert_, assert_array_equal, assert_equal + + +class TestView: + def test_type(self): + x = np.array([1, 2, 3]) + assert_(isinstance(x.view(np.matrix), np.matrix)) + + def test_keywords(self): + x = np.array([(1, 2)], dtype=[('a', np.int8), ('b', np.int8)]) + # We must be specific about the endianness here: + y = x.view(dtype='> f8 # type: ignore[operator] +i8 << f8 # type: ignore[operator] +i | f8 # type: ignore[operator] +i8 ^ f8 # type: ignore[operator] +u8 & f8 # type: ignore[operator] +~f8 # type: ignore[operator] +# TODO: Certain mixes like i4 << u8 go to float and thus should fail diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/chararray.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/chararray.pyi new file mode 100644 index 0000000000000000000000000000000000000000..fb52f7349dd1e0827324e93d3c05d4582b49c83e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/chararray.pyi @@ -0,0 +1,62 @@ +from typing import Any +import numpy as np + +AR_U: np.char.chararray[tuple[Any, ...], np.dtype[np.str_]] +AR_S: np.char.chararray[tuple[Any, ...], np.dtype[np.bytes_]] + +AR_S.encode() # type: ignore[misc] +AR_U.decode() # type: ignore[misc] + +AR_U.join(b"_") # type: ignore[arg-type] +AR_S.join("_") # type: ignore[arg-type] + +AR_U.ljust(5, fillchar=b"a") # type: ignore[arg-type] +AR_S.ljust(5, fillchar="a") # type: ignore[arg-type] +AR_U.rjust(5, fillchar=b"a") # type: ignore[arg-type] +AR_S.rjust(5, fillchar="a") # type: ignore[arg-type] + +AR_U.lstrip(chars=b"a") # type: ignore[arg-type] +AR_S.lstrip(chars="a") # type: ignore[arg-type] +AR_U.strip(chars=b"a") # type: ignore[arg-type] +AR_S.strip(chars="a") # type: ignore[arg-type] +AR_U.rstrip(chars=b"a") # type: ignore[arg-type] +AR_S.rstrip(chars="a") # type: ignore[arg-type] + +AR_U.partition(b"a") # type: ignore[arg-type] +AR_S.partition("a") # type: ignore[arg-type] +AR_U.rpartition(b"a") # type: ignore[arg-type] +AR_S.rpartition("a") # type: ignore[arg-type] + +AR_U.replace(b"_", b"-") # type: ignore[arg-type] +AR_S.replace("_", "-") # type: ignore[arg-type] + +AR_U.split(b"_") # type: ignore[arg-type] +AR_S.split("_") # type: ignore[arg-type] +AR_S.split(1) # type: ignore[arg-type] +AR_U.rsplit(b"_") # type: ignore[arg-type] +AR_S.rsplit("_") # type: ignore[arg-type] + +AR_U.count(b"a", start=[1, 2, 3]) # type: ignore[arg-type] +AR_S.count("a", end=9) # type: ignore[arg-type] + +AR_U.endswith(b"a", start=[1, 2, 3]) # type: ignore[arg-type] +AR_S.endswith("a", end=9) # type: ignore[arg-type] +AR_U.startswith(b"a", start=[1, 2, 3]) # type: ignore[arg-type] +AR_S.startswith("a", end=9) # type: ignore[arg-type] + +AR_U.find(b"a", start=[1, 2, 3]) # type: ignore[arg-type] +AR_S.find("a", end=9) # type: ignore[arg-type] +AR_U.rfind(b"a", start=[1, 2, 3]) # type: ignore[arg-type] +AR_S.rfind("a", end=9) # type: ignore[arg-type] + +AR_U.index(b"a", start=[1, 2, 3]) # type: ignore[arg-type] +AR_S.index("a", end=9) # type: ignore[arg-type] +AR_U.rindex(b"a", start=[1, 2, 3]) # type: ignore[arg-type] +AR_S.rindex("a", end=9) # type: ignore[arg-type] + +AR_U == AR_S # type: ignore[operator] +AR_U != AR_S # type: ignore[operator] +AR_U >= AR_S # type: ignore[operator] +AR_U <= AR_S # type: ignore[operator] +AR_U > AR_S # type: ignore[operator] +AR_U < AR_S # type: ignore[operator] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/comparisons.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/comparisons.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3c8a94bff240374c590d43b01031aff5337d6008 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/comparisons.pyi @@ -0,0 +1,27 @@ +import numpy as np +import numpy.typing as npt + +AR_i: npt.NDArray[np.int64] +AR_f: npt.NDArray[np.float64] +AR_c: npt.NDArray[np.complex128] +AR_m: npt.NDArray[np.timedelta64] +AR_M: npt.NDArray[np.datetime64] + +AR_f > AR_m # type: ignore[operator] +AR_c > AR_m # type: ignore[operator] + +AR_m > AR_f # type: ignore[operator] +AR_m > AR_c # type: ignore[operator] + +AR_i > AR_M # type: ignore[operator] +AR_f > AR_M # type: ignore[operator] +AR_m > AR_M # type: ignore[operator] + +AR_M > AR_i # type: ignore[operator] +AR_M > AR_f # type: ignore[operator] +AR_M > AR_m # type: ignore[operator] + +AR_i > str() # type: ignore[operator] +AR_i > bytes() # type: ignore[operator] +str() > AR_M # type: ignore[operator] +bytes() > AR_M # type: ignore[operator] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/constants.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/constants.pyi new file mode 100644 index 0000000000000000000000000000000000000000..10717f664e0a679fdd30e232001948dac6da063d --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/constants.pyi @@ -0,0 +1,3 @@ +import numpy as np + +np.little_endian = np.little_endian # type: ignore[misc] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/dtype.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/dtype.pyi new file mode 100644 index 0000000000000000000000000000000000000000..64a7c3f775e1a14080aafda8ba3ff293232e9c47 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/dtype.pyi @@ -0,0 +1,17 @@ +import numpy as np + +class Test1: + not_dtype = np.dtype(float) + +class Test2: + dtype = float + +np.dtype(Test1()) # type: ignore[call-overload] +np.dtype(Test2()) # type: ignore[arg-type] + +np.dtype( # type: ignore[call-overload] + { + "field1": (float, 1), + "field2": (int, 3), + } +) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/index_tricks.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/index_tricks.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8b7b1ae2b5bfc0705bbbfbdd244c06b90d0fdd4d --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/index_tricks.pyi @@ -0,0 +1,14 @@ +import numpy as np + +AR_LIKE_i: list[int] +AR_LIKE_f: list[float] + +np.ndindex([1, 2, 3]) # type: ignore[call-overload] +np.unravel_index(AR_LIKE_f, (1, 2, 3)) # type: ignore[arg-type] +np.ravel_multi_index(AR_LIKE_i, (1, 2, 3), mode="bob") # type: ignore[call-overload] +np.mgrid[1] # type: ignore[index] +np.mgrid[...] # type: ignore[index] +np.ogrid[1] # type: ignore[index] +np.ogrid[...] # type: ignore[index] +np.fill_diagonal(AR_LIKE_f, 2) # type: ignore[arg-type] +np.diag_indices(1.0) # type: ignore[arg-type] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/lib_function_base.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/lib_function_base.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f0bf6347691d9e75a8901b4d12afe4d8baa54173 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/lib_function_base.pyi @@ -0,0 +1,62 @@ +from typing import Any + +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +AR_m: npt.NDArray[np.timedelta64] +AR_M: npt.NDArray[np.datetime64] +AR_O: npt.NDArray[np.object_] +AR_b_list: list[npt.NDArray[np.bool]] + +def fn_none_i(a: None, /) -> npt.NDArray[Any]: ... +def fn_ar_i(a: npt.NDArray[np.float64], posarg: int, /) -> npt.NDArray[Any]: ... + +np.average(AR_m) # type: ignore[arg-type] +np.select(1, [AR_f8]) # type: ignore[arg-type] +np.angle(AR_m) # type: ignore[arg-type] +np.unwrap(AR_m) # type: ignore[arg-type] +np.unwrap(AR_c16) # type: ignore[arg-type] +np.trim_zeros(1) # type: ignore[arg-type] +np.place(1, [True], 1.5) # type: ignore[arg-type] +np.vectorize(1) # type: ignore[arg-type] +np.place(AR_f8, slice(None), 5) # type: ignore[arg-type] + +np.piecewise(AR_f8, True, [fn_ar_i], 42) # type: ignore[call-overload] +# TODO: enable these once mypy actually supports ParamSpec (released in 2021) +# NOTE: pyright correctly reports errors for these (`reportCallIssue`) +# np.piecewise(AR_f8, AR_b_list, [fn_none_i]) # type: ignore[call-overload]s +# np.piecewise(AR_f8, AR_b_list, [fn_ar_i]) # type: ignore[call-overload] +# np.piecewise(AR_f8, AR_b_list, [fn_ar_i], 3.14) # type: ignore[call-overload] +# np.piecewise(AR_f8, AR_b_list, [fn_ar_i], 42, None) # type: ignore[call-overload] +# np.piecewise(AR_f8, AR_b_list, [fn_ar_i], 42, _=None) # type: ignore[call-overload] + +np.interp(AR_f8, AR_c16, AR_f8) # type: ignore[arg-type] +np.interp(AR_c16, AR_f8, AR_f8) # type: ignore[arg-type] +np.interp(AR_f8, AR_f8, AR_f8, period=AR_c16) # type: ignore[call-overload] +np.interp(AR_f8, AR_f8, AR_O) # type: ignore[arg-type] + +np.cov(AR_m) # type: ignore[arg-type] +np.cov(AR_O) # type: ignore[arg-type] +np.corrcoef(AR_m) # type: ignore[arg-type] +np.corrcoef(AR_O) # type: ignore[arg-type] +np.corrcoef(AR_f8, bias=True) # type: ignore[call-overload] +np.corrcoef(AR_f8, ddof=2) # type: ignore[call-overload] +np.blackman(1j) # type: ignore[arg-type] +np.bartlett(1j) # type: ignore[arg-type] +np.hanning(1j) # type: ignore[arg-type] +np.hamming(1j) # type: ignore[arg-type] +np.hamming(AR_c16) # type: ignore[arg-type] +np.kaiser(1j, 1) # type: ignore[arg-type] +np.sinc(AR_O) # type: ignore[arg-type] +np.median(AR_M) # type: ignore[arg-type] + +np.percentile(AR_f8, 50j) # type: ignore[call-overload] +np.percentile(AR_f8, 50, interpolation="bob") # type: ignore[call-overload] +np.quantile(AR_f8, 0.5j) # type: ignore[call-overload] +np.quantile(AR_f8, 0.5, interpolation="bob") # type: ignore[call-overload] +np.meshgrid(AR_f8, AR_f8, indexing="bob") # type: ignore[call-overload] +np.delete(AR_f8, AR_f8) # type: ignore[arg-type] +np.insert(AR_f8, AR_f8, 1.5) # type: ignore[arg-type] +np.digitize(AR_f8, 1j) # type: ignore[call-overload] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/lib_utils.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/lib_utils.pyi new file mode 100644 index 0000000000000000000000000000000000000000..25af32b432977e311d8612c74725340d926cf590 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/lib_utils.pyi @@ -0,0 +1,3 @@ +import numpy.lib.array_utils as array_utils + +array_utils.byte_bounds(1) # type: ignore[arg-type] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/linalg.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/linalg.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c4695ee671cd40ce59f2536e75b20df0d7222473 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/linalg.pyi @@ -0,0 +1,48 @@ +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] +AR_O: npt.NDArray[np.object_] +AR_M: npt.NDArray[np.datetime64] + +np.linalg.tensorsolve(AR_O, AR_O) # type: ignore[arg-type] + +np.linalg.solve(AR_O, AR_O) # type: ignore[arg-type] + +np.linalg.tensorinv(AR_O) # type: ignore[arg-type] + +np.linalg.inv(AR_O) # type: ignore[arg-type] + +np.linalg.matrix_power(AR_M, 5) # type: ignore[arg-type] + +np.linalg.cholesky(AR_O) # type: ignore[arg-type] + +np.linalg.qr(AR_O) # type: ignore[arg-type] +np.linalg.qr(AR_f8, mode="bob") # type: ignore[call-overload] + +np.linalg.eigvals(AR_O) # type: ignore[arg-type] + +np.linalg.eigvalsh(AR_O) # type: ignore[arg-type] +np.linalg.eigvalsh(AR_O, UPLO="bob") # type: ignore[call-overload] + +np.linalg.eig(AR_O) # type: ignore[arg-type] + +np.linalg.eigh(AR_O) # type: ignore[arg-type] +np.linalg.eigh(AR_O, UPLO="bob") # type: ignore[call-overload] + +np.linalg.svd(AR_O) # type: ignore[arg-type] + +np.linalg.cond(AR_O) # type: ignore[arg-type] +np.linalg.cond(AR_f8, p="bob") # type: ignore[arg-type] + +np.linalg.matrix_rank(AR_O) # type: ignore[arg-type] + +np.linalg.pinv(AR_O) # type: ignore[arg-type] + +np.linalg.slogdet(AR_O) # type: ignore[arg-type] + +np.linalg.det(AR_O) # type: ignore[arg-type] + +np.linalg.norm(AR_f8, ord="bob") # type: ignore[call-overload] + +np.linalg.multi_dot([AR_M]) # type: ignore[list-item] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ma.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ma.pyi new file mode 100644 index 0000000000000000000000000000000000000000..41306b23fe781286f6094d8279a114b1bef1baa7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ma.pyi @@ -0,0 +1,143 @@ +from typing import TypeAlias, TypeVar + +import numpy as np +import numpy.typing as npt +from numpy._typing import _Shape + +_ScalarT = TypeVar("_ScalarT", bound=np.generic) +MaskedArray: TypeAlias = np.ma.MaskedArray[_Shape, np.dtype[_ScalarT]] + +MAR_1d_f8: np.ma.MaskedArray[tuple[int], np.dtype[np.float64]] +MAR_b: MaskedArray[np.bool] +MAR_c: MaskedArray[np.complex128] +MAR_td64: MaskedArray[np.timedelta64] + +AR_b: npt.NDArray[np.bool] + +MAR_1d_f8.shape = (3, 1) # type: ignore[assignment] +MAR_1d_f8.dtype = np.bool # type: ignore[assignment] + +np.ma.min(MAR_1d_f8, axis=1.0) # type: ignore[call-overload] +np.ma.min(MAR_1d_f8, keepdims=1.0) # type: ignore[call-overload] +np.ma.min(MAR_1d_f8, out=1.0) # type: ignore[call-overload] +np.ma.min(MAR_1d_f8, fill_value=lambda x: 27) # type: ignore[call-overload] + +MAR_1d_f8.min(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.min(keepdims=1.0) # type: ignore[call-overload] +MAR_1d_f8.min(out=1.0) # type: ignore[call-overload] +MAR_1d_f8.min(fill_value=lambda x: 27) # type: ignore[call-overload] + +np.ma.max(MAR_1d_f8, axis=1.0) # type: ignore[call-overload] +np.ma.max(MAR_1d_f8, keepdims=1.0) # type: ignore[call-overload] +np.ma.max(MAR_1d_f8, out=1.0) # type: ignore[call-overload] +np.ma.max(MAR_1d_f8, fill_value=lambda x: 27) # type: ignore[call-overload] + +MAR_1d_f8.max(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.max(keepdims=1.0) # type: ignore[call-overload] +MAR_1d_f8.max(out=1.0) # type: ignore[call-overload] +MAR_1d_f8.max(fill_value=lambda x: 27) # type: ignore[call-overload] + +np.ma.ptp(MAR_1d_f8, axis=1.0) # type: ignore[call-overload] +np.ma.ptp(MAR_1d_f8, keepdims=1.0) # type: ignore[call-overload] +np.ma.ptp(MAR_1d_f8, out=1.0) # type: ignore[call-overload] +np.ma.ptp(MAR_1d_f8, fill_value=lambda x: 27) # type: ignore[call-overload] + +MAR_1d_f8.ptp(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.ptp(keepdims=1.0) # type: ignore[call-overload] +MAR_1d_f8.ptp(out=1.0) # type: ignore[call-overload] +MAR_1d_f8.ptp(fill_value=lambda x: 27) # type: ignore[call-overload] + +MAR_1d_f8.argmin(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.argmin(keepdims=1.0) # type: ignore[call-overload] +MAR_1d_f8.argmin(out=1.0) # type: ignore[call-overload] +MAR_1d_f8.argmin(fill_value=lambda x: 27) # type: ignore[call-overload] + +np.ma.argmin(MAR_1d_f8, axis=1.0) # type: ignore[call-overload] +np.ma.argmin(MAR_1d_f8, axis=(1,)) # type: ignore[call-overload] +np.ma.argmin(MAR_1d_f8, keepdims=1.0) # type: ignore[call-overload] +np.ma.argmin(MAR_1d_f8, out=1.0) # type: ignore[call-overload] +np.ma.argmin(MAR_1d_f8, fill_value=lambda x: 27) # type: ignore[call-overload] + +MAR_1d_f8.argmax(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.argmax(keepdims=1.0) # type: ignore[call-overload] +MAR_1d_f8.argmax(out=1.0) # type: ignore[call-overload] +MAR_1d_f8.argmax(fill_value=lambda x: 27) # type: ignore[call-overload] + +np.ma.argmax(MAR_1d_f8, axis=1.0) # type: ignore[call-overload] +np.ma.argmax(MAR_1d_f8, axis=(0,)) # type: ignore[call-overload] +np.ma.argmax(MAR_1d_f8, keepdims=1.0) # type: ignore[call-overload] +np.ma.argmax(MAR_1d_f8, out=1.0) # type: ignore[call-overload] +np.ma.argmax(MAR_1d_f8, fill_value=lambda x: 27) # type: ignore[call-overload] + +MAR_1d_f8.all(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.all(keepdims=1.0) # type: ignore[call-overload] +MAR_1d_f8.all(out=1.0) # type: ignore[call-overload] + +MAR_1d_f8.any(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.any(keepdims=1.0) # type: ignore[call-overload] +MAR_1d_f8.any(out=1.0) # type: ignore[call-overload] + +MAR_1d_f8.sort(axis=(0,1)) # type: ignore[arg-type] +MAR_1d_f8.sort(axis=None) # type: ignore[arg-type] +MAR_1d_f8.sort(kind='cabbage') # type: ignore[arg-type] +MAR_1d_f8.sort(order=lambda: 'cabbage') # type: ignore[arg-type] +MAR_1d_f8.sort(endwith='cabbage') # type: ignore[arg-type] +MAR_1d_f8.sort(fill_value=lambda: 'cabbage') # type: ignore[arg-type] +MAR_1d_f8.sort(stable='cabbage') # type: ignore[arg-type] +MAR_1d_f8.sort(stable=True) # type: ignore[arg-type] + +MAR_1d_f8.take(axis=1.0) # type: ignore[call-overload] +MAR_1d_f8.take(out=1) # type: ignore[call-overload] +MAR_1d_f8.take(mode="bob") # type: ignore[call-overload] + +np.ma.take(None) # type: ignore[call-overload] +np.ma.take(axis=1.0) # type: ignore[call-overload] +np.ma.take(out=1) # type: ignore[call-overload] +np.ma.take(mode="bob") # type: ignore[call-overload] + +MAR_1d_f8.partition(['cabbage']) # type: ignore[arg-type] +MAR_1d_f8.partition(axis=(0,1)) # type: ignore[arg-type, call-arg] +MAR_1d_f8.partition(kind='cabbage') # type: ignore[arg-type, call-arg] +MAR_1d_f8.partition(order=lambda: 'cabbage') # type: ignore[arg-type, call-arg] +MAR_1d_f8.partition(AR_b) # type: ignore[arg-type] + +MAR_1d_f8.argpartition(['cabbage']) # type: ignore[arg-type] +MAR_1d_f8.argpartition(axis=(0,1)) # type: ignore[arg-type, call-arg] +MAR_1d_f8.argpartition(kind='cabbage') # type: ignore[arg-type, call-arg] +MAR_1d_f8.argpartition(order=lambda: 'cabbage') # type: ignore[arg-type, call-arg] +MAR_1d_f8.argpartition(AR_b) # type: ignore[arg-type] + +np.ma.ndim(lambda: 'lambda') # type: ignore[arg-type] + +np.ma.size(AR_b, axis='0') # type: ignore[arg-type] + +MAR_1d_f8 >= (lambda x: 'mango') # type: ignore[operator] +MAR_1d_f8 > (lambda x: 'mango') # type: ignore[operator] +MAR_1d_f8 <= (lambda x: 'mango') # type: ignore[operator] +MAR_1d_f8 < (lambda x: 'mango') # type: ignore[operator] + +MAR_1d_f8.count(axis=0.) # type: ignore[call-overload] + +np.ma.count(MAR_1d_f8, axis=0.) # type: ignore[call-overload] + +MAR_1d_f8.put(4, 999, mode='flip') # type: ignore[arg-type] + +np.ma.put(MAR_1d_f8, 4, 999, mode='flip') # type: ignore[arg-type] + +np.ma.put([1,1,3], 0, 999) # type: ignore[arg-type] + +np.ma.compressed(lambda: 'compress me') # type: ignore[call-overload] + +np.ma.allequal(MAR_1d_f8, [1,2,3], fill_value=1.5) # type: ignore[arg-type] + +np.ma.allclose(MAR_1d_f8, [1,2,3], masked_equal=4.5) # type: ignore[arg-type] +np.ma.allclose(MAR_1d_f8, [1,2,3], rtol='.4') # type: ignore[arg-type] +np.ma.allclose(MAR_1d_f8, [1,2,3], atol='.5') # type: ignore[arg-type] + +MAR_1d_f8.__setmask__('mask') # type: ignore[arg-type] + +MAR_b *= 2 # type: ignore[arg-type] +MAR_c //= 2 # type: ignore[misc] +MAR_td64 **= 2 # type: ignore[misc] + +MAR_1d_f8.swapaxes(axis1=1, axis2=0) # type: ignore[call-arg] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/memmap.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/memmap.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3a4fc7df068961736ee8cf05412d5bcea5ad547a --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/memmap.pyi @@ -0,0 +1,5 @@ +import numpy as np + +with open("file.txt", "r") as f: + np.memmap(f) # type: ignore[call-overload] +np.memmap("test.txt", shape=[10, 5]) # type: ignore[call-overload] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/modules.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/modules.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c12a182807d395fd9d8c94882b826c23de703932 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/modules.pyi @@ -0,0 +1,17 @@ +import numpy as np + +np.testing.bob # type: ignore[attr-defined] +np.bob # type: ignore[attr-defined] + +# Stdlib modules in the namespace by accident +np.warnings # type: ignore[attr-defined] +np.sys # type: ignore[attr-defined] +np.os # type: ignore[attr-defined] +np.math # type: ignore[attr-defined] + +# Public sub-modules that are not imported to their parent module by default; +# e.g. one must first execute `import numpy.lib.recfunctions` +np.lib.recfunctions # type: ignore[attr-defined] + +np.__deprecated_attrs__ # type: ignore[attr-defined] +np.__expired_functions__ # type: ignore[attr-defined] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ndarray.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ndarray.pyi new file mode 100644 index 0000000000000000000000000000000000000000..2aeec0883e3ff1463c36d64ec233af91d83220bc --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ndarray.pyi @@ -0,0 +1,11 @@ +import numpy as np + +# Ban setting dtype since mutating the type of the array in place +# makes having ndarray be generic over dtype impossible. Generally +# users should use `ndarray.view` in this situation anyway. See +# +# https://github.com/numpy/numpy-stubs/issues/7 +# +# for more context. +float_array = np.array([1.0]) +float_array.dtype = np.bool # type: ignore[assignment, misc] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/nditer.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/nditer.pyi new file mode 100644 index 0000000000000000000000000000000000000000..cb64061e45fe5ebab770bdbd9850c027a1984fca --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/nditer.pyi @@ -0,0 +1,8 @@ +import numpy as np + +class Test(np.nditer): ... # type: ignore[misc] + +np.nditer([0, 1], flags=["test"]) # type: ignore[list-item] +np.nditer([0, 1], op_flags=[["test"]]) # type: ignore[list-item] +np.nditer([0, 1], itershape=(1.0,)) # type: ignore[arg-type] +np.nditer([0, 1], buffersize=1.0) # type: ignore[arg-type] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/nested_sequence.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/nested_sequence.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a28d3df3c749c4dcdbff10f266dbcd29d44f5bcc --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/nested_sequence.pyi @@ -0,0 +1,16 @@ +from collections.abc import Sequence +from numpy._typing import _NestedSequence + +a: Sequence[float] +b: list[complex] +c: tuple[str, ...] +d: int +e: str + +def func(a: _NestedSequence[int]) -> None: ... + +reveal_type(func(a)) # type: ignore[arg-type, misc] +reveal_type(func(b)) # type: ignore[arg-type, misc] +reveal_type(func(c)) # type: ignore[arg-type, misc] +reveal_type(func(d)) # type: ignore[arg-type, misc] +reveal_type(func(e)) # type: ignore[arg-type, misc] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/numerictypes.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/numerictypes.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a1fd47a6f479039639bf188399c8cd001f45d6c0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/numerictypes.pyi @@ -0,0 +1,5 @@ +import numpy as np + +np.isdtype(1, np.int64) # type: ignore[arg-type] + +np.issubdtype(1, np.int64) # type: ignore[arg-type] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/random.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/random.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1abf4b77653c5f49fa095e0785949aef16a34ced --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/random.pyi @@ -0,0 +1,62 @@ +import numpy as np +import numpy.typing as npt + +SEED_FLOAT: float = 457.3 +SEED_ARR_FLOAT: npt.NDArray[np.float64] = np.array([1.0, 2, 3, 4]) +SEED_ARRLIKE_FLOAT: list[float] = [1.0, 2.0, 3.0, 4.0] +SEED_SEED_SEQ: np.random.SeedSequence = np.random.SeedSequence(0) +SEED_STR: str = "String seeding not allowed" + +# default rng +np.random.default_rng(SEED_FLOAT) # type: ignore[arg-type] +np.random.default_rng(SEED_ARR_FLOAT) # type: ignore[arg-type] +np.random.default_rng(SEED_ARRLIKE_FLOAT) # type: ignore[arg-type] +np.random.default_rng(SEED_STR) # type: ignore[arg-type] + +# Seed Sequence +np.random.SeedSequence(SEED_FLOAT) # type: ignore[arg-type] +np.random.SeedSequence(SEED_ARR_FLOAT) # type: ignore[arg-type] +np.random.SeedSequence(SEED_ARRLIKE_FLOAT) # type: ignore[arg-type] +np.random.SeedSequence(SEED_SEED_SEQ) # type: ignore[arg-type] +np.random.SeedSequence(SEED_STR) # type: ignore[arg-type] + +seed_seq: np.random.bit_generator.SeedSequence = np.random.SeedSequence() +seed_seq.spawn(11.5) # type: ignore[arg-type] +seed_seq.generate_state(3.14) # type: ignore[arg-type] +seed_seq.generate_state(3, np.uint8) # type: ignore[arg-type] +seed_seq.generate_state(3, "uint8") # type: ignore[arg-type] +seed_seq.generate_state(3, "u1") # type: ignore[arg-type] +seed_seq.generate_state(3, np.uint16) # type: ignore[arg-type] +seed_seq.generate_state(3, "uint16") # type: ignore[arg-type] +seed_seq.generate_state(3, "u2") # type: ignore[arg-type] +seed_seq.generate_state(3, np.int32) # type: ignore[arg-type] +seed_seq.generate_state(3, "int32") # type: ignore[arg-type] +seed_seq.generate_state(3, "i4") # type: ignore[arg-type] + +# Bit Generators +np.random.MT19937(SEED_FLOAT) # type: ignore[arg-type] +np.random.MT19937(SEED_ARR_FLOAT) # type: ignore[arg-type] +np.random.MT19937(SEED_ARRLIKE_FLOAT) # type: ignore[arg-type] +np.random.MT19937(SEED_STR) # type: ignore[arg-type] + +np.random.PCG64(SEED_FLOAT) # type: ignore[arg-type] +np.random.PCG64(SEED_ARR_FLOAT) # type: ignore[arg-type] +np.random.PCG64(SEED_ARRLIKE_FLOAT) # type: ignore[arg-type] +np.random.PCG64(SEED_STR) # type: ignore[arg-type] + +np.random.Philox(SEED_FLOAT) # type: ignore[arg-type] +np.random.Philox(SEED_ARR_FLOAT) # type: ignore[arg-type] +np.random.Philox(SEED_ARRLIKE_FLOAT) # type: ignore[arg-type] +np.random.Philox(SEED_STR) # type: ignore[arg-type] + +np.random.SFC64(SEED_FLOAT) # type: ignore[arg-type] +np.random.SFC64(SEED_ARR_FLOAT) # type: ignore[arg-type] +np.random.SFC64(SEED_ARRLIKE_FLOAT) # type: ignore[arg-type] +np.random.SFC64(SEED_STR) # type: ignore[arg-type] + +# Generator +np.random.Generator(None) # type: ignore[arg-type] +np.random.Generator(12333283902830213) # type: ignore[arg-type] +np.random.Generator("OxFEEDF00D") # type: ignore[arg-type] +np.random.Generator([123, 234]) # type: ignore[arg-type] +np.random.Generator(np.array([123, 234], dtype="u4")) # type: ignore[arg-type] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/scalars.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/scalars.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bfbe9125e529dcfd301643e31a1ab4a179148794 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/scalars.pyi @@ -0,0 +1,87 @@ +import sys +import numpy as np + +f2: np.float16 +f8: np.float64 +c8: np.complex64 + +# Construction + +np.float32(3j) # type: ignore[arg-type] + +# Technically the following examples are valid NumPy code. But they +# are not considered a best practice, and people who wish to use the +# stubs should instead do +# +# np.array([1.0, 0.0, 0.0], dtype=np.float32) +# np.array([], dtype=np.complex64) +# +# See e.g. the discussion on the mailing list +# +# https://mail.python.org/pipermail/numpy-discussion/2020-April/080566.html +# +# and the issue +# +# https://github.com/numpy/numpy-stubs/issues/41 +# +# for more context. +np.float32([1.0, 0.0, 0.0]) # type: ignore[arg-type] +np.complex64([]) # type: ignore[call-overload] + +# TODO: protocols (can't check for non-existent protocols w/ __getattr__) + +np.datetime64(0) # type: ignore[call-overload] + +class A: + def __float__(self) -> float: ... + +np.int8(A()) # type: ignore[arg-type] +np.int16(A()) # type: ignore[arg-type] +np.int32(A()) # type: ignore[arg-type] +np.int64(A()) # type: ignore[arg-type] +np.uint8(A()) # type: ignore[arg-type] +np.uint16(A()) # type: ignore[arg-type] +np.uint32(A()) # type: ignore[arg-type] +np.uint64(A()) # type: ignore[arg-type] + +np.void("test") # type: ignore[call-overload] +np.void("test", dtype=None) # type: ignore[call-overload] + +np.generic(1) # type: ignore[abstract] +np.number(1) # type: ignore[abstract] +np.integer(1) # type: ignore[abstract] +np.inexact(1) # type: ignore[abstract] +np.character("test") # type: ignore[abstract] +np.flexible(b"test") # type: ignore[abstract] + +np.float64(value=0.0) # type: ignore[call-arg] +np.int64(value=0) # type: ignore[call-arg] +np.uint64(value=0) # type: ignore[call-arg] +np.complex128(value=0.0j) # type: ignore[call-overload] +np.str_(value='bob') # type: ignore[call-overload] +np.bytes_(value=b'test') # type: ignore[call-overload] +np.void(value=b'test') # type: ignore[call-overload] +np.bool(value=True) # type: ignore[call-overload] +np.datetime64(value="2019") # type: ignore[call-overload] +np.timedelta64(value=0) # type: ignore[call-overload] + +np.bytes_(b"hello", encoding='utf-8') # type: ignore[call-overload] +np.str_("hello", encoding='utf-8') # type: ignore[call-overload] + +f8.item(1) # type: ignore[call-overload] +f8.item((0, 1)) # type: ignore[arg-type] +f8.squeeze(axis=1) # type: ignore[arg-type] +f8.squeeze(axis=(0, 1)) # type: ignore[arg-type] +f8.transpose(1) # type: ignore[arg-type] + +def func(a: np.float32) -> None: ... + +func(f2) # type: ignore[arg-type] +func(f8) # type: ignore[arg-type] + +c8.__getnewargs__() # type: ignore[attr-defined] +f2.__getnewargs__() # type: ignore[attr-defined] +f2.hex() # type: ignore[attr-defined] +np.float16.fromhex("0x0.0p+0") # type: ignore[attr-defined] +f2.__trunc__() # type: ignore[attr-defined] +f2.__getformat__("float") # type: ignore[attr-defined] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/shape.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/shape.pyi new file mode 100644 index 0000000000000000000000000000000000000000..fea055583073d7e4be3972eb48951e255ceb4484 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/shape.pyi @@ -0,0 +1,6 @@ +from typing import Any +import numpy as np + +# test bounds of _ShapeT_co + +np.ndarray[tuple[str, str], Any] # type: ignore[type-var] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/stride_tricks.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/stride_tricks.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7f9a26b9692438eef717972fcc0024d2ab4dbcb9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/stride_tricks.pyi @@ -0,0 +1,9 @@ +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] + +np.lib.stride_tricks.as_strided(AR_f8, shape=8) # type: ignore[call-overload] +np.lib.stride_tricks.as_strided(AR_f8, strides=8) # type: ignore[call-overload] + +np.lib.stride_tricks.sliding_window_view(AR_f8, axis=(1,)) # type: ignore[call-overload] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/twodim_base.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/twodim_base.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d0f2b7ad8322c06072ae10b87ad60d7b6cb3035e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/twodim_base.pyi @@ -0,0 +1,32 @@ +from typing import Any, TypeVar + +import numpy as np +import numpy.typing as npt + +def func1(ar: npt.NDArray[Any], a: int) -> npt.NDArray[np.str_]: ... + +def func2(ar: npt.NDArray[Any], a: float) -> float: ... + +AR_b: npt.NDArray[np.bool] +AR_m: npt.NDArray[np.timedelta64] + +AR_LIKE_b: list[bool] + +np.eye(10, M=20.0) # type: ignore[call-overload] +np.eye(10, k=2.5, dtype=int) # type: ignore[call-overload] + +np.diag(AR_b, k=0.5) # type: ignore[call-overload] +np.diagflat(AR_b, k=0.5) # type: ignore[call-overload] + +np.tri(10, M=20.0) # type: ignore[call-overload] +np.tri(10, k=2.5, dtype=int) # type: ignore[call-overload] + +np.tril(AR_b, k=0.5) # type: ignore[call-overload] +np.triu(AR_b, k=0.5) # type: ignore[call-overload] + +np.vander(AR_m) # type: ignore[arg-type] + +np.histogram2d(AR_m) # type: ignore[call-overload] + +np.mask_indices(10, func1) # type: ignore[arg-type] +np.mask_indices(10, func2, 10.5) # type: ignore[arg-type] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ufunclike.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ufunclike.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e556e409ebbc90f1e57c096d53f656850c041b0f --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ufunclike.pyi @@ -0,0 +1,21 @@ +import numpy as np +import numpy.typing as npt + +AR_c: npt.NDArray[np.complex128] +AR_m: npt.NDArray[np.timedelta64] +AR_M: npt.NDArray[np.datetime64] +AR_O: npt.NDArray[np.object_] + +np.fix(AR_c) # type: ignore[arg-type] +np.fix(AR_m) # type: ignore[arg-type] +np.fix(AR_M) # type: ignore[arg-type] + +np.isposinf(AR_c) # type: ignore[arg-type] +np.isposinf(AR_m) # type: ignore[arg-type] +np.isposinf(AR_M) # type: ignore[arg-type] +np.isposinf(AR_O) # type: ignore[arg-type] + +np.isneginf(AR_c) # type: ignore[arg-type] +np.isneginf(AR_m) # type: ignore[arg-type] +np.isneginf(AR_M) # type: ignore[arg-type] +np.isneginf(AR_O) # type: ignore[arg-type] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ufuncs.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ufuncs.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1b1628d7da4431c5fc30c051a413052cbe6969ac --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/fail/ufuncs.pyi @@ -0,0 +1,17 @@ +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] + +np.sin.nin + "foo" # type: ignore[operator] +np.sin(1, foo="bar") # type: ignore[call-overload] + +np.abs(None) # type: ignore[call-overload] + +np.add(1, 1, 1) # type: ignore[call-overload] +np.add(1, 1, axis=0) # type: ignore[call-overload] + +np.matmul(AR_f8, AR_f8, where=True) # type: ignore[call-overload] + +np.frexp(AR_f8, out=None) # type: ignore[call-overload] +np.frexp(AR_f8, out=AR_f8) # type: ignore[call-overload] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/misc/extended_precision.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/misc/extended_precision.pyi new file mode 100644 index 0000000000000000000000000000000000000000..84b5f516bdde6435aaab6088f098fc62af76320e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/misc/extended_precision.pyi @@ -0,0 +1,9 @@ +import numpy as np +from numpy._typing import _96Bit, _128Bit + +from typing import assert_type + +assert_type(np.float96(), np.floating[_96Bit]) +assert_type(np.float128(), np.floating[_128Bit]) +assert_type(np.complex192(), np.complexfloating[_96Bit, _96Bit]) +assert_type(np.complex256(), np.complexfloating[_128Bit, _128Bit]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arithmetic.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arithmetic.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..732bd9038bc8aa9b17f94310d68286d7655ab7d0 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arithmetic.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/array_constructors.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/array_constructors.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80badfff428ba1327afa658ad366eb45f3996626 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/array_constructors.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/array_like.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/array_like.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf0ea7ba5b605815914915f1a2da92936ca43dc2 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/array_like.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayprint.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayprint.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e3c2d731ed5d3a845704ce0c78bad09f4fccdb Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayprint.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayterator.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayterator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63813fc153cdc8e4771aa8abf36418083b40f042 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayterator.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/bitwise_ops.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/bitwise_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6e7a828a8a3b676e9ca39b18279e9dc6efdca71 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/bitwise_ops.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/comparisons.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/comparisons.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ff04a38737bfcfb21a1fda9dc3c86036dbd493a Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/comparisons.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/dtype.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/dtype.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df326635a4d50a239cf22213751140f34c93023f Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/dtype.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/einsumfunc.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/einsumfunc.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..388abcd5beb1988a0c0ca578ad05b6ce5568a790 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/einsumfunc.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/flatiter.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/flatiter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf2dd6350e20d9e60f8518aafca059386b50b3dc Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/flatiter.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/fromnumeric.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/fromnumeric.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61fbd991cf1989d528f57f82e354d085ca60a1d Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/fromnumeric.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/index_tricks.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/index_tricks.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f5875bfbbeeca30b1e201058bddfa2cefdb4d9 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/index_tricks.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_user_array.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_user_array.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1e18187546abf38897485d0b0ed525cf5f66625 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_user_array.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_utils.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e8845562a83621836a32d92a650d8bce7c229f Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_utils.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_version.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_version.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7353398cd7869750dafa0e939914f1a8d81b9c8 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_version.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/literal.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/literal.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13fa20ebc43af6cfdbe9633d02f9dcc3113cf38e Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/literal.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ma.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ma.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b310d3062046311faee40409560e288d668d73 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ma.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/mod.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/mod.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..664b0cfe01fa57776d398de9491d1369993dcae2 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/mod.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/modules.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/modules.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2f9b8f4958b3514b20cebef7c2c254abe50d68f Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/modules.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/multiarray.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/multiarray.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f23588fa2b5036af0de0eb02f9338591c2e44cc Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/multiarray.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_conversion.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_conversion.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3010eda6abb7aad8b44cdcfad5e40c62cce6e6b Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_conversion.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_misc.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_misc.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0abb843d741833214375e25da77b03e34335ea9 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_misc.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_shape_manipulation.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_shape_manipulation.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7d8a887d1de7792ac53852f44c7d905389e98c4 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ndarray_shape_manipulation.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/nditer.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/nditer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce4e4979d77e8accab8793e4179979f7835abf10 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/nditer.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/numeric.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/numeric.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f447597a321b217c8caf660539ab73d4c065713c Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/numeric.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/numerictypes.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/numerictypes.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b54c58716ea8519715f0f809c32e2f256dd8e6 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/numerictypes.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/random.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/random.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df72e9394a51201f9e38a07588a82a7f2e21688a Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/random.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/recfunctions.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/recfunctions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..970b6f5e7fe8db016f97ca99f4d838f093a91256 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/recfunctions.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/scalars.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/scalars.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b23d4e1799ebf2c4cf031f6617a4560c3a024c20 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/scalars.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/shape.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/shape.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c28846a082fd9557d12e3699441339e6a1a32524 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/shape.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/simple.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/simple.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bff7d0981cb2d675ac6a4743694b849a1b48c355 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/simple.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/simple_py3.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/simple_py3.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8850d78f4e59863303b5de21ef2438c57aef5fbb Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/simple_py3.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufunc_config.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufunc_config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a6a919542642fcbb3169d12146eb6ed33783cf Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufunc_config.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufunclike.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufunclike.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfb4d3b0a9dcb38f0e2a5054e6a35e62e118ba38 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufunclike.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufuncs.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufuncs.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..383f6aa695a45ab9a8f05dcb2c3566059c8e33ba Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/ufuncs.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/warnings_and_errors.cpython-313.pyc b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/warnings_and_errors.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db8126456b4d3eac9f2fdab2b7a9a71d5ec76080 Binary files /dev/null and b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/__pycache__/warnings_and_errors.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arithmetic.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2901cf2b51b01d89702a01360e607fc8e4f338 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arithmetic.py @@ -0,0 +1,612 @@ +from __future__ import annotations + +from typing import Any, cast +import numpy as np +import numpy.typing as npt +import pytest + +c16 = np.complex128(1) +f8 = np.float64(1) +i8 = np.int64(1) +u8 = np.uint64(1) + +c8 = np.complex64(1) +f4 = np.float32(1) +i4 = np.int32(1) +u4 = np.uint32(1) + +dt = np.datetime64(1, "D") +td = np.timedelta64(1, "D") + +b_ = np.bool(1) + +b = bool(1) +c = complex(1) +f = float(1) +i = int(1) + + +class Object: + def __array__(self, dtype: np.typing.DTypeLike = None, + copy: bool | None = None) -> np.ndarray[Any, np.dtype[np.object_]]: + ret = np.empty((), dtype=object) + ret[()] = self + return ret + + def __sub__(self, value: Any) -> Object: + return self + + def __rsub__(self, value: Any) -> Object: + return self + + def __floordiv__(self, value: Any) -> Object: + return self + + def __rfloordiv__(self, value: Any) -> Object: + return self + + def __mul__(self, value: Any) -> Object: + return self + + def __rmul__(self, value: Any) -> Object: + return self + + def __pow__(self, value: Any) -> Object: + return self + + def __rpow__(self, value: Any) -> Object: + return self + + +AR_b: npt.NDArray[np.bool] = np.array([True]) +AR_u: npt.NDArray[np.uint32] = np.array([1], dtype=np.uint32) +AR_i: npt.NDArray[np.int64] = np.array([1]) +AR_integer: npt.NDArray[np.integer] = cast(npt.NDArray[np.integer], AR_i) +AR_f: npt.NDArray[np.float64] = np.array([1.0]) +AR_c: npt.NDArray[np.complex128] = np.array([1j]) +AR_m: npt.NDArray[np.timedelta64] = np.array([np.timedelta64(1, "D")]) +AR_M: npt.NDArray[np.datetime64] = np.array([np.datetime64(1, "D")]) +AR_O: npt.NDArray[np.object_] = np.array([Object()]) + +AR_LIKE_b = [True] +AR_LIKE_u = [np.uint32(1)] +AR_LIKE_i = [1] +AR_LIKE_f = [1.0] +AR_LIKE_c = [1j] +AR_LIKE_m = [np.timedelta64(1, "D")] +AR_LIKE_M = [np.datetime64(1, "D")] +AR_LIKE_O = [Object()] + +# Array subtractions + +AR_b - AR_LIKE_u +AR_b - AR_LIKE_i +AR_b - AR_LIKE_f +AR_b - AR_LIKE_c +AR_b - AR_LIKE_m +AR_b - AR_LIKE_O + +AR_LIKE_u - AR_b +AR_LIKE_i - AR_b +AR_LIKE_f - AR_b +AR_LIKE_c - AR_b +AR_LIKE_m - AR_b +AR_LIKE_M - AR_b +AR_LIKE_O - AR_b + +AR_u - AR_LIKE_b +AR_u - AR_LIKE_u +AR_u - AR_LIKE_i +AR_u - AR_LIKE_f +AR_u - AR_LIKE_c +AR_u - AR_LIKE_m +AR_u - AR_LIKE_O + +AR_LIKE_b - AR_u +AR_LIKE_u - AR_u +AR_LIKE_i - AR_u +AR_LIKE_f - AR_u +AR_LIKE_c - AR_u +AR_LIKE_m - AR_u +AR_LIKE_M - AR_u +AR_LIKE_O - AR_u + +AR_i - AR_LIKE_b +AR_i - AR_LIKE_u +AR_i - AR_LIKE_i +AR_i - AR_LIKE_f +AR_i - AR_LIKE_c +AR_i - AR_LIKE_m +AR_i - AR_LIKE_O + +AR_LIKE_b - AR_i +AR_LIKE_u - AR_i +AR_LIKE_i - AR_i +AR_LIKE_f - AR_i +AR_LIKE_c - AR_i +AR_LIKE_m - AR_i +AR_LIKE_M - AR_i +AR_LIKE_O - AR_i + +AR_f - AR_LIKE_b +AR_f - AR_LIKE_u +AR_f - AR_LIKE_i +AR_f - AR_LIKE_f +AR_f - AR_LIKE_c +AR_f - AR_LIKE_O + +AR_LIKE_b - AR_f +AR_LIKE_u - AR_f +AR_LIKE_i - AR_f +AR_LIKE_f - AR_f +AR_LIKE_c - AR_f +AR_LIKE_O - AR_f + +AR_c - AR_LIKE_b +AR_c - AR_LIKE_u +AR_c - AR_LIKE_i +AR_c - AR_LIKE_f +AR_c - AR_LIKE_c +AR_c - AR_LIKE_O + +AR_LIKE_b - AR_c +AR_LIKE_u - AR_c +AR_LIKE_i - AR_c +AR_LIKE_f - AR_c +AR_LIKE_c - AR_c +AR_LIKE_O - AR_c + +AR_m - AR_LIKE_b +AR_m - AR_LIKE_u +AR_m - AR_LIKE_i +AR_m - AR_LIKE_m + +AR_LIKE_b - AR_m +AR_LIKE_u - AR_m +AR_LIKE_i - AR_m +AR_LIKE_m - AR_m +AR_LIKE_M - AR_m + +AR_M - AR_LIKE_b +AR_M - AR_LIKE_u +AR_M - AR_LIKE_i +AR_M - AR_LIKE_m +AR_M - AR_LIKE_M + +AR_LIKE_M - AR_M + +AR_O - AR_LIKE_b +AR_O - AR_LIKE_u +AR_O - AR_LIKE_i +AR_O - AR_LIKE_f +AR_O - AR_LIKE_c +AR_O - AR_LIKE_O + +AR_LIKE_b - AR_O +AR_LIKE_u - AR_O +AR_LIKE_i - AR_O +AR_LIKE_f - AR_O +AR_LIKE_c - AR_O +AR_LIKE_O - AR_O + +AR_u += AR_b +AR_u += AR_u +AR_u += 1 # Allowed during runtime as long as the object is 0D and >=0 + +# Array floor division + +AR_b // AR_LIKE_b +AR_b // AR_LIKE_u +AR_b // AR_LIKE_i +AR_b // AR_LIKE_f +AR_b // AR_LIKE_O + +AR_LIKE_b // AR_b +AR_LIKE_u // AR_b +AR_LIKE_i // AR_b +AR_LIKE_f // AR_b +AR_LIKE_O // AR_b + +AR_u // AR_LIKE_b +AR_u // AR_LIKE_u +AR_u // AR_LIKE_i +AR_u // AR_LIKE_f +AR_u // AR_LIKE_O + +AR_LIKE_b // AR_u +AR_LIKE_u // AR_u +AR_LIKE_i // AR_u +AR_LIKE_f // AR_u +AR_LIKE_m // AR_u +AR_LIKE_O // AR_u + +AR_i // AR_LIKE_b +AR_i // AR_LIKE_u +AR_i // AR_LIKE_i +AR_i // AR_LIKE_f +AR_i // AR_LIKE_O + +AR_LIKE_b // AR_i +AR_LIKE_u // AR_i +AR_LIKE_i // AR_i +AR_LIKE_f // AR_i +AR_LIKE_m // AR_i +AR_LIKE_O // AR_i + +AR_f // AR_LIKE_b +AR_f // AR_LIKE_u +AR_f // AR_LIKE_i +AR_f // AR_LIKE_f +AR_f // AR_LIKE_O + +AR_LIKE_b // AR_f +AR_LIKE_u // AR_f +AR_LIKE_i // AR_f +AR_LIKE_f // AR_f +AR_LIKE_m // AR_f +AR_LIKE_O // AR_f + +AR_m // AR_LIKE_u +AR_m // AR_LIKE_i +AR_m // AR_LIKE_f +AR_m // AR_LIKE_m + +AR_LIKE_m // AR_m + +AR_m /= f +AR_m //= f +AR_m /= AR_f +AR_m /= AR_LIKE_f +AR_m //= AR_f +AR_m //= AR_LIKE_f + +AR_O // AR_LIKE_b +AR_O // AR_LIKE_u +AR_O // AR_LIKE_i +AR_O // AR_LIKE_f +AR_O // AR_LIKE_O + +AR_LIKE_b // AR_O +AR_LIKE_u // AR_O +AR_LIKE_i // AR_O +AR_LIKE_f // AR_O +AR_LIKE_O // AR_O + +# Inplace multiplication + +AR_b *= AR_LIKE_b + +AR_u *= AR_LIKE_b +AR_u *= AR_LIKE_u + +AR_i *= AR_LIKE_b +AR_i *= AR_LIKE_u +AR_i *= AR_LIKE_i + +AR_integer *= AR_LIKE_b +AR_integer *= AR_LIKE_u +AR_integer *= AR_LIKE_i + +AR_f *= AR_LIKE_b +AR_f *= AR_LIKE_u +AR_f *= AR_LIKE_i +AR_f *= AR_LIKE_f + +AR_c *= AR_LIKE_b +AR_c *= AR_LIKE_u +AR_c *= AR_LIKE_i +AR_c *= AR_LIKE_f +AR_c *= AR_LIKE_c + +AR_m *= AR_LIKE_b +AR_m *= AR_LIKE_u +AR_m *= AR_LIKE_i +AR_m *= AR_LIKE_f + +AR_O *= AR_LIKE_b +AR_O *= AR_LIKE_u +AR_O *= AR_LIKE_i +AR_O *= AR_LIKE_f +AR_O *= AR_LIKE_c +AR_O *= AR_LIKE_O + +# Inplace power + +AR_u **= AR_LIKE_b +AR_u **= AR_LIKE_u + +AR_i **= AR_LIKE_b +AR_i **= AR_LIKE_u +AR_i **= AR_LIKE_i + +AR_integer **= AR_LIKE_b +AR_integer **= AR_LIKE_u +AR_integer **= AR_LIKE_i + +AR_f **= AR_LIKE_b +AR_f **= AR_LIKE_u +AR_f **= AR_LIKE_i +AR_f **= AR_LIKE_f + +AR_c **= AR_LIKE_b +AR_c **= AR_LIKE_u +AR_c **= AR_LIKE_i +AR_c **= AR_LIKE_f +AR_c **= AR_LIKE_c + +AR_O **= AR_LIKE_b +AR_O **= AR_LIKE_u +AR_O **= AR_LIKE_i +AR_O **= AR_LIKE_f +AR_O **= AR_LIKE_c +AR_O **= AR_LIKE_O + +# unary ops + +-c16 +-c8 +-f8 +-f4 +-i8 +-i4 +with pytest.warns(RuntimeWarning): + -u8 + -u4 +-td +-AR_f + ++c16 ++c8 ++f8 ++f4 ++i8 ++i4 ++u8 ++u4 ++td ++AR_f + +abs(c16) +abs(c8) +abs(f8) +abs(f4) +abs(i8) +abs(i4) +abs(u8) +abs(u4) +abs(td) +abs(b_) +abs(AR_f) + +# Time structures + +dt + td +dt + i +dt + i4 +dt + i8 +dt - dt +dt - i +dt - i4 +dt - i8 + +td + td +td + i +td + i4 +td + i8 +td - td +td - i +td - i4 +td - i8 +td / f +td / f4 +td / f8 +td / td +td // td +td % td + + +# boolean + +b_ / b +b_ / b_ +b_ / i +b_ / i8 +b_ / i4 +b_ / u8 +b_ / u4 +b_ / f +b_ / f8 +b_ / f4 +b_ / c +b_ / c16 +b_ / c8 + +b / b_ +b_ / b_ +i / b_ +i8 / b_ +i4 / b_ +u8 / b_ +u4 / b_ +f / b_ +f8 / b_ +f4 / b_ +c / b_ +c16 / b_ +c8 / b_ + +# Complex + +c16 + c16 +c16 + f8 +c16 + i8 +c16 + c8 +c16 + f4 +c16 + i4 +c16 + b_ +c16 + b +c16 + c +c16 + f +c16 + i +c16 + AR_f + +c16 + c16 +f8 + c16 +i8 + c16 +c8 + c16 +f4 + c16 +i4 + c16 +b_ + c16 +b + c16 +c + c16 +f + c16 +i + c16 +AR_f + c16 + +c8 + c16 +c8 + f8 +c8 + i8 +c8 + c8 +c8 + f4 +c8 + i4 +c8 + b_ +c8 + b +c8 + c +c8 + f +c8 + i +c8 + AR_f + +c16 + c8 +f8 + c8 +i8 + c8 +c8 + c8 +f4 + c8 +i4 + c8 +b_ + c8 +b + c8 +c + c8 +f + c8 +i + c8 +AR_f + c8 + +# Float + +f8 + f8 +f8 + i8 +f8 + f4 +f8 + i4 +f8 + b_ +f8 + b +f8 + c +f8 + f +f8 + i +f8 + AR_f + +f8 + f8 +i8 + f8 +f4 + f8 +i4 + f8 +b_ + f8 +b + f8 +c + f8 +f + f8 +i + f8 +AR_f + f8 + +f4 + f8 +f4 + i8 +f4 + f4 +f4 + i4 +f4 + b_ +f4 + b +f4 + c +f4 + f +f4 + i +f4 + AR_f + +f8 + f4 +i8 + f4 +f4 + f4 +i4 + f4 +b_ + f4 +b + f4 +c + f4 +f + f4 +i + f4 +AR_f + f4 + +# Int + +i8 + i8 +i8 + u8 +i8 + i4 +i8 + u4 +i8 + b_ +i8 + b +i8 + c +i8 + f +i8 + i +i8 + AR_f + +u8 + u8 +u8 + i4 +u8 + u4 +u8 + b_ +u8 + b +u8 + c +u8 + f +u8 + i +u8 + AR_f + +i8 + i8 +u8 + i8 +i4 + i8 +u4 + i8 +b_ + i8 +b + i8 +c + i8 +f + i8 +i + i8 +AR_f + i8 + +u8 + u8 +i4 + u8 +u4 + u8 +b_ + u8 +b + u8 +c + u8 +f + u8 +i + u8 +AR_f + u8 + +i4 + i8 +i4 + i4 +i4 + i +i4 + b_ +i4 + b +i4 + AR_f + +u4 + i8 +u4 + i4 +u4 + u8 +u4 + u4 +u4 + i +u4 + b_ +u4 + b +u4 + AR_f + +i8 + i4 +i4 + i4 +i + i4 +b_ + i4 +b + i4 +AR_f + i4 + +i8 + u4 +i4 + u4 +u8 + u4 +u4 + u4 +b_ + u4 +b + u4 +i + u4 +AR_f + u4 diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/array_constructors.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/array_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..17b6fab93ad877b5eff31e65afadabcda8bd1c9e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/array_constructors.py @@ -0,0 +1,137 @@ +from typing import Any + +import numpy as np +import numpy.typing as npt + +class Index: + def __index__(self) -> int: + return 0 + + +class SubClass(npt.NDArray[np.float64]): + pass + + +def func(i: int, j: int, **kwargs: Any) -> SubClass: + return B + + +i8 = np.int64(1) + +A = np.array([1]) +B = A.view(SubClass).copy() +B_stack = np.array([[1], [1]]).view(SubClass) +C = [1] + +np.ndarray(Index()) +np.ndarray([Index()]) + +np.array(1, dtype=float) +np.array(1, copy=None) +np.array(1, order='F') +np.array(1, order=None) +np.array(1, subok=True) +np.array(1, ndmin=3) +np.array(1, str, copy=True, order='C', subok=False, ndmin=2) + +np.asarray(A) +np.asarray(B) +np.asarray(C) + +np.asanyarray(A) +np.asanyarray(B) +np.asanyarray(B, dtype=int) +np.asanyarray(C) + +np.ascontiguousarray(A) +np.ascontiguousarray(B) +np.ascontiguousarray(C) + +np.asfortranarray(A) +np.asfortranarray(B) +np.asfortranarray(C) + +np.require(A) +np.require(B) +np.require(B, dtype=int) +np.require(B, requirements=None) +np.require(B, requirements="E") +np.require(B, requirements=["ENSUREARRAY"]) +np.require(B, requirements={"F", "E"}) +np.require(B, requirements=["C", "OWNDATA"]) +np.require(B, requirements="W") +np.require(B, requirements="A") +np.require(C) + +np.linspace(0, 2) +np.linspace(0.5, [0, 1, 2]) +np.linspace([0, 1, 2], 3) +np.linspace(0j, 2) +np.linspace(0, 2, num=10) +np.linspace(0, 2, endpoint=True) +np.linspace(0, 2, retstep=True) +np.linspace(0j, 2j, retstep=True) +np.linspace(0, 2, dtype=bool) +np.linspace([0, 1], [2, 3], axis=Index()) + +np.logspace(0, 2, base=2) +np.logspace(0, 2, base=2) +np.logspace(0, 2, base=[1j, 2j], num=2) + +np.geomspace(1, 2) + +np.zeros_like(A) +np.zeros_like(C) +np.zeros_like(B) +np.zeros_like(B, dtype=np.int64) + +np.ones_like(A) +np.ones_like(C) +np.ones_like(B) +np.ones_like(B, dtype=np.int64) + +np.empty_like(A) +np.empty_like(C) +np.empty_like(B) +np.empty_like(B, dtype=np.int64) + +np.full_like(A, i8) +np.full_like(C, i8) +np.full_like(B, i8) +np.full_like(B, i8, dtype=np.int64) + +np.ones(1) +np.ones([1, 1, 1]) + +np.full(1, i8) +np.full([1, 1, 1], i8) + +np.indices([1, 2, 3]) +np.indices([1, 2, 3], sparse=True) + +np.fromfunction(func, (3, 5)) + +np.identity(10) + +np.atleast_1d(C) +np.atleast_1d(A) +np.atleast_1d(C, C) +np.atleast_1d(C, A) +np.atleast_1d(A, A) + +np.atleast_2d(C) + +np.atleast_3d(C) + +np.vstack([C, C]) +np.vstack([C, A]) +np.vstack([A, A]) + +np.hstack([C, C]) + +np.stack([C, C]) +np.stack([C, C], axis=0) +np.stack([C, C], out=B_stack) + +np.block([[C, C], [C, C]]) +np.block(A) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/array_like.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/array_like.py new file mode 100644 index 0000000000000000000000000000000000000000..264ec55da053f2de0b37fc6ee8b0f7d4d450e8fe --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/array_like.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from numpy._typing import NDArray, ArrayLike, _SupportsArray + +x1: ArrayLike = True +x2: ArrayLike = 5 +x3: ArrayLike = 1.0 +x4: ArrayLike = 1 + 1j +x5: ArrayLike = np.int8(1) +x6: ArrayLike = np.float64(1) +x7: ArrayLike = np.complex128(1) +x8: ArrayLike = np.array([1, 2, 3]) +x9: ArrayLike = [1, 2, 3] +x10: ArrayLike = (1, 2, 3) +x11: ArrayLike = "foo" +x12: ArrayLike = memoryview(b'foo') + + +class A: + def __array__(self, dtype: np.dtype | None = None) -> NDArray[np.float64]: + return np.array([1.0, 2.0, 3.0]) + + +x13: ArrayLike = A() + +scalar: _SupportsArray[np.dtype[np.int64]] = np.int64(1) +scalar.__array__() +array: _SupportsArray[np.dtype[np.int_]] = np.array(1) +array.__array__() + +a: _SupportsArray[np.dtype[np.float64]] = A() +a.__array__() +a.__array__() + +# Escape hatch for when you mean to make something like an object +# array. +object_array_scalar: object = (i for i in range(10)) +np.array(object_array_scalar) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arrayprint.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arrayprint.py new file mode 100644 index 0000000000000000000000000000000000000000..6c704c755570d1508424af92a0eb5aa1353666a0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arrayprint.py @@ -0,0 +1,37 @@ +import numpy as np + +AR = np.arange(10) +AR.setflags(write=False) + +with np.printoptions(): + np.set_printoptions( + precision=1, + threshold=2, + edgeitems=3, + linewidth=4, + suppress=False, + nanstr="Bob", + infstr="Bill", + formatter={}, + sign="+", + floatmode="unique", + ) + np.get_printoptions() + str(AR) + + np.array2string( + AR, + max_line_width=5, + precision=2, + suppress_small=True, + separator=";", + prefix="test", + threshold=5, + floatmode="fixed", + suffix="?", + legacy="1.13", + ) + np.format_float_scientific(1, precision=5) + np.format_float_positional(1, trim="k") + np.array_repr(AR) + np.array_str(AR) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arrayterator.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arrayterator.py new file mode 100644 index 0000000000000000000000000000000000000000..572be5e2fe29ba978b78c8b65b116b5b54a4d01a --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/arrayterator.py @@ -0,0 +1,27 @@ + +from __future__ import annotations + +from typing import Any +import numpy as np + +AR_i8: np.ndarray[Any, np.dtype[np.int_]] = np.arange(10) +ar_iter = np.lib.Arrayterator(AR_i8) + +ar_iter.var +ar_iter.buf_size +ar_iter.start +ar_iter.stop +ar_iter.step +ar_iter.shape +ar_iter.flat + +ar_iter.__array__() + +for i in ar_iter: + pass + +ar_iter[0] +ar_iter[...] +ar_iter[:] +ar_iter[0, 0, 0] +ar_iter[..., 0, :] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/bitwise_ops.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/bitwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..22a245d2180979e6e0fa0cb5780713aa3ac86439 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/bitwise_ops.py @@ -0,0 +1,131 @@ +import numpy as np + +i8 = np.int64(1) +u8 = np.uint64(1) + +i4 = np.int32(1) +u4 = np.uint32(1) + +b_ = np.bool(1) + +b = bool(1) +i = int(1) + +AR = np.array([0, 1, 2], dtype=np.int32) +AR.setflags(write=False) + + +i8 << i8 +i8 >> i8 +i8 | i8 +i8 ^ i8 +i8 & i8 + +i << AR +i >> AR +i | AR +i ^ AR +i & AR + +i8 << AR +i8 >> AR +i8 | AR +i8 ^ AR +i8 & AR + +i4 << i4 +i4 >> i4 +i4 | i4 +i4 ^ i4 +i4 & i4 + +i8 << i4 +i8 >> i4 +i8 | i4 +i8 ^ i4 +i8 & i4 + +i8 << i +i8 >> i +i8 | i +i8 ^ i +i8 & i + +i8 << b_ +i8 >> b_ +i8 | b_ +i8 ^ b_ +i8 & b_ + +i8 << b +i8 >> b +i8 | b +i8 ^ b +i8 & b + +u8 << u8 +u8 >> u8 +u8 | u8 +u8 ^ u8 +u8 & u8 + +u4 << u4 +u4 >> u4 +u4 | u4 +u4 ^ u4 +u4 & u4 + +u4 << i4 +u4 >> i4 +u4 | i4 +u4 ^ i4 +u4 & i4 + +u4 << i +u4 >> i +u4 | i +u4 ^ i +u4 & i + +u8 << b_ +u8 >> b_ +u8 | b_ +u8 ^ b_ +u8 & b_ + +u8 << b +u8 >> b +u8 | b +u8 ^ b +u8 & b + +b_ << b_ +b_ >> b_ +b_ | b_ +b_ ^ b_ +b_ & b_ + +b_ << AR +b_ >> AR +b_ | AR +b_ ^ AR +b_ & AR + +b_ << b +b_ >> b +b_ | b +b_ ^ b +b_ & b + +b_ << i +b_ >> i +b_ | i +b_ ^ i +b_ & i + +~i8 +~i4 +~u8 +~u4 +~b_ +~AR diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/comparisons.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/comparisons.py new file mode 100644 index 0000000000000000000000000000000000000000..a461d8b660da6f38ca6ddd416d4857d7b865aa9e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/comparisons.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +from typing import cast, Any +import numpy as np + +c16 = np.complex128() +f8 = np.float64() +i8 = np.int64() +u8 = np.uint64() + +c8 = np.complex64() +f4 = np.float32() +i4 = np.int32() +u4 = np.uint32() + +dt = np.datetime64(0, "D") +td = np.timedelta64(0, "D") + +b_ = np.bool() + +b = bool() +c = complex() +f = float() +i = int() + +SEQ = (0, 1, 2, 3, 4) + +AR_b: np.ndarray[Any, np.dtype[np.bool]] = np.array([True]) +AR_u: np.ndarray[Any, np.dtype[np.uint32]] = np.array([1], dtype=np.uint32) +AR_i: np.ndarray[Any, np.dtype[np.int_]] = np.array([1]) +AR_f: np.ndarray[Any, np.dtype[np.float64]] = np.array([1.0]) +AR_c: np.ndarray[Any, np.dtype[np.complex128]] = np.array([1.0j]) +AR_S: np.ndarray[Any, np.dtype[np.bytes_]] = np.array([b"a"], "S") +AR_T = cast(np.ndarray[Any, np.dtypes.StringDType], np.array(["a"], "T")) +AR_U: np.ndarray[Any, np.dtype[np.str_]] = np.array(["a"], "U") +AR_m: np.ndarray[Any, np.dtype[np.timedelta64]] = np.array([np.timedelta64("1")]) +AR_M: np.ndarray[Any, np.dtype[np.datetime64]] = np.array([np.datetime64("1")]) +AR_O: np.ndarray[Any, np.dtype[np.object_]] = np.array([1], dtype=object) + +# Arrays + +AR_b > AR_b +AR_b > AR_u +AR_b > AR_i +AR_b > AR_f +AR_b > AR_c + +AR_u > AR_b +AR_u > AR_u +AR_u > AR_i +AR_u > AR_f +AR_u > AR_c + +AR_i > AR_b +AR_i > AR_u +AR_i > AR_i +AR_i > AR_f +AR_i > AR_c + +AR_f > AR_b +AR_f > AR_u +AR_f > AR_i +AR_f > AR_f +AR_f > AR_c + +AR_c > AR_b +AR_c > AR_u +AR_c > AR_i +AR_c > AR_f +AR_c > AR_c + +AR_S > AR_S +AR_S > b"" + +AR_T > AR_T +AR_T > AR_U +AR_T > "" + +AR_U > AR_U +AR_U > AR_T +AR_U > "" + +AR_m > AR_b +AR_m > AR_u +AR_m > AR_i +AR_b > AR_m +AR_u > AR_m +AR_i > AR_m + +AR_M > AR_M + +AR_O > AR_O +1 > AR_O +AR_O > 1 + +# Time structures + +dt > dt + +td > td +td > i +td > i4 +td > i8 +td > AR_i +td > SEQ + +# boolean + +b_ > b +b_ > b_ +b_ > i +b_ > i8 +b_ > i4 +b_ > u8 +b_ > u4 +b_ > f +b_ > f8 +b_ > f4 +b_ > c +b_ > c16 +b_ > c8 +b_ > AR_i +b_ > SEQ + +# Complex + +c16 > c16 +c16 > f8 +c16 > i8 +c16 > c8 +c16 > f4 +c16 > i4 +c16 > b_ +c16 > b +c16 > c +c16 > f +c16 > i +c16 > AR_i +c16 > SEQ + +c16 > c16 +f8 > c16 +i8 > c16 +c8 > c16 +f4 > c16 +i4 > c16 +b_ > c16 +b > c16 +c > c16 +f > c16 +i > c16 +AR_i > c16 +SEQ > c16 + +c8 > c16 +c8 > f8 +c8 > i8 +c8 > c8 +c8 > f4 +c8 > i4 +c8 > b_ +c8 > b +c8 > c +c8 > f +c8 > i +c8 > AR_i +c8 > SEQ + +c16 > c8 +f8 > c8 +i8 > c8 +c8 > c8 +f4 > c8 +i4 > c8 +b_ > c8 +b > c8 +c > c8 +f > c8 +i > c8 +AR_i > c8 +SEQ > c8 + +# Float + +f8 > f8 +f8 > i8 +f8 > f4 +f8 > i4 +f8 > b_ +f8 > b +f8 > c +f8 > f +f8 > i +f8 > AR_i +f8 > SEQ + +f8 > f8 +i8 > f8 +f4 > f8 +i4 > f8 +b_ > f8 +b > f8 +c > f8 +f > f8 +i > f8 +AR_i > f8 +SEQ > f8 + +f4 > f8 +f4 > i8 +f4 > f4 +f4 > i4 +f4 > b_ +f4 > b +f4 > c +f4 > f +f4 > i +f4 > AR_i +f4 > SEQ + +f8 > f4 +i8 > f4 +f4 > f4 +i4 > f4 +b_ > f4 +b > f4 +c > f4 +f > f4 +i > f4 +AR_i > f4 +SEQ > f4 + +# Int + +i8 > i8 +i8 > u8 +i8 > i4 +i8 > u4 +i8 > b_ +i8 > b +i8 > c +i8 > f +i8 > i +i8 > AR_i +i8 > SEQ + +u8 > u8 +u8 > i4 +u8 > u4 +u8 > b_ +u8 > b +u8 > c +u8 > f +u8 > i +u8 > AR_i +u8 > SEQ + +i8 > i8 +u8 > i8 +i4 > i8 +u4 > i8 +b_ > i8 +b > i8 +c > i8 +f > i8 +i > i8 +AR_i > i8 +SEQ > i8 + +u8 > u8 +i4 > u8 +u4 > u8 +b_ > u8 +b > u8 +c > u8 +f > u8 +i > u8 +AR_i > u8 +SEQ > u8 + +i4 > i8 +i4 > i4 +i4 > i +i4 > b_ +i4 > b +i4 > AR_i +i4 > SEQ + +u4 > i8 +u4 > i4 +u4 > u8 +u4 > u4 +u4 > i +u4 > b_ +u4 > b +u4 > AR_i +u4 > SEQ + +i8 > i4 +i4 > i4 +i > i4 +b_ > i4 +b > i4 +AR_i > i4 +SEQ > i4 + +i8 > u4 +i4 > u4 +u8 > u4 +u4 > u4 +b_ > u4 +b > u4 +i > u4 +AR_i > u4 +SEQ > u4 diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/dtype.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..9f11518276c842f2601e3f1420020f00e3b775fb --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/dtype.py @@ -0,0 +1,57 @@ +import numpy as np + +dtype_obj = np.dtype(np.str_) +void_dtype_obj = np.dtype([("f0", np.float64), ("f1", np.float32)]) + +np.dtype(dtype=np.int64) +np.dtype(int) +np.dtype("int") +np.dtype(None) + +np.dtype((int, 2)) +np.dtype((int, (1,))) + +np.dtype({"names": ["a", "b"], "formats": [int, float]}) +np.dtype({"names": ["a"], "formats": [int], "titles": [object]}) +np.dtype({"names": ["a"], "formats": [int], "titles": [object()]}) + +np.dtype([("name", np.str_, 16), ("grades", np.float64, (2,)), ("age", "int32")]) + +np.dtype( + { + "names": ["a", "b"], + "formats": [int, float], + "itemsize": 9, + "aligned": False, + "titles": ["x", "y"], + "offsets": [0, 1], + } +) + +np.dtype((np.float64, float)) + + +class Test: + dtype = np.dtype(float) + + +np.dtype(Test()) + +# Methods and attributes +dtype_obj.base +dtype_obj.subdtype +dtype_obj.newbyteorder() +dtype_obj.type +dtype_obj.name +dtype_obj.names + +dtype_obj * 0 +dtype_obj * 2 + +0 * dtype_obj +2 * dtype_obj + +void_dtype_obj["f0"] +void_dtype_obj[0] +void_dtype_obj[["f0", "f1"]] +void_dtype_obj[["f0"]] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/einsumfunc.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/einsumfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..429764e67eccc7855d363da20d432fdb45e66971 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/einsumfunc.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +AR_LIKE_b = [True, True, True] +AR_LIKE_u = [np.uint32(1), np.uint32(2), np.uint32(3)] +AR_LIKE_i = [1, 2, 3] +AR_LIKE_f = [1.0, 2.0, 3.0] +AR_LIKE_c = [1j, 2j, 3j] +AR_LIKE_U = ["1", "2", "3"] + +OUT_f: np.ndarray[Any, np.dtype[np.float64]] = np.empty(3, dtype=np.float64) +OUT_c: np.ndarray[Any, np.dtype[np.complex128]] = np.empty(3, dtype=np.complex128) + +np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_b) +np.einsum("i,i->i", AR_LIKE_u, AR_LIKE_u) +np.einsum("i,i->i", AR_LIKE_i, AR_LIKE_i) +np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f) +np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c) +np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_i) +np.einsum("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c) + +np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, dtype="c16") +np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe") +np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, out=OUT_c) +np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=int, casting="unsafe", out=OUT_f) + +np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_b) +np.einsum_path("i,i->i", AR_LIKE_u, AR_LIKE_u) +np.einsum_path("i,i->i", AR_LIKE_i, AR_LIKE_i) +np.einsum_path("i,i->i", AR_LIKE_f, AR_LIKE_f) +np.einsum_path("i,i->i", AR_LIKE_c, AR_LIKE_c) +np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i) +np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/flatiter.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/flatiter.py new file mode 100644 index 0000000000000000000000000000000000000000..e64e4261b8e70587c872f25875fadfb765120986 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/flatiter.py @@ -0,0 +1,19 @@ +import numpy as np + +a = np.empty((2, 2)).flat + +a.base +a.copy() +a.coords +a.index +iter(a) +next(a) +a[0] +a[[0, 1, 2]] +a[...] +a[:] +a.__array__() +a.__array__(np.dtype(np.float64)) + +b = np.array([1]).flat +a[b] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/fromnumeric.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/fromnumeric.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc2bcfd8b50f235081e737c3420e1db34191637 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/fromnumeric.py @@ -0,0 +1,272 @@ +"""Tests for :mod:`numpy._core.fromnumeric`.""" + +import numpy as np + +A = np.array(True, ndmin=2, dtype=bool) +B = np.array(1.0, ndmin=2, dtype=np.float32) +A.setflags(write=False) +B.setflags(write=False) + +a = np.bool(True) +b = np.float32(1.0) +c = 1.0 +d = np.array(1.0, dtype=np.float32) # writeable + +np.take(a, 0) +np.take(b, 0) +np.take(c, 0) +np.take(A, 0) +np.take(B, 0) +np.take(A, [0]) +np.take(B, [0]) + +np.reshape(a, 1) +np.reshape(b, 1) +np.reshape(c, 1) +np.reshape(A, 1) +np.reshape(B, 1) + +np.choose(a, [True, True]) +np.choose(A, [1.0, 1.0]) + +np.repeat(a, 1) +np.repeat(b, 1) +np.repeat(c, 1) +np.repeat(A, 1) +np.repeat(B, 1) + +np.swapaxes(A, 0, 0) +np.swapaxes(B, 0, 0) + +np.transpose(a) +np.transpose(b) +np.transpose(c) +np.transpose(A) +np.transpose(B) + +np.partition(a, 0, axis=None) +np.partition(b, 0, axis=None) +np.partition(c, 0, axis=None) +np.partition(A, 0) +np.partition(B, 0) + +np.argpartition(a, 0) +np.argpartition(b, 0) +np.argpartition(c, 0) +np.argpartition(A, 0) +np.argpartition(B, 0) + +np.sort(A, 0) +np.sort(B, 0) + +np.argsort(A, 0) +np.argsort(B, 0) + +np.argmax(A) +np.argmax(B) +np.argmax(A, axis=0) +np.argmax(B, axis=0) + +np.argmin(A) +np.argmin(B) +np.argmin(A, axis=0) +np.argmin(B, axis=0) + +np.searchsorted(A[0], 0) +np.searchsorted(B[0], 0) +np.searchsorted(A[0], [0]) +np.searchsorted(B[0], [0]) + +np.resize(a, (5, 5)) +np.resize(b, (5, 5)) +np.resize(c, (5, 5)) +np.resize(A, (5, 5)) +np.resize(B, (5, 5)) + +np.squeeze(a) +np.squeeze(b) +np.squeeze(c) +np.squeeze(A) +np.squeeze(B) + +np.diagonal(A) +np.diagonal(B) + +np.trace(A) +np.trace(B) + +np.ravel(a) +np.ravel(b) +np.ravel(c) +np.ravel(A) +np.ravel(B) + +np.nonzero(A) +np.nonzero(B) + +np.shape(a) +np.shape(b) +np.shape(c) +np.shape(A) +np.shape(B) + +np.compress([True], a) +np.compress([True], b) +np.compress([True], c) +np.compress([True], A) +np.compress([True], B) + +np.clip(a, 0, 1.0) +np.clip(b, -1, 1) +np.clip(a, 0, None) +np.clip(b, None, 1) +np.clip(c, 0, 1) +np.clip(A, 0, 1) +np.clip(B, 0, 1) +np.clip(B, [0, 1], [1, 2]) + +np.sum(a) +np.sum(b) +np.sum(c) +np.sum(A) +np.sum(B) +np.sum(A, axis=0) +np.sum(B, axis=0) + +np.all(a) +np.all(b) +np.all(c) +np.all(A) +np.all(B) +np.all(A, axis=0) +np.all(B, axis=0) +np.all(A, keepdims=True) +np.all(B, keepdims=True) + +np.any(a) +np.any(b) +np.any(c) +np.any(A) +np.any(B) +np.any(A, axis=0) +np.any(B, axis=0) +np.any(A, keepdims=True) +np.any(B, keepdims=True) + +np.cumsum(a) +np.cumsum(b) +np.cumsum(c) +np.cumsum(A) +np.cumsum(B) + +np.cumulative_sum(a) +np.cumulative_sum(b) +np.cumulative_sum(c) +np.cumulative_sum(A, axis=0) +np.cumulative_sum(B, axis=0) + +np.ptp(b) +np.ptp(c) +np.ptp(B) +np.ptp(B, axis=0) +np.ptp(B, keepdims=True) + +np.amax(a) +np.amax(b) +np.amax(c) +np.amax(A) +np.amax(B) +np.amax(A, axis=0) +np.amax(B, axis=0) +np.amax(A, keepdims=True) +np.amax(B, keepdims=True) + +np.amin(a) +np.amin(b) +np.amin(c) +np.amin(A) +np.amin(B) +np.amin(A, axis=0) +np.amin(B, axis=0) +np.amin(A, keepdims=True) +np.amin(B, keepdims=True) + +np.prod(a) +np.prod(b) +np.prod(c) +np.prod(A) +np.prod(B) +np.prod(a, dtype=None) +np.prod(A, dtype=None) +np.prod(A, axis=0) +np.prod(B, axis=0) +np.prod(A, keepdims=True) +np.prod(B, keepdims=True) +np.prod(b, out=d) +np.prod(B, out=d) + +np.cumprod(a) +np.cumprod(b) +np.cumprod(c) +np.cumprod(A) +np.cumprod(B) + +np.cumulative_prod(a) +np.cumulative_prod(b) +np.cumulative_prod(c) +np.cumulative_prod(A, axis=0) +np.cumulative_prod(B, axis=0) + +np.ndim(a) +np.ndim(b) +np.ndim(c) +np.ndim(A) +np.ndim(B) + +np.size(a) +np.size(b) +np.size(c) +np.size(A) +np.size(B) + +np.around(a) +np.around(b) +np.around(c) +np.around(A) +np.around(B) + +np.mean(a) +np.mean(b) +np.mean(c) +np.mean(A) +np.mean(B) +np.mean(A, axis=0) +np.mean(B, axis=0) +np.mean(A, keepdims=True) +np.mean(B, keepdims=True) +np.mean(b, out=d) +np.mean(B, out=d) + +np.std(a) +np.std(b) +np.std(c) +np.std(A) +np.std(B) +np.std(A, axis=0) +np.std(B, axis=0) +np.std(A, keepdims=True) +np.std(B, keepdims=True) +np.std(b, out=d) +np.std(B, out=d) + +np.var(a) +np.var(b) +np.var(c) +np.var(A) +np.var(B) +np.var(A, axis=0) +np.var(B, axis=0) +np.var(A, keepdims=True) +np.var(B, keepdims=True) +np.var(b, out=d) +np.var(B, out=d) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/index_tricks.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/index_tricks.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc4ff2f314aa2417cd5205a86386d4bd6211273 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/index_tricks.py @@ -0,0 +1,60 @@ +from __future__ import annotations +from typing import Any +import numpy as np + +AR_LIKE_b = [[True, True], [True, True]] +AR_LIKE_i = [[1, 2], [3, 4]] +AR_LIKE_f = [[1.0, 2.0], [3.0, 4.0]] +AR_LIKE_U = [["1", "2"], ["3", "4"]] + +AR_i8: np.ndarray[Any, np.dtype[np.int64]] = np.array(AR_LIKE_i, dtype=np.int64) + +np.ndenumerate(AR_i8) +np.ndenumerate(AR_LIKE_f) +np.ndenumerate(AR_LIKE_U) + +next(np.ndenumerate(AR_i8)) +next(np.ndenumerate(AR_LIKE_f)) +next(np.ndenumerate(AR_LIKE_U)) + +iter(np.ndenumerate(AR_i8)) +iter(np.ndenumerate(AR_LIKE_f)) +iter(np.ndenumerate(AR_LIKE_U)) + +iter(np.ndindex(1, 2, 3)) +next(np.ndindex(1, 2, 3)) + +np.unravel_index([22, 41, 37], (7, 6)) +np.unravel_index([31, 41, 13], (7, 6), order='F') +np.unravel_index(1621, (6, 7, 8, 9)) + +np.ravel_multi_index(AR_LIKE_i, (7, 6)) +np.ravel_multi_index(AR_LIKE_i, (7, 6), order='F') +np.ravel_multi_index(AR_LIKE_i, (4, 6), mode='clip') +np.ravel_multi_index(AR_LIKE_i, (4, 4), mode=('clip', 'wrap')) +np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9)) + +np.mgrid[1:1:2] +np.mgrid[1:1:2, None:10] + +np.ogrid[1:1:2] +np.ogrid[1:1:2, None:10] + +np.index_exp[0:1] +np.index_exp[0:1, None:3] +np.index_exp[0, 0:1, ..., [0, 1, 3]] + +np.s_[0:1] +np.s_[0:1, None:3] +np.s_[0, 0:1, ..., [0, 1, 3]] + +np.ix_(AR_LIKE_b[0]) +np.ix_(AR_LIKE_i[0], AR_LIKE_f[0]) +np.ix_(AR_i8[0]) + +np.fill_diagonal(AR_i8, 5) + +np.diag_indices(4) +np.diag_indices(2, 3) + +np.diag_indices_from(AR_i8) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_user_array.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_user_array.py new file mode 100644 index 0000000000000000000000000000000000000000..62b7e85d7ff1e49fecbbe88a921c931a5b8ae745 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_user_array.py @@ -0,0 +1,22 @@ +"""Based on the `if __name__ == "__main__"` test code in `lib/_user_array_impl.py`.""" + +from __future__ import annotations + +import numpy as np +from numpy.lib.user_array import container + +N = 10_000 +W = H = int(N**0.5) + +a: np.ndarray[tuple[int, int], np.dtype[np.int32]] +ua: container[tuple[int, int], np.dtype[np.int32]] + +a = np.arange(N, dtype=np.int32).reshape(W, H) +ua = container(a) + +ua_small: container[tuple[int, int], np.dtype[np.int32]] = ua[:3, :5] +ua_small[0, 0] = 10 + +ua_bool: container[tuple[int, int], np.dtype[np.bool]] = ua_small > 1 + +# shape: tuple[int, int] = np.shape(ua) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_utils.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b3381e13d26adf44130fc62aa8bfaacfe6a658 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_utils.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from io import StringIO + +import numpy as np +import numpy.lib.array_utils as array_utils + +FILE = StringIO() +AR = np.arange(10, dtype=np.float64) + + +def func(a: int) -> bool: + return True + + +array_utils.byte_bounds(AR) +array_utils.byte_bounds(np.float64()) + +np.info(1, output=FILE) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_version.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_version.py new file mode 100644 index 0000000000000000000000000000000000000000..f3825eca524795f4d9742873a773cd7749636e2a --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/lib_version.py @@ -0,0 +1,18 @@ +from numpy.lib import NumpyVersion + +version = NumpyVersion("1.8.0") + +version.vstring +version.version +version.major +version.minor +version.bugfix +version.pre_release +version.is_devversion + +version == version +version != version +version < "1.8.0" +version <= version +version > version +version >= "1.8.0" diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/literal.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/literal.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fa476210e31968bbdd5106dc394e5b933e072a --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/literal.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING +from functools import partial + +import pytest +import numpy as np + +if TYPE_CHECKING: + from collections.abc import Callable + +AR = np.array(0) +AR.setflags(write=False) + +KACF = frozenset({None, "K", "A", "C", "F"}) +ACF = frozenset({None, "A", "C", "F"}) +CF = frozenset({None, "C", "F"}) + +order_list: list[tuple[frozenset[str | None], Callable[..., Any]]] = [ + (KACF, AR.tobytes), + (KACF, partial(AR.astype, int)), + (KACF, AR.copy), + (ACF, partial(AR.reshape, 1)), + (KACF, AR.flatten), + (KACF, AR.ravel), + (KACF, partial(np.array, 1)), + # NOTE: __call__ is needed due to mypy bugs (#17620, #17631) + (KACF, partial(np.ndarray.__call__, 1)), + (CF, partial(np.zeros.__call__, 1)), + (CF, partial(np.ones.__call__, 1)), + (CF, partial(np.empty.__call__, 1)), + (CF, partial(np.full, 1, 1)), + (KACF, partial(np.zeros_like, AR)), + (KACF, partial(np.ones_like, AR)), + (KACF, partial(np.empty_like, AR)), + (KACF, partial(np.full_like, AR, 1)), + (KACF, partial(np.add.__call__, 1, 1)), # i.e. np.ufunc.__call__ + (ACF, partial(np.reshape, AR, 1)), + (KACF, partial(np.ravel, AR)), + (KACF, partial(np.asarray, 1)), + (KACF, partial(np.asanyarray, 1)), +] + +for order_set, func in order_list: + for order in order_set: + func(order=order) + + invalid_orders = KACF - order_set + for order in invalid_orders: + with pytest.raises(ValueError): + func(order=order) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ma.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ma.py new file mode 100644 index 0000000000000000000000000000000000000000..e7915a583210aa43d96db0eb68fa83f3ead152ac --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ma.py @@ -0,0 +1,174 @@ +from typing import Any, TypeAlias, TypeVar, cast + +import numpy as np +import numpy.typing as npt +from numpy._typing import _Shape + +_ScalarT = TypeVar("_ScalarT", bound=np.generic) +MaskedArray: TypeAlias = np.ma.MaskedArray[_Shape, np.dtype[_ScalarT]] + +MAR_b: MaskedArray[np.bool] = np.ma.MaskedArray([True]) +MAR_u: MaskedArray[np.uint32] = np.ma.MaskedArray([1], dtype=np.uint32) +MAR_i: MaskedArray[np.int64] = np.ma.MaskedArray([1]) +MAR_f: MaskedArray[np.float64] = np.ma.MaskedArray([1.0]) +MAR_c: MaskedArray[np.complex128] = np.ma.MaskedArray([1j]) +MAR_td64: MaskedArray[np.timedelta64] = np.ma.MaskedArray([np.timedelta64(1, "D")]) +MAR_M_dt64: MaskedArray[np.datetime64] = np.ma.MaskedArray([np.datetime64(1, "D")]) +MAR_S: MaskedArray[np.bytes_] = np.ma.MaskedArray([b'foo'], dtype=np.bytes_) +MAR_U: MaskedArray[np.str_] = np.ma.MaskedArray(['foo'], dtype=np.str_) +MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType], + np.ma.MaskedArray(["a"], dtype="T")) + +AR_b: npt.NDArray[np.bool] = np.array([True, False, True]) + +AR_LIKE_b = [True] +AR_LIKE_u = [np.uint32(1)] +AR_LIKE_i = [1] +AR_LIKE_f = [1.0] +AR_LIKE_c = [1j] +AR_LIKE_m = [np.timedelta64(1, "D")] +AR_LIKE_M = [np.datetime64(1, "D")] + +MAR_f.mask = AR_b +MAR_f.mask = np.False_ + +# Inplace addition + +MAR_b += AR_LIKE_b + +MAR_u += AR_LIKE_b +MAR_u += AR_LIKE_u + +MAR_i += AR_LIKE_b +MAR_i += 2 +MAR_i += AR_LIKE_i + +MAR_f += AR_LIKE_b +MAR_f += 2 +MAR_f += AR_LIKE_u +MAR_f += AR_LIKE_i +MAR_f += AR_LIKE_f + +MAR_c += AR_LIKE_b +MAR_c += AR_LIKE_u +MAR_c += AR_LIKE_i +MAR_c += AR_LIKE_f +MAR_c += AR_LIKE_c + +MAR_td64 += AR_LIKE_b +MAR_td64 += AR_LIKE_u +MAR_td64 += AR_LIKE_i +MAR_td64 += AR_LIKE_m +MAR_M_dt64 += AR_LIKE_b +MAR_M_dt64 += AR_LIKE_u +MAR_M_dt64 += AR_LIKE_i +MAR_M_dt64 += AR_LIKE_m + +MAR_S += b'snakes' +MAR_U += 'snakes' +MAR_T += 'snakes' + +# Inplace subtraction + +MAR_u -= AR_LIKE_b +MAR_u -= AR_LIKE_u + +MAR_i -= AR_LIKE_b +MAR_i -= AR_LIKE_i + +MAR_f -= AR_LIKE_b +MAR_f -= AR_LIKE_u +MAR_f -= AR_LIKE_i +MAR_f -= AR_LIKE_f + +MAR_c -= AR_LIKE_b +MAR_c -= AR_LIKE_u +MAR_c -= AR_LIKE_i +MAR_c -= AR_LIKE_f +MAR_c -= AR_LIKE_c + +MAR_td64 -= AR_LIKE_b +MAR_td64 -= AR_LIKE_u +MAR_td64 -= AR_LIKE_i +MAR_td64 -= AR_LIKE_m +MAR_M_dt64 -= AR_LIKE_b +MAR_M_dt64 -= AR_LIKE_u +MAR_M_dt64 -= AR_LIKE_i +MAR_M_dt64 -= AR_LIKE_m + +# Inplace floor division + +MAR_f //= AR_LIKE_b +MAR_f //= 2 +MAR_f //= AR_LIKE_u +MAR_f //= AR_LIKE_i +MAR_f //= AR_LIKE_f + +MAR_td64 //= AR_LIKE_i + +# Inplace true division + +MAR_f /= AR_LIKE_b +MAR_f /= 2 +MAR_f /= AR_LIKE_u +MAR_f /= AR_LIKE_i +MAR_f /= AR_LIKE_f + +MAR_c /= AR_LIKE_b +MAR_c /= AR_LIKE_u +MAR_c /= AR_LIKE_i +MAR_c /= AR_LIKE_f +MAR_c /= AR_LIKE_c + +MAR_td64 /= AR_LIKE_i + +# Inplace multiplication + +MAR_b *= AR_LIKE_b + +MAR_u *= AR_LIKE_b +MAR_u *= AR_LIKE_u + +MAR_i *= AR_LIKE_b +MAR_i *= 2 +MAR_i *= AR_LIKE_i + +MAR_f *= AR_LIKE_b +MAR_f *= 2 +MAR_f *= AR_LIKE_u +MAR_f *= AR_LIKE_i +MAR_f *= AR_LIKE_f + +MAR_c *= AR_LIKE_b +MAR_c *= AR_LIKE_u +MAR_c *= AR_LIKE_i +MAR_c *= AR_LIKE_f +MAR_c *= AR_LIKE_c + +MAR_td64 *= AR_LIKE_b +MAR_td64 *= AR_LIKE_u +MAR_td64 *= AR_LIKE_i +MAR_td64 *= AR_LIKE_f + +MAR_S *= 2 +MAR_U *= 2 +MAR_T *= 2 + +# Inplace power + +MAR_u **= AR_LIKE_b +MAR_u **= AR_LIKE_u + +MAR_i **= AR_LIKE_b +MAR_i **= AR_LIKE_i + +MAR_f **= AR_LIKE_b +MAR_f **= AR_LIKE_u +MAR_f **= AR_LIKE_i +MAR_f **= AR_LIKE_f + +MAR_c **= AR_LIKE_b +MAR_c **= AR_LIKE_u +MAR_c **= AR_LIKE_i +MAR_c **= AR_LIKE_f +MAR_c **= AR_LIKE_c diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/mod.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/mod.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7e6cd85c73cfce452c3bd8bf91f174f134966c --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/mod.py @@ -0,0 +1,149 @@ +import numpy as np + +f8 = np.float64(1) +i8 = np.int64(1) +u8 = np.uint64(1) + +f4 = np.float32(1) +i4 = np.int32(1) +u4 = np.uint32(1) + +td = np.timedelta64(1, "D") +b_ = np.bool(1) + +b = bool(1) +f = float(1) +i = int(1) + +AR = np.array([1], dtype=np.bool) +AR.setflags(write=False) + +AR2 = np.array([1], dtype=np.timedelta64) +AR2.setflags(write=False) + +# Time structures + +td % td +td % AR2 +AR2 % td + +divmod(td, td) +divmod(td, AR2) +divmod(AR2, td) + +# Bool + +b_ % b +b_ % i +b_ % f +b_ % b_ +b_ % i8 +b_ % u8 +b_ % f8 +b_ % AR + +divmod(b_, b) +divmod(b_, i) +divmod(b_, f) +divmod(b_, b_) +divmod(b_, i8) +divmod(b_, u8) +divmod(b_, f8) +divmod(b_, AR) + +b % b_ +i % b_ +f % b_ +b_ % b_ +i8 % b_ +u8 % b_ +f8 % b_ +AR % b_ + +divmod(b, b_) +divmod(i, b_) +divmod(f, b_) +divmod(b_, b_) +divmod(i8, b_) +divmod(u8, b_) +divmod(f8, b_) +divmod(AR, b_) + +# int + +i8 % b +i8 % i +i8 % f +i8 % i8 +i8 % f8 +i4 % i8 +i4 % f8 +i4 % i4 +i4 % f4 +i8 % AR + +divmod(i8, b) +divmod(i8, i) +divmod(i8, f) +divmod(i8, i8) +divmod(i8, f8) +divmod(i8, i4) +divmod(i8, f4) +divmod(i4, i4) +divmod(i4, f4) +divmod(i8, AR) + +b % i8 +i % i8 +f % i8 +i8 % i8 +f8 % i8 +i8 % i4 +f8 % i4 +i4 % i4 +f4 % i4 +AR % i8 + +divmod(b, i8) +divmod(i, i8) +divmod(f, i8) +divmod(i8, i8) +divmod(f8, i8) +divmod(i4, i8) +divmod(f4, i8) +divmod(i4, i4) +divmod(f4, i4) +divmod(AR, i8) + +# float + +f8 % b +f8 % i +f8 % f +i8 % f4 +f4 % f4 +f8 % AR + +divmod(f8, b) +divmod(f8, i) +divmod(f8, f) +divmod(f8, f8) +divmod(f8, f4) +divmod(f4, f4) +divmod(f8, AR) + +b % f8 +i % f8 +f % f8 +f8 % f8 +f8 % f8 +f4 % f4 +AR % f8 + +divmod(b, f8) +divmod(i, f8) +divmod(f, f8) +divmod(f8, f8) +divmod(f4, f8) +divmod(f4, f4) +divmod(AR, f8) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/modules.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2fd4b7e7c3a2d0ed1be39928edf5b2d591e16c --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/modules.py @@ -0,0 +1,45 @@ +import numpy as np +from numpy import f2py + +np.char +np.ctypeslib +np.emath +np.fft +np.lib +np.linalg +np.ma +np.matrixlib +np.polynomial +np.random +np.rec +np.strings +np.testing +np.version + +np.lib.format +np.lib.mixins +np.lib.scimath +np.lib.stride_tricks +np.lib.array_utils +np.ma.extras +np.polynomial.chebyshev +np.polynomial.hermite +np.polynomial.hermite_e +np.polynomial.laguerre +np.polynomial.legendre +np.polynomial.polynomial + +np.__path__ +np.__version__ + +np.__all__ +np.char.__all__ +np.ctypeslib.__all__ +np.emath.__all__ +np.lib.__all__ +np.ma.__all__ +np.random.__all__ +np.rec.__all__ +np.strings.__all__ +np.testing.__all__ +f2py.__all__ diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/multiarray.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/multiarray.py new file mode 100644 index 0000000000000000000000000000000000000000..26cedfd77566e7b7865345e0775af88153e74ffc --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/multiarray.py @@ -0,0 +1,76 @@ +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] = np.array([1.0]) +AR_i4 = np.array([1], dtype=np.int32) +AR_u1 = np.array([1], dtype=np.uint8) + +AR_LIKE_f = [1.5] +AR_LIKE_i = [1] + +b_f8 = np.broadcast(AR_f8) +b_i4_f8_f8 = np.broadcast(AR_i4, AR_f8, AR_f8) + +next(b_f8) +b_f8.reset() +b_f8.index +b_f8.iters +b_f8.nd +b_f8.ndim +b_f8.numiter +b_f8.shape +b_f8.size + +next(b_i4_f8_f8) +b_i4_f8_f8.reset() +b_i4_f8_f8.ndim +b_i4_f8_f8.index +b_i4_f8_f8.iters +b_i4_f8_f8.nd +b_i4_f8_f8.numiter +b_i4_f8_f8.shape +b_i4_f8_f8.size + +np.inner(AR_f8, AR_i4) + +np.where([True, True, False]) +np.where([True, True, False], 1, 0) + +np.lexsort([0, 1, 2]) + +np.can_cast(np.dtype("i8"), int) +np.can_cast(AR_f8, "f8") +np.can_cast(AR_f8, np.complex128, casting="unsafe") + +np.min_scalar_type([1]) +np.min_scalar_type(AR_f8) + +np.result_type(int, AR_i4) +np.result_type(AR_f8, AR_u1) +np.result_type(AR_f8, np.complex128) + +np.dot(AR_LIKE_f, AR_i4) +np.dot(AR_u1, 1) +np.dot(1.5j, 1) +np.dot(AR_u1, 1, out=AR_f8) + +np.vdot(AR_LIKE_f, AR_i4) +np.vdot(AR_u1, 1) +np.vdot(1.5j, 1) + +np.bincount(AR_i4) + +np.copyto(AR_f8, [1.6]) + +np.putmask(AR_f8, [True], 1.5) + +np.packbits(AR_i4) +np.packbits(AR_u1) + +np.unpackbits(AR_u1) + +np.shares_memory(1, 2) +np.shares_memory(AR_f8, AR_f8, max_work=1) + +np.may_share_memory(1, 2) +np.may_share_memory(AR_f8, AR_f8, max_work=1) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_conversion.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..76da1dadd327861fdb576f1f492fbc44a4881168 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_conversion.py @@ -0,0 +1,87 @@ +import os +import tempfile + +import numpy as np + +nd = np.array([[1, 2], [3, 4]]) +scalar_array = np.array(1) + +# item +scalar_array.item() +nd.item(1) +nd.item(0, 1) +nd.item((0, 1)) + +# tobytes +nd.tobytes() +nd.tobytes("C") +nd.tobytes(None) + +# tofile +if os.name != "nt": + with tempfile.NamedTemporaryFile(suffix=".txt") as tmp: + nd.tofile(tmp.name) + nd.tofile(tmp.name, "") + nd.tofile(tmp.name, sep="") + + nd.tofile(tmp.name, "", "%s") + nd.tofile(tmp.name, format="%s") + + nd.tofile(tmp) + +# dump is pretty simple +# dumps is pretty simple + +# astype +nd.astype("float") +nd.astype(float) + +nd.astype(float, "K") +nd.astype(float, order="K") + +nd.astype(float, "K", "unsafe") +nd.astype(float, casting="unsafe") + +nd.astype(float, "K", "unsafe", True) +nd.astype(float, subok=True) + +nd.astype(float, "K", "unsafe", True, True) +nd.astype(float, copy=True) + +# byteswap +nd.byteswap() +nd.byteswap(True) + +# copy +nd.copy() +nd.copy("C") + +# view +nd.view() +nd.view(np.int64) +nd.view(dtype=np.int64) +nd.view(np.int64, np.matrix) +nd.view(type=np.matrix) + +# getfield +complex_array = np.array([[1 + 1j, 0], [0, 1 - 1j]], dtype=np.complex128) + +complex_array.getfield("float") +complex_array.getfield(float) + +complex_array.getfield("float", 8) +complex_array.getfield(float, offset=8) + +# setflags +nd.setflags() + +nd.setflags(True) +nd.setflags(write=True) + +nd.setflags(True, True) +nd.setflags(write=True, align=True) + +nd.setflags(True, True, False) +nd.setflags(write=True, align=True, uic=False) + +# fill is pretty simple diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_misc.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7df182a49a2ec09eb55767141d50968bec31f2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_misc.py @@ -0,0 +1,203 @@ +""" +Tests for miscellaneous (non-magic) ``np.ndarray``/``np.generic`` methods. + +More extensive tests are performed for the methods' +function-based counterpart in `../from_numeric.py`. + +""" + +from __future__ import annotations + +import operator +from typing import cast, Any + +import numpy as np +import numpy.typing as npt + +class SubClass(npt.NDArray[np.float64]): ... +class IntSubClass(npt.NDArray[np.intp]): ... + +i4 = np.int32(1) +A: np.ndarray[Any, np.dtype[np.int32]] = np.array([[1]], dtype=np.int32) +B0 = np.empty((), dtype=np.int32).view(SubClass) +B1 = np.empty((1,), dtype=np.int32).view(SubClass) +B2 = np.empty((1, 1), dtype=np.int32).view(SubClass) +B_int0: IntSubClass = np.empty((), dtype=np.intp).view(IntSubClass) +C: np.ndarray[Any, np.dtype[np.int32]] = np.array([0, 1, 2], dtype=np.int32) +D = np.ones(3).view(SubClass) + +ctypes_obj = A.ctypes + +i4.all() +A.all() +A.all(axis=0) +A.all(keepdims=True) +A.all(out=B0) + +i4.any() +A.any() +A.any(axis=0) +A.any(keepdims=True) +A.any(out=B0) + +i4.argmax() +A.argmax() +A.argmax(axis=0) +A.argmax(out=B_int0) + +i4.argmin() +A.argmin() +A.argmin(axis=0) +A.argmin(out=B_int0) + +i4.argsort() +i4.argsort(stable=True) +A.argsort() +A.argsort(stable=True) + +A.sort() +A.sort(stable=True) + +i4.choose([()]) +_choices = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=np.int32) +C.choose(_choices) +C.choose(_choices, out=D) + +i4.clip(1) +A.clip(1) +A.clip(None, 1) +A.clip(1, out=B2) +A.clip(None, 1, out=B2) + +i4.compress([1]) +A.compress([1]) +A.compress([1], out=B1) + +i4.conj() +A.conj() +B0.conj() + +i4.conjugate() +A.conjugate() +B0.conjugate() + +i4.cumprod() +A.cumprod() +A.cumprod(out=B1) + +i4.cumsum() +A.cumsum() +A.cumsum(out=B1) + +i4.max() +A.max() +A.max(axis=0) +A.max(keepdims=True) +A.max(out=B0) + +i4.mean() +A.mean() +A.mean(axis=0) +A.mean(keepdims=True) +A.mean(out=B0) + +i4.min() +A.min() +A.min(axis=0) +A.min(keepdims=True) +A.min(out=B0) + +i4.prod() +A.prod() +A.prod(axis=0) +A.prod(keepdims=True) +A.prod(out=B0) + +i4.round() +A.round() +A.round(out=B2) + +i4.repeat(1) +A.repeat(1) +B0.repeat(1) + +i4.std() +A.std() +A.std(axis=0) +A.std(keepdims=True) +A.std(out=B0.astype(np.float64)) + +i4.sum() +A.sum() +A.sum(axis=0) +A.sum(keepdims=True) +A.sum(out=B0) + +i4.take(0) +A.take(0) +A.take([0]) +A.take(0, out=B0) +A.take([0], out=B1) + +i4.var() +A.var() +A.var(axis=0) +A.var(keepdims=True) +A.var(out=B0) + +A.argpartition([0]) + +A.diagonal() + +A.dot(1) +A.dot(1, out=B2) + +A.nonzero() + +C.searchsorted(1) + +A.trace() +A.trace(out=B0) + +void = cast(np.void, np.array(1, dtype=[("f", np.float64)]).take(0)) +void.setfield(10, np.float64) + +A.item(0) +C.item(0) + +A.ravel() +C.ravel() + +A.flatten() +C.flatten() + +A.reshape(1) +C.reshape(3) + +int(np.array(1.0, dtype=np.float64)) +int(np.array("1", dtype=np.str_)) + +float(np.array(1.0, dtype=np.float64)) +float(np.array("1", dtype=np.str_)) + +complex(np.array(1.0, dtype=np.float64)) + +operator.index(np.array(1, dtype=np.int64)) + +# this fails on numpy 2.2.1 +# https://github.com/scipy/scipy/blob/a755ee77ec47a64849abe42c349936475a6c2f24/scipy/io/arff/tests/test_arffread.py#L41-L44 +A_float = np.array([[1, 5], [2, 4], [np.nan, np.nan]]) +A_void: npt.NDArray[np.void] = np.empty(3, [("yop", float), ("yap", float)]) +A_void["yop"] = A_float[:, 0] +A_void["yap"] = A_float[:, 1] + +# deprecated + +with np.testing.assert_warns(DeprecationWarning): + ctypes_obj.get_data() # type: ignore[deprecated] # pyright: ignore[reportDeprecated] +with np.testing.assert_warns(DeprecationWarning): + ctypes_obj.get_shape() # type: ignore[deprecated] # pyright: ignore[reportDeprecated] +with np.testing.assert_warns(DeprecationWarning): + ctypes_obj.get_strides() # type: ignore[deprecated] # pyright: ignore[reportDeprecated] +with np.testing.assert_warns(DeprecationWarning): + ctypes_obj.get_as_parameter() # type: ignore[deprecated] # pyright: ignore[reportDeprecated] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_shape_manipulation.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_shape_manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca3dff392e12a4122740ade2ceb71b3d6e7bb08 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ndarray_shape_manipulation.py @@ -0,0 +1,47 @@ +import numpy as np + +nd1 = np.array([[1, 2], [3, 4]]) + +# reshape +nd1.reshape(4) +nd1.reshape(2, 2) +nd1.reshape((2, 2)) + +nd1.reshape((2, 2), order="C") +nd1.reshape(4, order="C") + +# resize +nd1.resize() +nd1.resize(4) +nd1.resize(2, 2) +nd1.resize((2, 2)) + +nd1.resize((2, 2), refcheck=True) +nd1.resize(4, refcheck=True) + +nd2 = np.array([[1, 2], [3, 4]]) + +# transpose +nd2.transpose() +nd2.transpose(1, 0) +nd2.transpose((1, 0)) + +# swapaxes +nd2.swapaxes(0, 1) + +# flatten +nd2.flatten() +nd2.flatten("C") + +# ravel +nd2.ravel() +nd2.ravel("C") + +# squeeze +nd2.squeeze() + +nd3 = np.array([[1, 2]]) +nd3.squeeze(0) + +nd4 = np.array([[[1, 2]]]) +nd4.squeeze((0, 1)) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/nditer.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/nditer.py new file mode 100644 index 0000000000000000000000000000000000000000..25a5b44d7aecf6486be76ae0dca0e3e8ef60699b --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/nditer.py @@ -0,0 +1,4 @@ +import numpy as np + +arr = np.array([1]) +np.nditer([arr, None]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/numeric.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb14cf3a2a2e363df36ca641b3a89733f827991 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/numeric.py @@ -0,0 +1,95 @@ +""" +Tests for :mod:`numpy._core.numeric`. + +Does not include tests which fall under ``array_constructors``. + +""" + +from __future__ import annotations +from typing import cast + +import numpy as np +import numpy.typing as npt + +class SubClass(npt.NDArray[np.float64]): ... + + +i8 = np.int64(1) + +A = cast( + np.ndarray[tuple[int, int, int], np.dtype[np.intp]], + np.arange(27).reshape(3, 3, 3), +) +B: list[list[list[int]]] = A.tolist() +C = np.empty((27, 27)).view(SubClass) + +np.count_nonzero(i8) +np.count_nonzero(A) +np.count_nonzero(B) +np.count_nonzero(A, keepdims=True) +np.count_nonzero(A, axis=0) + +np.isfortran(i8) +np.isfortran(A) + +np.argwhere(i8) +np.argwhere(A) + +np.flatnonzero(i8) +np.flatnonzero(A) + +np.correlate(B[0][0], A.ravel(), mode="valid") +np.correlate(A.ravel(), A.ravel(), mode="same") + +np.convolve(B[0][0], A.ravel(), mode="valid") +np.convolve(A.ravel(), A.ravel(), mode="same") + +np.outer(i8, A) +np.outer(B, A) +np.outer(A, A) +np.outer(A, A, out=C) + +np.tensordot(B, A) +np.tensordot(A, A) +np.tensordot(A, A, axes=0) +np.tensordot(A, A, axes=(0, 1)) + +np.isscalar(i8) +np.isscalar(A) +np.isscalar(B) + +np.roll(A, 1) +np.roll(A, (1, 2)) +np.roll(B, 1) + +np.rollaxis(A, 0, 1) + +np.moveaxis(A, 0, 1) +np.moveaxis(A, (0, 1), (1, 2)) + +np.cross(B, A) +np.cross(A, A) + +np.indices([0, 1, 2]) +np.indices([0, 1, 2], sparse=False) +np.indices([0, 1, 2], sparse=True) + +np.binary_repr(1) + +np.base_repr(1) + +np.allclose(i8, A) +np.allclose(B, A) +np.allclose(A, A) + +np.isclose(i8, A) +np.isclose(B, A) +np.isclose(A, A) + +np.array_equal(i8, A) +np.array_equal(B, A) +np.array_equal(A, A) + +np.array_equiv(i8, A) +np.array_equiv(B, A) +np.array_equiv(A, A) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/numerictypes.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/numerictypes.py new file mode 100644 index 0000000000000000000000000000000000000000..24e1a9986d88ab320bdfab54a050ed27aec4d8b5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/numerictypes.py @@ -0,0 +1,17 @@ +import numpy as np + +np.isdtype(np.float64, (np.int64, np.float64)) +np.isdtype(np.int64, "signed integer") + +np.issubdtype("S1", np.bytes_) +np.issubdtype(np.float64, np.float32) + +np.ScalarType +np.ScalarType[0] +np.ScalarType[3] +np.ScalarType[8] +np.ScalarType[10] + +np.typecodes["Character"] +np.typecodes["Complex"] +np.typecodes["All"] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/random.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/random.py new file mode 100644 index 0000000000000000000000000000000000000000..bce204a7378e8a8c8761d02a08686206f59ca6ce --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/random.py @@ -0,0 +1,1497 @@ +from __future__ import annotations + +from typing import Any +import numpy as np + +SEED_NONE = None +SEED_INT = 4579435749574957634658964293569 +SEED_ARR: np.ndarray[Any, np.dtype[np.int64]] = np.array([1, 2, 3, 4], dtype=np.int64) +SEED_ARRLIKE: list[int] = [1, 2, 3, 4] +SEED_SEED_SEQ: np.random.SeedSequence = np.random.SeedSequence(0) +SEED_MT19937: np.random.MT19937 = np.random.MT19937(0) +SEED_PCG64: np.random.PCG64 = np.random.PCG64(0) +SEED_PHILOX: np.random.Philox = np.random.Philox(0) +SEED_SFC64: np.random.SFC64 = np.random.SFC64(0) + +# default rng +np.random.default_rng() +np.random.default_rng(SEED_NONE) +np.random.default_rng(SEED_INT) +np.random.default_rng(SEED_ARR) +np.random.default_rng(SEED_ARRLIKE) +np.random.default_rng(SEED_SEED_SEQ) +np.random.default_rng(SEED_MT19937) +np.random.default_rng(SEED_PCG64) +np.random.default_rng(SEED_PHILOX) +np.random.default_rng(SEED_SFC64) + +# Seed Sequence +np.random.SeedSequence(SEED_NONE) +np.random.SeedSequence(SEED_INT) +np.random.SeedSequence(SEED_ARR) +np.random.SeedSequence(SEED_ARRLIKE) + +# Bit Generators +np.random.MT19937(SEED_NONE) +np.random.MT19937(SEED_INT) +np.random.MT19937(SEED_ARR) +np.random.MT19937(SEED_ARRLIKE) +np.random.MT19937(SEED_SEED_SEQ) + +np.random.PCG64(SEED_NONE) +np.random.PCG64(SEED_INT) +np.random.PCG64(SEED_ARR) +np.random.PCG64(SEED_ARRLIKE) +np.random.PCG64(SEED_SEED_SEQ) + +np.random.Philox(SEED_NONE) +np.random.Philox(SEED_INT) +np.random.Philox(SEED_ARR) +np.random.Philox(SEED_ARRLIKE) +np.random.Philox(SEED_SEED_SEQ) + +np.random.SFC64(SEED_NONE) +np.random.SFC64(SEED_INT) +np.random.SFC64(SEED_ARR) +np.random.SFC64(SEED_ARRLIKE) +np.random.SFC64(SEED_SEED_SEQ) + +seed_seq: np.random.bit_generator.SeedSequence = np.random.SeedSequence(SEED_NONE) +seed_seq.spawn(10) +seed_seq.generate_state(3) +seed_seq.generate_state(3, "u4") +seed_seq.generate_state(3, "uint32") +seed_seq.generate_state(3, "u8") +seed_seq.generate_state(3, "uint64") +seed_seq.generate_state(3, np.uint32) +seed_seq.generate_state(3, np.uint64) + + +def_gen: np.random.Generator = np.random.default_rng() + +D_arr_0p1: np.ndarray[Any, np.dtype[np.float64]] = np.array([0.1]) +D_arr_0p5: np.ndarray[Any, np.dtype[np.float64]] = np.array([0.5]) +D_arr_0p9: np.ndarray[Any, np.dtype[np.float64]] = np.array([0.9]) +D_arr_1p5: np.ndarray[Any, np.dtype[np.float64]] = np.array([1.5]) +I_arr_10: np.ndarray[Any, np.dtype[np.int_]] = np.array([10], dtype=np.int_) +I_arr_20: np.ndarray[Any, np.dtype[np.int_]] = np.array([20], dtype=np.int_) +D_arr_like_0p1: list[float] = [0.1] +D_arr_like_0p5: list[float] = [0.5] +D_arr_like_0p9: list[float] = [0.9] +D_arr_like_1p5: list[float] = [1.5] +I_arr_like_10: list[int] = [10] +I_arr_like_20: list[int] = [20] +D_2D_like: list[list[float]] = [[1, 2], [2, 3], [3, 4], [4, 5.1]] +D_2D: np.ndarray[Any, np.dtype[np.float64]] = np.array(D_2D_like) + +S_out: np.ndarray[Any, np.dtype[np.float32]] = np.empty(1, dtype=np.float32) +D_out: np.ndarray[Any, np.dtype[np.float64]] = np.empty(1) + +def_gen.standard_normal() +def_gen.standard_normal(dtype=np.float32) +def_gen.standard_normal(dtype="float32") +def_gen.standard_normal(dtype="double") +def_gen.standard_normal(dtype=np.float64) +def_gen.standard_normal(size=None) +def_gen.standard_normal(size=1) +def_gen.standard_normal(size=1, dtype=np.float32) +def_gen.standard_normal(size=1, dtype="f4") +def_gen.standard_normal(size=1, dtype="float32", out=S_out) +def_gen.standard_normal(dtype=np.float32, out=S_out) +def_gen.standard_normal(size=1, dtype=np.float64) +def_gen.standard_normal(size=1, dtype="float64") +def_gen.standard_normal(size=1, dtype="f8") +def_gen.standard_normal(out=D_out) +def_gen.standard_normal(size=1, dtype="float64") +def_gen.standard_normal(size=1, dtype="float64", out=D_out) + +def_gen.random() +def_gen.random(dtype=np.float32) +def_gen.random(dtype="float32") +def_gen.random(dtype="double") +def_gen.random(dtype=np.float64) +def_gen.random(size=None) +def_gen.random(size=1) +def_gen.random(size=1, dtype=np.float32) +def_gen.random(size=1, dtype="f4") +def_gen.random(size=1, dtype="float32", out=S_out) +def_gen.random(dtype=np.float32, out=S_out) +def_gen.random(size=1, dtype=np.float64) +def_gen.random(size=1, dtype="float64") +def_gen.random(size=1, dtype="f8") +def_gen.random(out=D_out) +def_gen.random(size=1, dtype="float64") +def_gen.random(size=1, dtype="float64", out=D_out) + +def_gen.standard_cauchy() +def_gen.standard_cauchy(size=None) +def_gen.standard_cauchy(size=1) + +def_gen.standard_exponential() +def_gen.standard_exponential(method="inv") +def_gen.standard_exponential(dtype=np.float32) +def_gen.standard_exponential(dtype="float32") +def_gen.standard_exponential(dtype="double") +def_gen.standard_exponential(dtype=np.float64) +def_gen.standard_exponential(size=None) +def_gen.standard_exponential(size=None, method="inv") +def_gen.standard_exponential(size=1, method="inv") +def_gen.standard_exponential(size=1, dtype=np.float32) +def_gen.standard_exponential(size=1, dtype="f4", method="inv") +def_gen.standard_exponential(size=1, dtype="float32", out=S_out) +def_gen.standard_exponential(dtype=np.float32, out=S_out) +def_gen.standard_exponential(size=1, dtype=np.float64, method="inv") +def_gen.standard_exponential(size=1, dtype="float64") +def_gen.standard_exponential(size=1, dtype="f8") +def_gen.standard_exponential(out=D_out) +def_gen.standard_exponential(size=1, dtype="float64") +def_gen.standard_exponential(size=1, dtype="float64", out=D_out) + +def_gen.zipf(1.5) +def_gen.zipf(1.5, size=None) +def_gen.zipf(1.5, size=1) +def_gen.zipf(D_arr_1p5) +def_gen.zipf(D_arr_1p5, size=1) +def_gen.zipf(D_arr_like_1p5) +def_gen.zipf(D_arr_like_1p5, size=1) + +def_gen.weibull(0.5) +def_gen.weibull(0.5, size=None) +def_gen.weibull(0.5, size=1) +def_gen.weibull(D_arr_0p5) +def_gen.weibull(D_arr_0p5, size=1) +def_gen.weibull(D_arr_like_0p5) +def_gen.weibull(D_arr_like_0p5, size=1) + +def_gen.standard_t(0.5) +def_gen.standard_t(0.5, size=None) +def_gen.standard_t(0.5, size=1) +def_gen.standard_t(D_arr_0p5) +def_gen.standard_t(D_arr_0p5, size=1) +def_gen.standard_t(D_arr_like_0p5) +def_gen.standard_t(D_arr_like_0p5, size=1) + +def_gen.poisson(0.5) +def_gen.poisson(0.5, size=None) +def_gen.poisson(0.5, size=1) +def_gen.poisson(D_arr_0p5) +def_gen.poisson(D_arr_0p5, size=1) +def_gen.poisson(D_arr_like_0p5) +def_gen.poisson(D_arr_like_0p5, size=1) + +def_gen.power(0.5) +def_gen.power(0.5, size=None) +def_gen.power(0.5, size=1) +def_gen.power(D_arr_0p5) +def_gen.power(D_arr_0p5, size=1) +def_gen.power(D_arr_like_0p5) +def_gen.power(D_arr_like_0p5, size=1) + +def_gen.pareto(0.5) +def_gen.pareto(0.5, size=None) +def_gen.pareto(0.5, size=1) +def_gen.pareto(D_arr_0p5) +def_gen.pareto(D_arr_0p5, size=1) +def_gen.pareto(D_arr_like_0p5) +def_gen.pareto(D_arr_like_0p5, size=1) + +def_gen.chisquare(0.5) +def_gen.chisquare(0.5, size=None) +def_gen.chisquare(0.5, size=1) +def_gen.chisquare(D_arr_0p5) +def_gen.chisquare(D_arr_0p5, size=1) +def_gen.chisquare(D_arr_like_0p5) +def_gen.chisquare(D_arr_like_0p5, size=1) + +def_gen.exponential(0.5) +def_gen.exponential(0.5, size=None) +def_gen.exponential(0.5, size=1) +def_gen.exponential(D_arr_0p5) +def_gen.exponential(D_arr_0p5, size=1) +def_gen.exponential(D_arr_like_0p5) +def_gen.exponential(D_arr_like_0p5, size=1) + +def_gen.geometric(0.5) +def_gen.geometric(0.5, size=None) +def_gen.geometric(0.5, size=1) +def_gen.geometric(D_arr_0p5) +def_gen.geometric(D_arr_0p5, size=1) +def_gen.geometric(D_arr_like_0p5) +def_gen.geometric(D_arr_like_0p5, size=1) + +def_gen.logseries(0.5) +def_gen.logseries(0.5, size=None) +def_gen.logseries(0.5, size=1) +def_gen.logseries(D_arr_0p5) +def_gen.logseries(D_arr_0p5, size=1) +def_gen.logseries(D_arr_like_0p5) +def_gen.logseries(D_arr_like_0p5, size=1) + +def_gen.rayleigh(0.5) +def_gen.rayleigh(0.5, size=None) +def_gen.rayleigh(0.5, size=1) +def_gen.rayleigh(D_arr_0p5) +def_gen.rayleigh(D_arr_0p5, size=1) +def_gen.rayleigh(D_arr_like_0p5) +def_gen.rayleigh(D_arr_like_0p5, size=1) + +def_gen.standard_gamma(0.5) +def_gen.standard_gamma(0.5, size=None) +def_gen.standard_gamma(0.5, dtype="float32") +def_gen.standard_gamma(0.5, size=None, dtype="float32") +def_gen.standard_gamma(0.5, size=1) +def_gen.standard_gamma(D_arr_0p5) +def_gen.standard_gamma(D_arr_0p5, dtype="f4") +def_gen.standard_gamma(0.5, size=1, dtype="float32", out=S_out) +def_gen.standard_gamma(D_arr_0p5, dtype=np.float32, out=S_out) +def_gen.standard_gamma(D_arr_0p5, size=1) +def_gen.standard_gamma(D_arr_like_0p5) +def_gen.standard_gamma(D_arr_like_0p5, size=1) +def_gen.standard_gamma(0.5, out=D_out) +def_gen.standard_gamma(D_arr_like_0p5, out=D_out) +def_gen.standard_gamma(D_arr_like_0p5, size=1) +def_gen.standard_gamma(D_arr_like_0p5, size=1, out=D_out, dtype=np.float64) + +def_gen.vonmises(0.5, 0.5) +def_gen.vonmises(0.5, 0.5, size=None) +def_gen.vonmises(0.5, 0.5, size=1) +def_gen.vonmises(D_arr_0p5, 0.5) +def_gen.vonmises(0.5, D_arr_0p5) +def_gen.vonmises(D_arr_0p5, 0.5, size=1) +def_gen.vonmises(0.5, D_arr_0p5, size=1) +def_gen.vonmises(D_arr_like_0p5, 0.5) +def_gen.vonmises(0.5, D_arr_like_0p5) +def_gen.vonmises(D_arr_0p5, D_arr_0p5) +def_gen.vonmises(D_arr_like_0p5, D_arr_like_0p5) +def_gen.vonmises(D_arr_0p5, D_arr_0p5, size=1) +def_gen.vonmises(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.wald(0.5, 0.5) +def_gen.wald(0.5, 0.5, size=None) +def_gen.wald(0.5, 0.5, size=1) +def_gen.wald(D_arr_0p5, 0.5) +def_gen.wald(0.5, D_arr_0p5) +def_gen.wald(D_arr_0p5, 0.5, size=1) +def_gen.wald(0.5, D_arr_0p5, size=1) +def_gen.wald(D_arr_like_0p5, 0.5) +def_gen.wald(0.5, D_arr_like_0p5) +def_gen.wald(D_arr_0p5, D_arr_0p5) +def_gen.wald(D_arr_like_0p5, D_arr_like_0p5) +def_gen.wald(D_arr_0p5, D_arr_0p5, size=1) +def_gen.wald(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.uniform(0.5, 0.5) +def_gen.uniform(0.5, 0.5, size=None) +def_gen.uniform(0.5, 0.5, size=1) +def_gen.uniform(D_arr_0p5, 0.5) +def_gen.uniform(0.5, D_arr_0p5) +def_gen.uniform(D_arr_0p5, 0.5, size=1) +def_gen.uniform(0.5, D_arr_0p5, size=1) +def_gen.uniform(D_arr_like_0p5, 0.5) +def_gen.uniform(0.5, D_arr_like_0p5) +def_gen.uniform(D_arr_0p5, D_arr_0p5) +def_gen.uniform(D_arr_like_0p5, D_arr_like_0p5) +def_gen.uniform(D_arr_0p5, D_arr_0p5, size=1) +def_gen.uniform(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.beta(0.5, 0.5) +def_gen.beta(0.5, 0.5, size=None) +def_gen.beta(0.5, 0.5, size=1) +def_gen.beta(D_arr_0p5, 0.5) +def_gen.beta(0.5, D_arr_0p5) +def_gen.beta(D_arr_0p5, 0.5, size=1) +def_gen.beta(0.5, D_arr_0p5, size=1) +def_gen.beta(D_arr_like_0p5, 0.5) +def_gen.beta(0.5, D_arr_like_0p5) +def_gen.beta(D_arr_0p5, D_arr_0p5) +def_gen.beta(D_arr_like_0p5, D_arr_like_0p5) +def_gen.beta(D_arr_0p5, D_arr_0p5, size=1) +def_gen.beta(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.f(0.5, 0.5) +def_gen.f(0.5, 0.5, size=None) +def_gen.f(0.5, 0.5, size=1) +def_gen.f(D_arr_0p5, 0.5) +def_gen.f(0.5, D_arr_0p5) +def_gen.f(D_arr_0p5, 0.5, size=1) +def_gen.f(0.5, D_arr_0p5, size=1) +def_gen.f(D_arr_like_0p5, 0.5) +def_gen.f(0.5, D_arr_like_0p5) +def_gen.f(D_arr_0p5, D_arr_0p5) +def_gen.f(D_arr_like_0p5, D_arr_like_0p5) +def_gen.f(D_arr_0p5, D_arr_0p5, size=1) +def_gen.f(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.gamma(0.5, 0.5) +def_gen.gamma(0.5, 0.5, size=None) +def_gen.gamma(0.5, 0.5, size=1) +def_gen.gamma(D_arr_0p5, 0.5) +def_gen.gamma(0.5, D_arr_0p5) +def_gen.gamma(D_arr_0p5, 0.5, size=1) +def_gen.gamma(0.5, D_arr_0p5, size=1) +def_gen.gamma(D_arr_like_0p5, 0.5) +def_gen.gamma(0.5, D_arr_like_0p5) +def_gen.gamma(D_arr_0p5, D_arr_0p5) +def_gen.gamma(D_arr_like_0p5, D_arr_like_0p5) +def_gen.gamma(D_arr_0p5, D_arr_0p5, size=1) +def_gen.gamma(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.gumbel(0.5, 0.5) +def_gen.gumbel(0.5, 0.5, size=None) +def_gen.gumbel(0.5, 0.5, size=1) +def_gen.gumbel(D_arr_0p5, 0.5) +def_gen.gumbel(0.5, D_arr_0p5) +def_gen.gumbel(D_arr_0p5, 0.5, size=1) +def_gen.gumbel(0.5, D_arr_0p5, size=1) +def_gen.gumbel(D_arr_like_0p5, 0.5) +def_gen.gumbel(0.5, D_arr_like_0p5) +def_gen.gumbel(D_arr_0p5, D_arr_0p5) +def_gen.gumbel(D_arr_like_0p5, D_arr_like_0p5) +def_gen.gumbel(D_arr_0p5, D_arr_0p5, size=1) +def_gen.gumbel(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.laplace(0.5, 0.5) +def_gen.laplace(0.5, 0.5, size=None) +def_gen.laplace(0.5, 0.5, size=1) +def_gen.laplace(D_arr_0p5, 0.5) +def_gen.laplace(0.5, D_arr_0p5) +def_gen.laplace(D_arr_0p5, 0.5, size=1) +def_gen.laplace(0.5, D_arr_0p5, size=1) +def_gen.laplace(D_arr_like_0p5, 0.5) +def_gen.laplace(0.5, D_arr_like_0p5) +def_gen.laplace(D_arr_0p5, D_arr_0p5) +def_gen.laplace(D_arr_like_0p5, D_arr_like_0p5) +def_gen.laplace(D_arr_0p5, D_arr_0p5, size=1) +def_gen.laplace(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.logistic(0.5, 0.5) +def_gen.logistic(0.5, 0.5, size=None) +def_gen.logistic(0.5, 0.5, size=1) +def_gen.logistic(D_arr_0p5, 0.5) +def_gen.logistic(0.5, D_arr_0p5) +def_gen.logistic(D_arr_0p5, 0.5, size=1) +def_gen.logistic(0.5, D_arr_0p5, size=1) +def_gen.logistic(D_arr_like_0p5, 0.5) +def_gen.logistic(0.5, D_arr_like_0p5) +def_gen.logistic(D_arr_0p5, D_arr_0p5) +def_gen.logistic(D_arr_like_0p5, D_arr_like_0p5) +def_gen.logistic(D_arr_0p5, D_arr_0p5, size=1) +def_gen.logistic(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.lognormal(0.5, 0.5) +def_gen.lognormal(0.5, 0.5, size=None) +def_gen.lognormal(0.5, 0.5, size=1) +def_gen.lognormal(D_arr_0p5, 0.5) +def_gen.lognormal(0.5, D_arr_0p5) +def_gen.lognormal(D_arr_0p5, 0.5, size=1) +def_gen.lognormal(0.5, D_arr_0p5, size=1) +def_gen.lognormal(D_arr_like_0p5, 0.5) +def_gen.lognormal(0.5, D_arr_like_0p5) +def_gen.lognormal(D_arr_0p5, D_arr_0p5) +def_gen.lognormal(D_arr_like_0p5, D_arr_like_0p5) +def_gen.lognormal(D_arr_0p5, D_arr_0p5, size=1) +def_gen.lognormal(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.noncentral_chisquare(0.5, 0.5) +def_gen.noncentral_chisquare(0.5, 0.5, size=None) +def_gen.noncentral_chisquare(0.5, 0.5, size=1) +def_gen.noncentral_chisquare(D_arr_0p5, 0.5) +def_gen.noncentral_chisquare(0.5, D_arr_0p5) +def_gen.noncentral_chisquare(D_arr_0p5, 0.5, size=1) +def_gen.noncentral_chisquare(0.5, D_arr_0p5, size=1) +def_gen.noncentral_chisquare(D_arr_like_0p5, 0.5) +def_gen.noncentral_chisquare(0.5, D_arr_like_0p5) +def_gen.noncentral_chisquare(D_arr_0p5, D_arr_0p5) +def_gen.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5) +def_gen.noncentral_chisquare(D_arr_0p5, D_arr_0p5, size=1) +def_gen.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.normal(0.5, 0.5) +def_gen.normal(0.5, 0.5, size=None) +def_gen.normal(0.5, 0.5, size=1) +def_gen.normal(D_arr_0p5, 0.5) +def_gen.normal(0.5, D_arr_0p5) +def_gen.normal(D_arr_0p5, 0.5, size=1) +def_gen.normal(0.5, D_arr_0p5, size=1) +def_gen.normal(D_arr_like_0p5, 0.5) +def_gen.normal(0.5, D_arr_like_0p5) +def_gen.normal(D_arr_0p5, D_arr_0p5) +def_gen.normal(D_arr_like_0p5, D_arr_like_0p5) +def_gen.normal(D_arr_0p5, D_arr_0p5, size=1) +def_gen.normal(D_arr_like_0p5, D_arr_like_0p5, size=1) + +def_gen.triangular(0.1, 0.5, 0.9) +def_gen.triangular(0.1, 0.5, 0.9, size=None) +def_gen.triangular(0.1, 0.5, 0.9, size=1) +def_gen.triangular(D_arr_0p1, 0.5, 0.9) +def_gen.triangular(0.1, D_arr_0p5, 0.9) +def_gen.triangular(D_arr_0p1, 0.5, D_arr_like_0p9, size=1) +def_gen.triangular(0.1, D_arr_0p5, 0.9, size=1) +def_gen.triangular(D_arr_like_0p1, 0.5, D_arr_0p9) +def_gen.triangular(0.5, D_arr_like_0p5, 0.9) +def_gen.triangular(D_arr_0p1, D_arr_0p5, 0.9) +def_gen.triangular(D_arr_like_0p1, D_arr_like_0p5, 0.9) +def_gen.triangular(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1) +def_gen.triangular(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1) + +def_gen.noncentral_f(0.1, 0.5, 0.9) +def_gen.noncentral_f(0.1, 0.5, 0.9, size=None) +def_gen.noncentral_f(0.1, 0.5, 0.9, size=1) +def_gen.noncentral_f(D_arr_0p1, 0.5, 0.9) +def_gen.noncentral_f(0.1, D_arr_0p5, 0.9) +def_gen.noncentral_f(D_arr_0p1, 0.5, D_arr_like_0p9, size=1) +def_gen.noncentral_f(0.1, D_arr_0p5, 0.9, size=1) +def_gen.noncentral_f(D_arr_like_0p1, 0.5, D_arr_0p9) +def_gen.noncentral_f(0.5, D_arr_like_0p5, 0.9) +def_gen.noncentral_f(D_arr_0p1, D_arr_0p5, 0.9) +def_gen.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, 0.9) +def_gen.noncentral_f(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1) +def_gen.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1) + +def_gen.binomial(10, 0.5) +def_gen.binomial(10, 0.5, size=None) +def_gen.binomial(10, 0.5, size=1) +def_gen.binomial(I_arr_10, 0.5) +def_gen.binomial(10, D_arr_0p5) +def_gen.binomial(I_arr_10, 0.5, size=1) +def_gen.binomial(10, D_arr_0p5, size=1) +def_gen.binomial(I_arr_like_10, 0.5) +def_gen.binomial(10, D_arr_like_0p5) +def_gen.binomial(I_arr_10, D_arr_0p5) +def_gen.binomial(I_arr_like_10, D_arr_like_0p5) +def_gen.binomial(I_arr_10, D_arr_0p5, size=1) +def_gen.binomial(I_arr_like_10, D_arr_like_0p5, size=1) + +def_gen.negative_binomial(10, 0.5) +def_gen.negative_binomial(10, 0.5, size=None) +def_gen.negative_binomial(10, 0.5, size=1) +def_gen.negative_binomial(I_arr_10, 0.5) +def_gen.negative_binomial(10, D_arr_0p5) +def_gen.negative_binomial(I_arr_10, 0.5, size=1) +def_gen.negative_binomial(10, D_arr_0p5, size=1) +def_gen.negative_binomial(I_arr_like_10, 0.5) +def_gen.negative_binomial(10, D_arr_like_0p5) +def_gen.negative_binomial(I_arr_10, D_arr_0p5) +def_gen.negative_binomial(I_arr_like_10, D_arr_like_0p5) +def_gen.negative_binomial(I_arr_10, D_arr_0p5, size=1) +def_gen.negative_binomial(I_arr_like_10, D_arr_like_0p5, size=1) + +def_gen.hypergeometric(20, 20, 10) +def_gen.hypergeometric(20, 20, 10, size=None) +def_gen.hypergeometric(20, 20, 10, size=1) +def_gen.hypergeometric(I_arr_20, 20, 10) +def_gen.hypergeometric(20, I_arr_20, 10) +def_gen.hypergeometric(I_arr_20, 20, I_arr_like_10, size=1) +def_gen.hypergeometric(20, I_arr_20, 10, size=1) +def_gen.hypergeometric(I_arr_like_20, 20, I_arr_10) +def_gen.hypergeometric(20, I_arr_like_20, 10) +def_gen.hypergeometric(I_arr_20, I_arr_20, 10) +def_gen.hypergeometric(I_arr_like_20, I_arr_like_20, 10) +def_gen.hypergeometric(I_arr_20, I_arr_20, I_arr_10, size=1) +def_gen.hypergeometric(I_arr_like_20, I_arr_like_20, I_arr_like_10, size=1) + +I_int64_100: np.ndarray[Any, np.dtype[np.int64]] = np.array([100], dtype=np.int64) + +def_gen.integers(0, 100) +def_gen.integers(100) +def_gen.integers([100]) +def_gen.integers(0, [100]) + +I_bool_low: np.ndarray[Any, np.dtype[np.bool]] = np.array([0], dtype=np.bool) +I_bool_low_like: list[int] = [0] +I_bool_high_open: np.ndarray[Any, np.dtype[np.bool]] = np.array([1], dtype=np.bool) +I_bool_high_closed: np.ndarray[Any, np.dtype[np.bool]] = np.array([1], dtype=np.bool) + +def_gen.integers(2, dtype=bool) +def_gen.integers(0, 2, dtype=bool) +def_gen.integers(1, dtype=bool, endpoint=True) +def_gen.integers(0, 1, dtype=bool, endpoint=True) +def_gen.integers(I_bool_low_like, 1, dtype=bool, endpoint=True) +def_gen.integers(I_bool_high_open, dtype=bool) +def_gen.integers(I_bool_low, I_bool_high_open, dtype=bool) +def_gen.integers(0, I_bool_high_open, dtype=bool) +def_gen.integers(I_bool_high_closed, dtype=bool, endpoint=True) +def_gen.integers(I_bool_low, I_bool_high_closed, dtype=bool, endpoint=True) +def_gen.integers(0, I_bool_high_closed, dtype=bool, endpoint=True) + +def_gen.integers(2, dtype=np.bool) +def_gen.integers(0, 2, dtype=np.bool) +def_gen.integers(1, dtype=np.bool, endpoint=True) +def_gen.integers(0, 1, dtype=np.bool, endpoint=True) +def_gen.integers(I_bool_low_like, 1, dtype=np.bool, endpoint=True) +def_gen.integers(I_bool_high_open, dtype=np.bool) +def_gen.integers(I_bool_low, I_bool_high_open, dtype=np.bool) +def_gen.integers(0, I_bool_high_open, dtype=np.bool) +def_gen.integers(I_bool_high_closed, dtype=np.bool, endpoint=True) +def_gen.integers(I_bool_low, I_bool_high_closed, dtype=np.bool, endpoint=True) +def_gen.integers(0, I_bool_high_closed, dtype=np.bool, endpoint=True) + +I_u1_low: np.ndarray[Any, np.dtype[np.uint8]] = np.array([0], dtype=np.uint8) +I_u1_low_like: list[int] = [0] +I_u1_high_open: np.ndarray[Any, np.dtype[np.uint8]] = np.array([255], dtype=np.uint8) +I_u1_high_closed: np.ndarray[Any, np.dtype[np.uint8]] = np.array([255], dtype=np.uint8) + +def_gen.integers(256, dtype="u1") +def_gen.integers(0, 256, dtype="u1") +def_gen.integers(255, dtype="u1", endpoint=True) +def_gen.integers(0, 255, dtype="u1", endpoint=True) +def_gen.integers(I_u1_low_like, 255, dtype="u1", endpoint=True) +def_gen.integers(I_u1_high_open, dtype="u1") +def_gen.integers(I_u1_low, I_u1_high_open, dtype="u1") +def_gen.integers(0, I_u1_high_open, dtype="u1") +def_gen.integers(I_u1_high_closed, dtype="u1", endpoint=True) +def_gen.integers(I_u1_low, I_u1_high_closed, dtype="u1", endpoint=True) +def_gen.integers(0, I_u1_high_closed, dtype="u1", endpoint=True) + +def_gen.integers(256, dtype="uint8") +def_gen.integers(0, 256, dtype="uint8") +def_gen.integers(255, dtype="uint8", endpoint=True) +def_gen.integers(0, 255, dtype="uint8", endpoint=True) +def_gen.integers(I_u1_low_like, 255, dtype="uint8", endpoint=True) +def_gen.integers(I_u1_high_open, dtype="uint8") +def_gen.integers(I_u1_low, I_u1_high_open, dtype="uint8") +def_gen.integers(0, I_u1_high_open, dtype="uint8") +def_gen.integers(I_u1_high_closed, dtype="uint8", endpoint=True) +def_gen.integers(I_u1_low, I_u1_high_closed, dtype="uint8", endpoint=True) +def_gen.integers(0, I_u1_high_closed, dtype="uint8", endpoint=True) + +def_gen.integers(256, dtype=np.uint8) +def_gen.integers(0, 256, dtype=np.uint8) +def_gen.integers(255, dtype=np.uint8, endpoint=True) +def_gen.integers(0, 255, dtype=np.uint8, endpoint=True) +def_gen.integers(I_u1_low_like, 255, dtype=np.uint8, endpoint=True) +def_gen.integers(I_u1_high_open, dtype=np.uint8) +def_gen.integers(I_u1_low, I_u1_high_open, dtype=np.uint8) +def_gen.integers(0, I_u1_high_open, dtype=np.uint8) +def_gen.integers(I_u1_high_closed, dtype=np.uint8, endpoint=True) +def_gen.integers(I_u1_low, I_u1_high_closed, dtype=np.uint8, endpoint=True) +def_gen.integers(0, I_u1_high_closed, dtype=np.uint8, endpoint=True) + +I_u2_low: np.ndarray[Any, np.dtype[np.uint16]] = np.array([0], dtype=np.uint16) +I_u2_low_like: list[int] = [0] +I_u2_high_open: np.ndarray[Any, np.dtype[np.uint16]] = np.array([65535], dtype=np.uint16) +I_u2_high_closed: np.ndarray[Any, np.dtype[np.uint16]] = np.array([65535], dtype=np.uint16) + +def_gen.integers(65536, dtype="u2") +def_gen.integers(0, 65536, dtype="u2") +def_gen.integers(65535, dtype="u2", endpoint=True) +def_gen.integers(0, 65535, dtype="u2", endpoint=True) +def_gen.integers(I_u2_low_like, 65535, dtype="u2", endpoint=True) +def_gen.integers(I_u2_high_open, dtype="u2") +def_gen.integers(I_u2_low, I_u2_high_open, dtype="u2") +def_gen.integers(0, I_u2_high_open, dtype="u2") +def_gen.integers(I_u2_high_closed, dtype="u2", endpoint=True) +def_gen.integers(I_u2_low, I_u2_high_closed, dtype="u2", endpoint=True) +def_gen.integers(0, I_u2_high_closed, dtype="u2", endpoint=True) + +def_gen.integers(65536, dtype="uint16") +def_gen.integers(0, 65536, dtype="uint16") +def_gen.integers(65535, dtype="uint16", endpoint=True) +def_gen.integers(0, 65535, dtype="uint16", endpoint=True) +def_gen.integers(I_u2_low_like, 65535, dtype="uint16", endpoint=True) +def_gen.integers(I_u2_high_open, dtype="uint16") +def_gen.integers(I_u2_low, I_u2_high_open, dtype="uint16") +def_gen.integers(0, I_u2_high_open, dtype="uint16") +def_gen.integers(I_u2_high_closed, dtype="uint16", endpoint=True) +def_gen.integers(I_u2_low, I_u2_high_closed, dtype="uint16", endpoint=True) +def_gen.integers(0, I_u2_high_closed, dtype="uint16", endpoint=True) + +def_gen.integers(65536, dtype=np.uint16) +def_gen.integers(0, 65536, dtype=np.uint16) +def_gen.integers(65535, dtype=np.uint16, endpoint=True) +def_gen.integers(0, 65535, dtype=np.uint16, endpoint=True) +def_gen.integers(I_u2_low_like, 65535, dtype=np.uint16, endpoint=True) +def_gen.integers(I_u2_high_open, dtype=np.uint16) +def_gen.integers(I_u2_low, I_u2_high_open, dtype=np.uint16) +def_gen.integers(0, I_u2_high_open, dtype=np.uint16) +def_gen.integers(I_u2_high_closed, dtype=np.uint16, endpoint=True) +def_gen.integers(I_u2_low, I_u2_high_closed, dtype=np.uint16, endpoint=True) +def_gen.integers(0, I_u2_high_closed, dtype=np.uint16, endpoint=True) + +I_u4_low: np.ndarray[Any, np.dtype[np.uint32]] = np.array([0], dtype=np.uint32) +I_u4_low_like: list[int] = [0] +I_u4_high_open: np.ndarray[Any, np.dtype[np.uint32]] = np.array([4294967295], dtype=np.uint32) +I_u4_high_closed: np.ndarray[Any, np.dtype[np.uint32]] = np.array([4294967295], dtype=np.uint32) + +def_gen.integers(4294967296, dtype="u4") +def_gen.integers(0, 4294967296, dtype="u4") +def_gen.integers(4294967295, dtype="u4", endpoint=True) +def_gen.integers(0, 4294967295, dtype="u4", endpoint=True) +def_gen.integers(I_u4_low_like, 4294967295, dtype="u4", endpoint=True) +def_gen.integers(I_u4_high_open, dtype="u4") +def_gen.integers(I_u4_low, I_u4_high_open, dtype="u4") +def_gen.integers(0, I_u4_high_open, dtype="u4") +def_gen.integers(I_u4_high_closed, dtype="u4", endpoint=True) +def_gen.integers(I_u4_low, I_u4_high_closed, dtype="u4", endpoint=True) +def_gen.integers(0, I_u4_high_closed, dtype="u4", endpoint=True) + +def_gen.integers(4294967296, dtype="uint32") +def_gen.integers(0, 4294967296, dtype="uint32") +def_gen.integers(4294967295, dtype="uint32", endpoint=True) +def_gen.integers(0, 4294967295, dtype="uint32", endpoint=True) +def_gen.integers(I_u4_low_like, 4294967295, dtype="uint32", endpoint=True) +def_gen.integers(I_u4_high_open, dtype="uint32") +def_gen.integers(I_u4_low, I_u4_high_open, dtype="uint32") +def_gen.integers(0, I_u4_high_open, dtype="uint32") +def_gen.integers(I_u4_high_closed, dtype="uint32", endpoint=True) +def_gen.integers(I_u4_low, I_u4_high_closed, dtype="uint32", endpoint=True) +def_gen.integers(0, I_u4_high_closed, dtype="uint32", endpoint=True) + +def_gen.integers(4294967296, dtype=np.uint32) +def_gen.integers(0, 4294967296, dtype=np.uint32) +def_gen.integers(4294967295, dtype=np.uint32, endpoint=True) +def_gen.integers(0, 4294967295, dtype=np.uint32, endpoint=True) +def_gen.integers(I_u4_low_like, 4294967295, dtype=np.uint32, endpoint=True) +def_gen.integers(I_u4_high_open, dtype=np.uint32) +def_gen.integers(I_u4_low, I_u4_high_open, dtype=np.uint32) +def_gen.integers(0, I_u4_high_open, dtype=np.uint32) +def_gen.integers(I_u4_high_closed, dtype=np.uint32, endpoint=True) +def_gen.integers(I_u4_low, I_u4_high_closed, dtype=np.uint32, endpoint=True) +def_gen.integers(0, I_u4_high_closed, dtype=np.uint32, endpoint=True) + +I_u8_low: np.ndarray[Any, np.dtype[np.uint64]] = np.array([0], dtype=np.uint64) +I_u8_low_like: list[int] = [0] +I_u8_high_open: np.ndarray[Any, np.dtype[np.uint64]] = np.array([18446744073709551615], dtype=np.uint64) +I_u8_high_closed: np.ndarray[Any, np.dtype[np.uint64]] = np.array([18446744073709551615], dtype=np.uint64) + +def_gen.integers(18446744073709551616, dtype="u8") +def_gen.integers(0, 18446744073709551616, dtype="u8") +def_gen.integers(18446744073709551615, dtype="u8", endpoint=True) +def_gen.integers(0, 18446744073709551615, dtype="u8", endpoint=True) +def_gen.integers(I_u8_low_like, 18446744073709551615, dtype="u8", endpoint=True) +def_gen.integers(I_u8_high_open, dtype="u8") +def_gen.integers(I_u8_low, I_u8_high_open, dtype="u8") +def_gen.integers(0, I_u8_high_open, dtype="u8") +def_gen.integers(I_u8_high_closed, dtype="u8", endpoint=True) +def_gen.integers(I_u8_low, I_u8_high_closed, dtype="u8", endpoint=True) +def_gen.integers(0, I_u8_high_closed, dtype="u8", endpoint=True) + +def_gen.integers(18446744073709551616, dtype="uint64") +def_gen.integers(0, 18446744073709551616, dtype="uint64") +def_gen.integers(18446744073709551615, dtype="uint64", endpoint=True) +def_gen.integers(0, 18446744073709551615, dtype="uint64", endpoint=True) +def_gen.integers(I_u8_low_like, 18446744073709551615, dtype="uint64", endpoint=True) +def_gen.integers(I_u8_high_open, dtype="uint64") +def_gen.integers(I_u8_low, I_u8_high_open, dtype="uint64") +def_gen.integers(0, I_u8_high_open, dtype="uint64") +def_gen.integers(I_u8_high_closed, dtype="uint64", endpoint=True) +def_gen.integers(I_u8_low, I_u8_high_closed, dtype="uint64", endpoint=True) +def_gen.integers(0, I_u8_high_closed, dtype="uint64", endpoint=True) + +def_gen.integers(18446744073709551616, dtype=np.uint64) +def_gen.integers(0, 18446744073709551616, dtype=np.uint64) +def_gen.integers(18446744073709551615, dtype=np.uint64, endpoint=True) +def_gen.integers(0, 18446744073709551615, dtype=np.uint64, endpoint=True) +def_gen.integers(I_u8_low_like, 18446744073709551615, dtype=np.uint64, endpoint=True) +def_gen.integers(I_u8_high_open, dtype=np.uint64) +def_gen.integers(I_u8_low, I_u8_high_open, dtype=np.uint64) +def_gen.integers(0, I_u8_high_open, dtype=np.uint64) +def_gen.integers(I_u8_high_closed, dtype=np.uint64, endpoint=True) +def_gen.integers(I_u8_low, I_u8_high_closed, dtype=np.uint64, endpoint=True) +def_gen.integers(0, I_u8_high_closed, dtype=np.uint64, endpoint=True) + +I_i1_low: np.ndarray[Any, np.dtype[np.int8]] = np.array([-128], dtype=np.int8) +I_i1_low_like: list[int] = [-128] +I_i1_high_open: np.ndarray[Any, np.dtype[np.int8]] = np.array([127], dtype=np.int8) +I_i1_high_closed: np.ndarray[Any, np.dtype[np.int8]] = np.array([127], dtype=np.int8) + +def_gen.integers(128, dtype="i1") +def_gen.integers(-128, 128, dtype="i1") +def_gen.integers(127, dtype="i1", endpoint=True) +def_gen.integers(-128, 127, dtype="i1", endpoint=True) +def_gen.integers(I_i1_low_like, 127, dtype="i1", endpoint=True) +def_gen.integers(I_i1_high_open, dtype="i1") +def_gen.integers(I_i1_low, I_i1_high_open, dtype="i1") +def_gen.integers(-128, I_i1_high_open, dtype="i1") +def_gen.integers(I_i1_high_closed, dtype="i1", endpoint=True) +def_gen.integers(I_i1_low, I_i1_high_closed, dtype="i1", endpoint=True) +def_gen.integers(-128, I_i1_high_closed, dtype="i1", endpoint=True) + +def_gen.integers(128, dtype="int8") +def_gen.integers(-128, 128, dtype="int8") +def_gen.integers(127, dtype="int8", endpoint=True) +def_gen.integers(-128, 127, dtype="int8", endpoint=True) +def_gen.integers(I_i1_low_like, 127, dtype="int8", endpoint=True) +def_gen.integers(I_i1_high_open, dtype="int8") +def_gen.integers(I_i1_low, I_i1_high_open, dtype="int8") +def_gen.integers(-128, I_i1_high_open, dtype="int8") +def_gen.integers(I_i1_high_closed, dtype="int8", endpoint=True) +def_gen.integers(I_i1_low, I_i1_high_closed, dtype="int8", endpoint=True) +def_gen.integers(-128, I_i1_high_closed, dtype="int8", endpoint=True) + +def_gen.integers(128, dtype=np.int8) +def_gen.integers(-128, 128, dtype=np.int8) +def_gen.integers(127, dtype=np.int8, endpoint=True) +def_gen.integers(-128, 127, dtype=np.int8, endpoint=True) +def_gen.integers(I_i1_low_like, 127, dtype=np.int8, endpoint=True) +def_gen.integers(I_i1_high_open, dtype=np.int8) +def_gen.integers(I_i1_low, I_i1_high_open, dtype=np.int8) +def_gen.integers(-128, I_i1_high_open, dtype=np.int8) +def_gen.integers(I_i1_high_closed, dtype=np.int8, endpoint=True) +def_gen.integers(I_i1_low, I_i1_high_closed, dtype=np.int8, endpoint=True) +def_gen.integers(-128, I_i1_high_closed, dtype=np.int8, endpoint=True) + +I_i2_low: np.ndarray[Any, np.dtype[np.int16]] = np.array([-32768], dtype=np.int16) +I_i2_low_like: list[int] = [-32768] +I_i2_high_open: np.ndarray[Any, np.dtype[np.int16]] = np.array([32767], dtype=np.int16) +I_i2_high_closed: np.ndarray[Any, np.dtype[np.int16]] = np.array([32767], dtype=np.int16) + +def_gen.integers(32768, dtype="i2") +def_gen.integers(-32768, 32768, dtype="i2") +def_gen.integers(32767, dtype="i2", endpoint=True) +def_gen.integers(-32768, 32767, dtype="i2", endpoint=True) +def_gen.integers(I_i2_low_like, 32767, dtype="i2", endpoint=True) +def_gen.integers(I_i2_high_open, dtype="i2") +def_gen.integers(I_i2_low, I_i2_high_open, dtype="i2") +def_gen.integers(-32768, I_i2_high_open, dtype="i2") +def_gen.integers(I_i2_high_closed, dtype="i2", endpoint=True) +def_gen.integers(I_i2_low, I_i2_high_closed, dtype="i2", endpoint=True) +def_gen.integers(-32768, I_i2_high_closed, dtype="i2", endpoint=True) + +def_gen.integers(32768, dtype="int16") +def_gen.integers(-32768, 32768, dtype="int16") +def_gen.integers(32767, dtype="int16", endpoint=True) +def_gen.integers(-32768, 32767, dtype="int16", endpoint=True) +def_gen.integers(I_i2_low_like, 32767, dtype="int16", endpoint=True) +def_gen.integers(I_i2_high_open, dtype="int16") +def_gen.integers(I_i2_low, I_i2_high_open, dtype="int16") +def_gen.integers(-32768, I_i2_high_open, dtype="int16") +def_gen.integers(I_i2_high_closed, dtype="int16", endpoint=True) +def_gen.integers(I_i2_low, I_i2_high_closed, dtype="int16", endpoint=True) +def_gen.integers(-32768, I_i2_high_closed, dtype="int16", endpoint=True) + +def_gen.integers(32768, dtype=np.int16) +def_gen.integers(-32768, 32768, dtype=np.int16) +def_gen.integers(32767, dtype=np.int16, endpoint=True) +def_gen.integers(-32768, 32767, dtype=np.int16, endpoint=True) +def_gen.integers(I_i2_low_like, 32767, dtype=np.int16, endpoint=True) +def_gen.integers(I_i2_high_open, dtype=np.int16) +def_gen.integers(I_i2_low, I_i2_high_open, dtype=np.int16) +def_gen.integers(-32768, I_i2_high_open, dtype=np.int16) +def_gen.integers(I_i2_high_closed, dtype=np.int16, endpoint=True) +def_gen.integers(I_i2_low, I_i2_high_closed, dtype=np.int16, endpoint=True) +def_gen.integers(-32768, I_i2_high_closed, dtype=np.int16, endpoint=True) + +I_i4_low: np.ndarray[Any, np.dtype[np.int32]] = np.array([-2147483648], dtype=np.int32) +I_i4_low_like: list[int] = [-2147483648] +I_i4_high_open: np.ndarray[Any, np.dtype[np.int32]] = np.array([2147483647], dtype=np.int32) +I_i4_high_closed: np.ndarray[Any, np.dtype[np.int32]] = np.array([2147483647], dtype=np.int32) + +def_gen.integers(2147483648, dtype="i4") +def_gen.integers(-2147483648, 2147483648, dtype="i4") +def_gen.integers(2147483647, dtype="i4", endpoint=True) +def_gen.integers(-2147483648, 2147483647, dtype="i4", endpoint=True) +def_gen.integers(I_i4_low_like, 2147483647, dtype="i4", endpoint=True) +def_gen.integers(I_i4_high_open, dtype="i4") +def_gen.integers(I_i4_low, I_i4_high_open, dtype="i4") +def_gen.integers(-2147483648, I_i4_high_open, dtype="i4") +def_gen.integers(I_i4_high_closed, dtype="i4", endpoint=True) +def_gen.integers(I_i4_low, I_i4_high_closed, dtype="i4", endpoint=True) +def_gen.integers(-2147483648, I_i4_high_closed, dtype="i4", endpoint=True) + +def_gen.integers(2147483648, dtype="int32") +def_gen.integers(-2147483648, 2147483648, dtype="int32") +def_gen.integers(2147483647, dtype="int32", endpoint=True) +def_gen.integers(-2147483648, 2147483647, dtype="int32", endpoint=True) +def_gen.integers(I_i4_low_like, 2147483647, dtype="int32", endpoint=True) +def_gen.integers(I_i4_high_open, dtype="int32") +def_gen.integers(I_i4_low, I_i4_high_open, dtype="int32") +def_gen.integers(-2147483648, I_i4_high_open, dtype="int32") +def_gen.integers(I_i4_high_closed, dtype="int32", endpoint=True) +def_gen.integers(I_i4_low, I_i4_high_closed, dtype="int32", endpoint=True) +def_gen.integers(-2147483648, I_i4_high_closed, dtype="int32", endpoint=True) + +def_gen.integers(2147483648, dtype=np.int32) +def_gen.integers(-2147483648, 2147483648, dtype=np.int32) +def_gen.integers(2147483647, dtype=np.int32, endpoint=True) +def_gen.integers(-2147483648, 2147483647, dtype=np.int32, endpoint=True) +def_gen.integers(I_i4_low_like, 2147483647, dtype=np.int32, endpoint=True) +def_gen.integers(I_i4_high_open, dtype=np.int32) +def_gen.integers(I_i4_low, I_i4_high_open, dtype=np.int32) +def_gen.integers(-2147483648, I_i4_high_open, dtype=np.int32) +def_gen.integers(I_i4_high_closed, dtype=np.int32, endpoint=True) +def_gen.integers(I_i4_low, I_i4_high_closed, dtype=np.int32, endpoint=True) +def_gen.integers(-2147483648, I_i4_high_closed, dtype=np.int32, endpoint=True) + +I_i8_low: np.ndarray[Any, np.dtype[np.int64]] = np.array([-9223372036854775808], dtype=np.int64) +I_i8_low_like: list[int] = [-9223372036854775808] +I_i8_high_open: np.ndarray[Any, np.dtype[np.int64]] = np.array([9223372036854775807], dtype=np.int64) +I_i8_high_closed: np.ndarray[Any, np.dtype[np.int64]] = np.array([9223372036854775807], dtype=np.int64) + +def_gen.integers(9223372036854775808, dtype="i8") +def_gen.integers(-9223372036854775808, 9223372036854775808, dtype="i8") +def_gen.integers(9223372036854775807, dtype="i8", endpoint=True) +def_gen.integers(-9223372036854775808, 9223372036854775807, dtype="i8", endpoint=True) +def_gen.integers(I_i8_low_like, 9223372036854775807, dtype="i8", endpoint=True) +def_gen.integers(I_i8_high_open, dtype="i8") +def_gen.integers(I_i8_low, I_i8_high_open, dtype="i8") +def_gen.integers(-9223372036854775808, I_i8_high_open, dtype="i8") +def_gen.integers(I_i8_high_closed, dtype="i8", endpoint=True) +def_gen.integers(I_i8_low, I_i8_high_closed, dtype="i8", endpoint=True) +def_gen.integers(-9223372036854775808, I_i8_high_closed, dtype="i8", endpoint=True) + +def_gen.integers(9223372036854775808, dtype="int64") +def_gen.integers(-9223372036854775808, 9223372036854775808, dtype="int64") +def_gen.integers(9223372036854775807, dtype="int64", endpoint=True) +def_gen.integers(-9223372036854775808, 9223372036854775807, dtype="int64", endpoint=True) +def_gen.integers(I_i8_low_like, 9223372036854775807, dtype="int64", endpoint=True) +def_gen.integers(I_i8_high_open, dtype="int64") +def_gen.integers(I_i8_low, I_i8_high_open, dtype="int64") +def_gen.integers(-9223372036854775808, I_i8_high_open, dtype="int64") +def_gen.integers(I_i8_high_closed, dtype="int64", endpoint=True) +def_gen.integers(I_i8_low, I_i8_high_closed, dtype="int64", endpoint=True) +def_gen.integers(-9223372036854775808, I_i8_high_closed, dtype="int64", endpoint=True) + +def_gen.integers(9223372036854775808, dtype=np.int64) +def_gen.integers(-9223372036854775808, 9223372036854775808, dtype=np.int64) +def_gen.integers(9223372036854775807, dtype=np.int64, endpoint=True) +def_gen.integers(-9223372036854775808, 9223372036854775807, dtype=np.int64, endpoint=True) +def_gen.integers(I_i8_low_like, 9223372036854775807, dtype=np.int64, endpoint=True) +def_gen.integers(I_i8_high_open, dtype=np.int64) +def_gen.integers(I_i8_low, I_i8_high_open, dtype=np.int64) +def_gen.integers(-9223372036854775808, I_i8_high_open, dtype=np.int64) +def_gen.integers(I_i8_high_closed, dtype=np.int64, endpoint=True) +def_gen.integers(I_i8_low, I_i8_high_closed, dtype=np.int64, endpoint=True) +def_gen.integers(-9223372036854775808, I_i8_high_closed, dtype=np.int64, endpoint=True) + + +def_gen.bit_generator + +def_gen.bytes(2) + +def_gen.choice(5) +def_gen.choice(5, 3) +def_gen.choice(5, 3, replace=True) +def_gen.choice(5, 3, p=[1 / 5] * 5) +def_gen.choice(5, 3, p=[1 / 5] * 5, replace=False) + +def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"]) +def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3) +def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, p=[1 / 4] * 4) +def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=True) +def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=False, p=np.array([1 / 8, 1 / 8, 1 / 2, 1 / 4])) + +def_gen.dirichlet([0.5, 0.5]) +def_gen.dirichlet(np.array([0.5, 0.5])) +def_gen.dirichlet(np.array([0.5, 0.5]), size=3) + +def_gen.multinomial(20, [1 / 6.0] * 6) +def_gen.multinomial(20, np.array([0.5, 0.5])) +def_gen.multinomial(20, [1 / 6.0] * 6, size=2) +def_gen.multinomial([[10], [20]], [1 / 6.0] * 6, size=(2, 2)) +def_gen.multinomial(np.array([[10], [20]]), np.array([0.5, 0.5]), size=(2, 2)) + +def_gen.multivariate_hypergeometric([3, 5, 7], 2) +def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2) +def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2, size=4) +def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2, size=(4, 7)) +def_gen.multivariate_hypergeometric([3, 5, 7], 2, method="count") +def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2, method="marginals") + +def_gen.multivariate_normal([0.0], [[1.0]]) +def_gen.multivariate_normal([0.0], np.array([[1.0]])) +def_gen.multivariate_normal(np.array([0.0]), [[1.0]]) +def_gen.multivariate_normal([0.0], np.array([[1.0]])) + +def_gen.permutation(10) +def_gen.permutation([1, 2, 3, 4]) +def_gen.permutation(np.array([1, 2, 3, 4])) +def_gen.permutation(D_2D, axis=1) +def_gen.permuted(D_2D) +def_gen.permuted(D_2D_like) +def_gen.permuted(D_2D, axis=1) +def_gen.permuted(D_2D, out=D_2D) +def_gen.permuted(D_2D_like, out=D_2D) +def_gen.permuted(D_2D_like, out=D_2D) +def_gen.permuted(D_2D, axis=1, out=D_2D) + +def_gen.shuffle(np.arange(10)) +def_gen.shuffle([1, 2, 3, 4, 5]) +def_gen.shuffle(D_2D, axis=1) + +def_gen.__str__() +def_gen.__repr__() +def_gen.__setstate__(dict(def_gen.bit_generator.state)) + +# RandomState +random_st: np.random.RandomState = np.random.RandomState() + +random_st.standard_normal() +random_st.standard_normal(size=None) +random_st.standard_normal(size=1) + +random_st.random() +random_st.random(size=None) +random_st.random(size=1) + +random_st.standard_cauchy() +random_st.standard_cauchy(size=None) +random_st.standard_cauchy(size=1) + +random_st.standard_exponential() +random_st.standard_exponential(size=None) +random_st.standard_exponential(size=1) + +random_st.zipf(1.5) +random_st.zipf(1.5, size=None) +random_st.zipf(1.5, size=1) +random_st.zipf(D_arr_1p5) +random_st.zipf(D_arr_1p5, size=1) +random_st.zipf(D_arr_like_1p5) +random_st.zipf(D_arr_like_1p5, size=1) + +random_st.weibull(0.5) +random_st.weibull(0.5, size=None) +random_st.weibull(0.5, size=1) +random_st.weibull(D_arr_0p5) +random_st.weibull(D_arr_0p5, size=1) +random_st.weibull(D_arr_like_0p5) +random_st.weibull(D_arr_like_0p5, size=1) + +random_st.standard_t(0.5) +random_st.standard_t(0.5, size=None) +random_st.standard_t(0.5, size=1) +random_st.standard_t(D_arr_0p5) +random_st.standard_t(D_arr_0p5, size=1) +random_st.standard_t(D_arr_like_0p5) +random_st.standard_t(D_arr_like_0p5, size=1) + +random_st.poisson(0.5) +random_st.poisson(0.5, size=None) +random_st.poisson(0.5, size=1) +random_st.poisson(D_arr_0p5) +random_st.poisson(D_arr_0p5, size=1) +random_st.poisson(D_arr_like_0p5) +random_st.poisson(D_arr_like_0p5, size=1) + +random_st.power(0.5) +random_st.power(0.5, size=None) +random_st.power(0.5, size=1) +random_st.power(D_arr_0p5) +random_st.power(D_arr_0p5, size=1) +random_st.power(D_arr_like_0p5) +random_st.power(D_arr_like_0p5, size=1) + +random_st.pareto(0.5) +random_st.pareto(0.5, size=None) +random_st.pareto(0.5, size=1) +random_st.pareto(D_arr_0p5) +random_st.pareto(D_arr_0p5, size=1) +random_st.pareto(D_arr_like_0p5) +random_st.pareto(D_arr_like_0p5, size=1) + +random_st.chisquare(0.5) +random_st.chisquare(0.5, size=None) +random_st.chisquare(0.5, size=1) +random_st.chisquare(D_arr_0p5) +random_st.chisquare(D_arr_0p5, size=1) +random_st.chisquare(D_arr_like_0p5) +random_st.chisquare(D_arr_like_0p5, size=1) + +random_st.exponential(0.5) +random_st.exponential(0.5, size=None) +random_st.exponential(0.5, size=1) +random_st.exponential(D_arr_0p5) +random_st.exponential(D_arr_0p5, size=1) +random_st.exponential(D_arr_like_0p5) +random_st.exponential(D_arr_like_0p5, size=1) + +random_st.geometric(0.5) +random_st.geometric(0.5, size=None) +random_st.geometric(0.5, size=1) +random_st.geometric(D_arr_0p5) +random_st.geometric(D_arr_0p5, size=1) +random_st.geometric(D_arr_like_0p5) +random_st.geometric(D_arr_like_0p5, size=1) + +random_st.logseries(0.5) +random_st.logseries(0.5, size=None) +random_st.logseries(0.5, size=1) +random_st.logseries(D_arr_0p5) +random_st.logseries(D_arr_0p5, size=1) +random_st.logseries(D_arr_like_0p5) +random_st.logseries(D_arr_like_0p5, size=1) + +random_st.rayleigh(0.5) +random_st.rayleigh(0.5, size=None) +random_st.rayleigh(0.5, size=1) +random_st.rayleigh(D_arr_0p5) +random_st.rayleigh(D_arr_0p5, size=1) +random_st.rayleigh(D_arr_like_0p5) +random_st.rayleigh(D_arr_like_0p5, size=1) + +random_st.standard_gamma(0.5) +random_st.standard_gamma(0.5, size=None) +random_st.standard_gamma(0.5, size=1) +random_st.standard_gamma(D_arr_0p5) +random_st.standard_gamma(D_arr_0p5, size=1) +random_st.standard_gamma(D_arr_like_0p5) +random_st.standard_gamma(D_arr_like_0p5, size=1) +random_st.standard_gamma(D_arr_like_0p5, size=1) + +random_st.vonmises(0.5, 0.5) +random_st.vonmises(0.5, 0.5, size=None) +random_st.vonmises(0.5, 0.5, size=1) +random_st.vonmises(D_arr_0p5, 0.5) +random_st.vonmises(0.5, D_arr_0p5) +random_st.vonmises(D_arr_0p5, 0.5, size=1) +random_st.vonmises(0.5, D_arr_0p5, size=1) +random_st.vonmises(D_arr_like_0p5, 0.5) +random_st.vonmises(0.5, D_arr_like_0p5) +random_st.vonmises(D_arr_0p5, D_arr_0p5) +random_st.vonmises(D_arr_like_0p5, D_arr_like_0p5) +random_st.vonmises(D_arr_0p5, D_arr_0p5, size=1) +random_st.vonmises(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.wald(0.5, 0.5) +random_st.wald(0.5, 0.5, size=None) +random_st.wald(0.5, 0.5, size=1) +random_st.wald(D_arr_0p5, 0.5) +random_st.wald(0.5, D_arr_0p5) +random_st.wald(D_arr_0p5, 0.5, size=1) +random_st.wald(0.5, D_arr_0p5, size=1) +random_st.wald(D_arr_like_0p5, 0.5) +random_st.wald(0.5, D_arr_like_0p5) +random_st.wald(D_arr_0p5, D_arr_0p5) +random_st.wald(D_arr_like_0p5, D_arr_like_0p5) +random_st.wald(D_arr_0p5, D_arr_0p5, size=1) +random_st.wald(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.uniform(0.5, 0.5) +random_st.uniform(0.5, 0.5, size=None) +random_st.uniform(0.5, 0.5, size=1) +random_st.uniform(D_arr_0p5, 0.5) +random_st.uniform(0.5, D_arr_0p5) +random_st.uniform(D_arr_0p5, 0.5, size=1) +random_st.uniform(0.5, D_arr_0p5, size=1) +random_st.uniform(D_arr_like_0p5, 0.5) +random_st.uniform(0.5, D_arr_like_0p5) +random_st.uniform(D_arr_0p5, D_arr_0p5) +random_st.uniform(D_arr_like_0p5, D_arr_like_0p5) +random_st.uniform(D_arr_0p5, D_arr_0p5, size=1) +random_st.uniform(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.beta(0.5, 0.5) +random_st.beta(0.5, 0.5, size=None) +random_st.beta(0.5, 0.5, size=1) +random_st.beta(D_arr_0p5, 0.5) +random_st.beta(0.5, D_arr_0p5) +random_st.beta(D_arr_0p5, 0.5, size=1) +random_st.beta(0.5, D_arr_0p5, size=1) +random_st.beta(D_arr_like_0p5, 0.5) +random_st.beta(0.5, D_arr_like_0p5) +random_st.beta(D_arr_0p5, D_arr_0p5) +random_st.beta(D_arr_like_0p5, D_arr_like_0p5) +random_st.beta(D_arr_0p5, D_arr_0p5, size=1) +random_st.beta(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.f(0.5, 0.5) +random_st.f(0.5, 0.5, size=None) +random_st.f(0.5, 0.5, size=1) +random_st.f(D_arr_0p5, 0.5) +random_st.f(0.5, D_arr_0p5) +random_st.f(D_arr_0p5, 0.5, size=1) +random_st.f(0.5, D_arr_0p5, size=1) +random_st.f(D_arr_like_0p5, 0.5) +random_st.f(0.5, D_arr_like_0p5) +random_st.f(D_arr_0p5, D_arr_0p5) +random_st.f(D_arr_like_0p5, D_arr_like_0p5) +random_st.f(D_arr_0p5, D_arr_0p5, size=1) +random_st.f(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.gamma(0.5, 0.5) +random_st.gamma(0.5, 0.5, size=None) +random_st.gamma(0.5, 0.5, size=1) +random_st.gamma(D_arr_0p5, 0.5) +random_st.gamma(0.5, D_arr_0p5) +random_st.gamma(D_arr_0p5, 0.5, size=1) +random_st.gamma(0.5, D_arr_0p5, size=1) +random_st.gamma(D_arr_like_0p5, 0.5) +random_st.gamma(0.5, D_arr_like_0p5) +random_st.gamma(D_arr_0p5, D_arr_0p5) +random_st.gamma(D_arr_like_0p5, D_arr_like_0p5) +random_st.gamma(D_arr_0p5, D_arr_0p5, size=1) +random_st.gamma(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.gumbel(0.5, 0.5) +random_st.gumbel(0.5, 0.5, size=None) +random_st.gumbel(0.5, 0.5, size=1) +random_st.gumbel(D_arr_0p5, 0.5) +random_st.gumbel(0.5, D_arr_0p5) +random_st.gumbel(D_arr_0p5, 0.5, size=1) +random_st.gumbel(0.5, D_arr_0p5, size=1) +random_st.gumbel(D_arr_like_0p5, 0.5) +random_st.gumbel(0.5, D_arr_like_0p5) +random_st.gumbel(D_arr_0p5, D_arr_0p5) +random_st.gumbel(D_arr_like_0p5, D_arr_like_0p5) +random_st.gumbel(D_arr_0p5, D_arr_0p5, size=1) +random_st.gumbel(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.laplace(0.5, 0.5) +random_st.laplace(0.5, 0.5, size=None) +random_st.laplace(0.5, 0.5, size=1) +random_st.laplace(D_arr_0p5, 0.5) +random_st.laplace(0.5, D_arr_0p5) +random_st.laplace(D_arr_0p5, 0.5, size=1) +random_st.laplace(0.5, D_arr_0p5, size=1) +random_st.laplace(D_arr_like_0p5, 0.5) +random_st.laplace(0.5, D_arr_like_0p5) +random_st.laplace(D_arr_0p5, D_arr_0p5) +random_st.laplace(D_arr_like_0p5, D_arr_like_0p5) +random_st.laplace(D_arr_0p5, D_arr_0p5, size=1) +random_st.laplace(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.logistic(0.5, 0.5) +random_st.logistic(0.5, 0.5, size=None) +random_st.logistic(0.5, 0.5, size=1) +random_st.logistic(D_arr_0p5, 0.5) +random_st.logistic(0.5, D_arr_0p5) +random_st.logistic(D_arr_0p5, 0.5, size=1) +random_st.logistic(0.5, D_arr_0p5, size=1) +random_st.logistic(D_arr_like_0p5, 0.5) +random_st.logistic(0.5, D_arr_like_0p5) +random_st.logistic(D_arr_0p5, D_arr_0p5) +random_st.logistic(D_arr_like_0p5, D_arr_like_0p5) +random_st.logistic(D_arr_0p5, D_arr_0p5, size=1) +random_st.logistic(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.lognormal(0.5, 0.5) +random_st.lognormal(0.5, 0.5, size=None) +random_st.lognormal(0.5, 0.5, size=1) +random_st.lognormal(D_arr_0p5, 0.5) +random_st.lognormal(0.5, D_arr_0p5) +random_st.lognormal(D_arr_0p5, 0.5, size=1) +random_st.lognormal(0.5, D_arr_0p5, size=1) +random_st.lognormal(D_arr_like_0p5, 0.5) +random_st.lognormal(0.5, D_arr_like_0p5) +random_st.lognormal(D_arr_0p5, D_arr_0p5) +random_st.lognormal(D_arr_like_0p5, D_arr_like_0p5) +random_st.lognormal(D_arr_0p5, D_arr_0p5, size=1) +random_st.lognormal(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.noncentral_chisquare(0.5, 0.5) +random_st.noncentral_chisquare(0.5, 0.5, size=None) +random_st.noncentral_chisquare(0.5, 0.5, size=1) +random_st.noncentral_chisquare(D_arr_0p5, 0.5) +random_st.noncentral_chisquare(0.5, D_arr_0p5) +random_st.noncentral_chisquare(D_arr_0p5, 0.5, size=1) +random_st.noncentral_chisquare(0.5, D_arr_0p5, size=1) +random_st.noncentral_chisquare(D_arr_like_0p5, 0.5) +random_st.noncentral_chisquare(0.5, D_arr_like_0p5) +random_st.noncentral_chisquare(D_arr_0p5, D_arr_0p5) +random_st.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5) +random_st.noncentral_chisquare(D_arr_0p5, D_arr_0p5, size=1) +random_st.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.normal(0.5, 0.5) +random_st.normal(0.5, 0.5, size=None) +random_st.normal(0.5, 0.5, size=1) +random_st.normal(D_arr_0p5, 0.5) +random_st.normal(0.5, D_arr_0p5) +random_st.normal(D_arr_0p5, 0.5, size=1) +random_st.normal(0.5, D_arr_0p5, size=1) +random_st.normal(D_arr_like_0p5, 0.5) +random_st.normal(0.5, D_arr_like_0p5) +random_st.normal(D_arr_0p5, D_arr_0p5) +random_st.normal(D_arr_like_0p5, D_arr_like_0p5) +random_st.normal(D_arr_0p5, D_arr_0p5, size=1) +random_st.normal(D_arr_like_0p5, D_arr_like_0p5, size=1) + +random_st.triangular(0.1, 0.5, 0.9) +random_st.triangular(0.1, 0.5, 0.9, size=None) +random_st.triangular(0.1, 0.5, 0.9, size=1) +random_st.triangular(D_arr_0p1, 0.5, 0.9) +random_st.triangular(0.1, D_arr_0p5, 0.9) +random_st.triangular(D_arr_0p1, 0.5, D_arr_like_0p9, size=1) +random_st.triangular(0.1, D_arr_0p5, 0.9, size=1) +random_st.triangular(D_arr_like_0p1, 0.5, D_arr_0p9) +random_st.triangular(0.5, D_arr_like_0p5, 0.9) +random_st.triangular(D_arr_0p1, D_arr_0p5, 0.9) +random_st.triangular(D_arr_like_0p1, D_arr_like_0p5, 0.9) +random_st.triangular(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1) +random_st.triangular(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1) + +random_st.noncentral_f(0.1, 0.5, 0.9) +random_st.noncentral_f(0.1, 0.5, 0.9, size=None) +random_st.noncentral_f(0.1, 0.5, 0.9, size=1) +random_st.noncentral_f(D_arr_0p1, 0.5, 0.9) +random_st.noncentral_f(0.1, D_arr_0p5, 0.9) +random_st.noncentral_f(D_arr_0p1, 0.5, D_arr_like_0p9, size=1) +random_st.noncentral_f(0.1, D_arr_0p5, 0.9, size=1) +random_st.noncentral_f(D_arr_like_0p1, 0.5, D_arr_0p9) +random_st.noncentral_f(0.5, D_arr_like_0p5, 0.9) +random_st.noncentral_f(D_arr_0p1, D_arr_0p5, 0.9) +random_st.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, 0.9) +random_st.noncentral_f(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1) +random_st.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1) + +random_st.binomial(10, 0.5) +random_st.binomial(10, 0.5, size=None) +random_st.binomial(10, 0.5, size=1) +random_st.binomial(I_arr_10, 0.5) +random_st.binomial(10, D_arr_0p5) +random_st.binomial(I_arr_10, 0.5, size=1) +random_st.binomial(10, D_arr_0p5, size=1) +random_st.binomial(I_arr_like_10, 0.5) +random_st.binomial(10, D_arr_like_0p5) +random_st.binomial(I_arr_10, D_arr_0p5) +random_st.binomial(I_arr_like_10, D_arr_like_0p5) +random_st.binomial(I_arr_10, D_arr_0p5, size=1) +random_st.binomial(I_arr_like_10, D_arr_like_0p5, size=1) + +random_st.negative_binomial(10, 0.5) +random_st.negative_binomial(10, 0.5, size=None) +random_st.negative_binomial(10, 0.5, size=1) +random_st.negative_binomial(I_arr_10, 0.5) +random_st.negative_binomial(10, D_arr_0p5) +random_st.negative_binomial(I_arr_10, 0.5, size=1) +random_st.negative_binomial(10, D_arr_0p5, size=1) +random_st.negative_binomial(I_arr_like_10, 0.5) +random_st.negative_binomial(10, D_arr_like_0p5) +random_st.negative_binomial(I_arr_10, D_arr_0p5) +random_st.negative_binomial(I_arr_like_10, D_arr_like_0p5) +random_st.negative_binomial(I_arr_10, D_arr_0p5, size=1) +random_st.negative_binomial(I_arr_like_10, D_arr_like_0p5, size=1) + +random_st.hypergeometric(20, 20, 10) +random_st.hypergeometric(20, 20, 10, size=None) +random_st.hypergeometric(20, 20, 10, size=1) +random_st.hypergeometric(I_arr_20, 20, 10) +random_st.hypergeometric(20, I_arr_20, 10) +random_st.hypergeometric(I_arr_20, 20, I_arr_like_10, size=1) +random_st.hypergeometric(20, I_arr_20, 10, size=1) +random_st.hypergeometric(I_arr_like_20, 20, I_arr_10) +random_st.hypergeometric(20, I_arr_like_20, 10) +random_st.hypergeometric(I_arr_20, I_arr_20, 10) +random_st.hypergeometric(I_arr_like_20, I_arr_like_20, 10) +random_st.hypergeometric(I_arr_20, I_arr_20, I_arr_10, size=1) +random_st.hypergeometric(I_arr_like_20, I_arr_like_20, I_arr_like_10, size=1) + +random_st.randint(0, 100) +random_st.randint(100) +random_st.randint([100]) +random_st.randint(0, [100]) + +random_st.randint(2, dtype=bool) +random_st.randint(0, 2, dtype=bool) +random_st.randint(I_bool_high_open, dtype=bool) +random_st.randint(I_bool_low, I_bool_high_open, dtype=bool) +random_st.randint(0, I_bool_high_open, dtype=bool) + +random_st.randint(2, dtype=np.bool) +random_st.randint(0, 2, dtype=np.bool) +random_st.randint(I_bool_high_open, dtype=np.bool) +random_st.randint(I_bool_low, I_bool_high_open, dtype=np.bool) +random_st.randint(0, I_bool_high_open, dtype=np.bool) + +random_st.randint(256, dtype="u1") +random_st.randint(0, 256, dtype="u1") +random_st.randint(I_u1_high_open, dtype="u1") +random_st.randint(I_u1_low, I_u1_high_open, dtype="u1") +random_st.randint(0, I_u1_high_open, dtype="u1") + +random_st.randint(256, dtype="uint8") +random_st.randint(0, 256, dtype="uint8") +random_st.randint(I_u1_high_open, dtype="uint8") +random_st.randint(I_u1_low, I_u1_high_open, dtype="uint8") +random_st.randint(0, I_u1_high_open, dtype="uint8") + +random_st.randint(256, dtype=np.uint8) +random_st.randint(0, 256, dtype=np.uint8) +random_st.randint(I_u1_high_open, dtype=np.uint8) +random_st.randint(I_u1_low, I_u1_high_open, dtype=np.uint8) +random_st.randint(0, I_u1_high_open, dtype=np.uint8) + +random_st.randint(65536, dtype="u2") +random_st.randint(0, 65536, dtype="u2") +random_st.randint(I_u2_high_open, dtype="u2") +random_st.randint(I_u2_low, I_u2_high_open, dtype="u2") +random_st.randint(0, I_u2_high_open, dtype="u2") + +random_st.randint(65536, dtype="uint16") +random_st.randint(0, 65536, dtype="uint16") +random_st.randint(I_u2_high_open, dtype="uint16") +random_st.randint(I_u2_low, I_u2_high_open, dtype="uint16") +random_st.randint(0, I_u2_high_open, dtype="uint16") + +random_st.randint(65536, dtype=np.uint16) +random_st.randint(0, 65536, dtype=np.uint16) +random_st.randint(I_u2_high_open, dtype=np.uint16) +random_st.randint(I_u2_low, I_u2_high_open, dtype=np.uint16) +random_st.randint(0, I_u2_high_open, dtype=np.uint16) + +random_st.randint(4294967296, dtype="u4") +random_st.randint(0, 4294967296, dtype="u4") +random_st.randint(I_u4_high_open, dtype="u4") +random_st.randint(I_u4_low, I_u4_high_open, dtype="u4") +random_st.randint(0, I_u4_high_open, dtype="u4") + +random_st.randint(4294967296, dtype="uint32") +random_st.randint(0, 4294967296, dtype="uint32") +random_st.randint(I_u4_high_open, dtype="uint32") +random_st.randint(I_u4_low, I_u4_high_open, dtype="uint32") +random_st.randint(0, I_u4_high_open, dtype="uint32") + +random_st.randint(4294967296, dtype=np.uint32) +random_st.randint(0, 4294967296, dtype=np.uint32) +random_st.randint(I_u4_high_open, dtype=np.uint32) +random_st.randint(I_u4_low, I_u4_high_open, dtype=np.uint32) +random_st.randint(0, I_u4_high_open, dtype=np.uint32) + + +random_st.randint(18446744073709551616, dtype="u8") +random_st.randint(0, 18446744073709551616, dtype="u8") +random_st.randint(I_u8_high_open, dtype="u8") +random_st.randint(I_u8_low, I_u8_high_open, dtype="u8") +random_st.randint(0, I_u8_high_open, dtype="u8") + +random_st.randint(18446744073709551616, dtype="uint64") +random_st.randint(0, 18446744073709551616, dtype="uint64") +random_st.randint(I_u8_high_open, dtype="uint64") +random_st.randint(I_u8_low, I_u8_high_open, dtype="uint64") +random_st.randint(0, I_u8_high_open, dtype="uint64") + +random_st.randint(18446744073709551616, dtype=np.uint64) +random_st.randint(0, 18446744073709551616, dtype=np.uint64) +random_st.randint(I_u8_high_open, dtype=np.uint64) +random_st.randint(I_u8_low, I_u8_high_open, dtype=np.uint64) +random_st.randint(0, I_u8_high_open, dtype=np.uint64) + +random_st.randint(128, dtype="i1") +random_st.randint(-128, 128, dtype="i1") +random_st.randint(I_i1_high_open, dtype="i1") +random_st.randint(I_i1_low, I_i1_high_open, dtype="i1") +random_st.randint(-128, I_i1_high_open, dtype="i1") + +random_st.randint(128, dtype="int8") +random_st.randint(-128, 128, dtype="int8") +random_st.randint(I_i1_high_open, dtype="int8") +random_st.randint(I_i1_low, I_i1_high_open, dtype="int8") +random_st.randint(-128, I_i1_high_open, dtype="int8") + +random_st.randint(128, dtype=np.int8) +random_st.randint(-128, 128, dtype=np.int8) +random_st.randint(I_i1_high_open, dtype=np.int8) +random_st.randint(I_i1_low, I_i1_high_open, dtype=np.int8) +random_st.randint(-128, I_i1_high_open, dtype=np.int8) + +random_st.randint(32768, dtype="i2") +random_st.randint(-32768, 32768, dtype="i2") +random_st.randint(I_i2_high_open, dtype="i2") +random_st.randint(I_i2_low, I_i2_high_open, dtype="i2") +random_st.randint(-32768, I_i2_high_open, dtype="i2") +random_st.randint(32768, dtype="int16") +random_st.randint(-32768, 32768, dtype="int16") +random_st.randint(I_i2_high_open, dtype="int16") +random_st.randint(I_i2_low, I_i2_high_open, dtype="int16") +random_st.randint(-32768, I_i2_high_open, dtype="int16") +random_st.randint(32768, dtype=np.int16) +random_st.randint(-32768, 32768, dtype=np.int16) +random_st.randint(I_i2_high_open, dtype=np.int16) +random_st.randint(I_i2_low, I_i2_high_open, dtype=np.int16) +random_st.randint(-32768, I_i2_high_open, dtype=np.int16) + +random_st.randint(2147483648, dtype="i4") +random_st.randint(-2147483648, 2147483648, dtype="i4") +random_st.randint(I_i4_high_open, dtype="i4") +random_st.randint(I_i4_low, I_i4_high_open, dtype="i4") +random_st.randint(-2147483648, I_i4_high_open, dtype="i4") + +random_st.randint(2147483648, dtype="int32") +random_st.randint(-2147483648, 2147483648, dtype="int32") +random_st.randint(I_i4_high_open, dtype="int32") +random_st.randint(I_i4_low, I_i4_high_open, dtype="int32") +random_st.randint(-2147483648, I_i4_high_open, dtype="int32") + +random_st.randint(2147483648, dtype=np.int32) +random_st.randint(-2147483648, 2147483648, dtype=np.int32) +random_st.randint(I_i4_high_open, dtype=np.int32) +random_st.randint(I_i4_low, I_i4_high_open, dtype=np.int32) +random_st.randint(-2147483648, I_i4_high_open, dtype=np.int32) + +random_st.randint(9223372036854775808, dtype="i8") +random_st.randint(-9223372036854775808, 9223372036854775808, dtype="i8") +random_st.randint(I_i8_high_open, dtype="i8") +random_st.randint(I_i8_low, I_i8_high_open, dtype="i8") +random_st.randint(-9223372036854775808, I_i8_high_open, dtype="i8") + +random_st.randint(9223372036854775808, dtype="int64") +random_st.randint(-9223372036854775808, 9223372036854775808, dtype="int64") +random_st.randint(I_i8_high_open, dtype="int64") +random_st.randint(I_i8_low, I_i8_high_open, dtype="int64") +random_st.randint(-9223372036854775808, I_i8_high_open, dtype="int64") + +random_st.randint(9223372036854775808, dtype=np.int64) +random_st.randint(-9223372036854775808, 9223372036854775808, dtype=np.int64) +random_st.randint(I_i8_high_open, dtype=np.int64) +random_st.randint(I_i8_low, I_i8_high_open, dtype=np.int64) +random_st.randint(-9223372036854775808, I_i8_high_open, dtype=np.int64) + +bg: np.random.BitGenerator = random_st._bit_generator + +random_st.bytes(2) + +random_st.choice(5) +random_st.choice(5, 3) +random_st.choice(5, 3, replace=True) +random_st.choice(5, 3, p=[1 / 5] * 5) +random_st.choice(5, 3, p=[1 / 5] * 5, replace=False) + +random_st.choice(["pooh", "rabbit", "piglet", "Christopher"]) +random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3) +random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, p=[1 / 4] * 4) +random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=True) +random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=False, p=np.array([1 / 8, 1 / 8, 1 / 2, 1 / 4])) + +random_st.dirichlet([0.5, 0.5]) +random_st.dirichlet(np.array([0.5, 0.5])) +random_st.dirichlet(np.array([0.5, 0.5]), size=3) + +random_st.multinomial(20, [1 / 6.0] * 6) +random_st.multinomial(20, np.array([0.5, 0.5])) +random_st.multinomial(20, [1 / 6.0] * 6, size=2) + +random_st.multivariate_normal([0.0], [[1.0]]) +random_st.multivariate_normal([0.0], np.array([[1.0]])) +random_st.multivariate_normal(np.array([0.0]), [[1.0]]) +random_st.multivariate_normal([0.0], np.array([[1.0]])) + +random_st.permutation(10) +random_st.permutation([1, 2, 3, 4]) +random_st.permutation(np.array([1, 2, 3, 4])) +random_st.permutation(D_2D) + +random_st.shuffle(np.arange(10)) +random_st.shuffle([1, 2, 3, 4, 5]) +random_st.shuffle(D_2D) + +np.random.RandomState(SEED_PCG64) +np.random.RandomState(0) +np.random.RandomState([0, 1, 2]) +random_st.__str__() +random_st.__repr__() +random_st_state = random_st.__getstate__() +random_st.__setstate__(random_st_state) +random_st.seed() +random_st.seed(1) +random_st.seed([0, 1]) +random_st_get_state = random_st.get_state() +random_st_get_state_legacy = random_st.get_state(legacy=True) +random_st.set_state(random_st_get_state) + +random_st.rand() +random_st.rand(1) +random_st.rand(1, 2) +random_st.randn() +random_st.randn(1) +random_st.randn(1, 2) +random_st.random_sample() +random_st.random_sample(1) +random_st.random_sample(size=(1, 2)) + +random_st.tomaxint() +random_st.tomaxint(1) +random_st.tomaxint((1,)) + +np.random.mtrand.set_bit_generator(SEED_PCG64) +np.random.mtrand.get_bit_generator() diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/recfunctions.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/recfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..52a3d78a76220e8549e322211c2b6c46663ca048 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/recfunctions.py @@ -0,0 +1,161 @@ +"""These tests are based on the doctests from `numpy/lib/recfunctions.py`.""" + +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy.lib import recfunctions as rfn + + +def test_recursive_fill_fields() -> None: + a: npt.NDArray[np.void] = np.array( + [(1, 10.0), (2, 20.0)], + dtype=[("A", np.int64), ("B", np.float64)], + ) + b = np.zeros((int(3),), dtype=a.dtype) + out = rfn.recursive_fill_fields(a, b) + assert_type(out, np.ndarray[tuple[int], np.dtype[np.void]]) + + +def test_get_names() -> None: + names: tuple[str | Any, ...] + names = rfn.get_names(np.empty((1,), dtype=[("A", int)]).dtype) + names = rfn.get_names(np.empty((1,), dtype=[("A", int), ("B", float)]).dtype) + + adtype = np.dtype([("a", int), ("b", [("b_a", int), ("b_b", int)])]) + names = rfn.get_names(adtype) + + +def test_get_names_flat() -> None: + names: tuple[str, ...] + names = rfn.get_names_flat(np.empty((1,), dtype=[("A", int)]).dtype) + names = rfn.get_names_flat(np.empty((1,), dtype=[("A", int), ("B", float)]).dtype) + + adtype = np.dtype([("a", int), ("b", [("b_a", int), ("b_b", int)])]) + names = rfn.get_names_flat(adtype) + + +def test_flatten_descr() -> None: + ndtype = np.dtype([("a", " None: + ndtype = np.dtype([ + ("A", int), + ("B", [("B_A", int), ("B_B", [("B_B_A", int), ("B_B_B", int)])]), + ]) + assert_type(rfn.get_fieldstructure(ndtype), dict[str, list[str]]) + + +def test_merge_arrays() -> None: + assert_type( + rfn.merge_arrays(( + np.ones((int(2),), np.int_), + np.ones((int(3),), np.float64), + )), + np.recarray[tuple[int], np.dtype[np.void]], + ) + + +def test_drop_fields() -> None: + ndtype = [("a", np.int64), ("b", [("b_a", np.double), ("b_b", np.int64)])] + a = np.ones((int(3),), dtype=ndtype) + + assert_type( + rfn.drop_fields(a, "a"), + np.ndarray[tuple[int], np.dtype[np.void]], + ) + assert_type( + rfn.drop_fields(a, "a", asrecarray=True), + np.rec.recarray[tuple[int], np.dtype[np.void]], + ) + assert_type( + rfn.rec_drop_fields(a, "a"), + np.rec.recarray[tuple[int], np.dtype[np.void]], + ) + + +def test_rename_fields() -> None: + ndtype = [("a", np.int64), ("b", [("b_a", np.double), ("b_b", np.int64)])] + a = np.ones((int(3),), dtype=ndtype) + + assert_type( + rfn.rename_fields(a, {"a": "A", "b_b": "B_B"}), + np.ndarray[tuple[int], np.dtype[np.void]], + ) + + +def test_repack_fields() -> None: + dt: np.dtype[np.void] = np.dtype("u1, None: + a = np.zeros(4, dtype=[("a", "i4"), ("b", "f4,u2"), ("c", "f4", 2)]) + assert_type(rfn.structured_to_unstructured(a), npt.NDArray[Any]) + + +def unstructured_to_structured() -> None: + dt: np.dtype[np.void] = np.dtype([("a", "i4"), ("b", "f4,u2"), ("c", "f4", 2)]) + a = np.arange(20, dtype=np.int32).reshape((4, 5)) + assert_type(rfn.unstructured_to_structured(a, dt), npt.NDArray[np.void]) + + +def test_apply_along_fields() -> None: + b = np.ones(4, dtype=[("x", "i4"), ("y", "f4"), ("z", "f8")]) + assert_type( + rfn.apply_along_fields(np.mean, b), + np.ndarray[tuple[int], np.dtype[np.void]], + ) + + +def test_assign_fields_by_name() -> None: + b = np.ones(4, dtype=[("x", "i4"), ("y", "f4"), ("z", "f8")]) + assert_type( + rfn.apply_along_fields(np.mean, b), + np.ndarray[tuple[int], np.dtype[np.void]], + ) + + +def test_require_fields() -> None: + a = np.ones(4, dtype=[("a", "i4"), ("b", "f8"), ("c", "u1")]) + assert_type( + rfn.require_fields(a, [("b", "f4"), ("c", "u1")]), + np.ndarray[tuple[int], np.dtype[np.void]], + ) + + +def test_stack_arrays() -> None: + x = np.zeros((int(2),), np.int32) + assert_type( + rfn.stack_arrays(x), + np.ndarray[tuple[int], np.dtype[np.int32]], + ) + + z = np.ones((int(2),), [("A", "|S3"), ("B", float)]) + zz = np.ones((int(2),), [("A", "|S3"), ("B", np.float64), ("C", np.float64)]) + assert_type( + rfn.stack_arrays((z, zz)), + np.ma.MaskedArray[tuple[Any, ...], np.dtype[np.void]], + ) + + +def test_find_duplicates() -> None: + ndtype = np.dtype([("a", int)]) + + a = np.ma.ones(7, mask=[0, 0, 1, 0, 0, 0, 1]).view(ndtype) + assert_type(rfn.find_duplicates(a), np.ma.MaskedArray[Any, np.dtype[np.void]]) + assert_type( + rfn.find_duplicates(a, ignoremask=True, return_index=True), + tuple[ + np.ma.MaskedArray[Any, np.dtype[np.void]], + np.ndarray[Any, np.dtype[np.int_]], + ], + ) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/scalars.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/scalars.py new file mode 100644 index 0000000000000000000000000000000000000000..655903a50bcec0ee683e92ca0dfa126b1080481e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/scalars.py @@ -0,0 +1,248 @@ +import datetime as dt + +import pytest +import numpy as np + +b = np.bool() +b_ = np.bool_() +u8 = np.uint64() +i8 = np.int64() +f8 = np.float64() +c16 = np.complex128() +U = np.str_() +S = np.bytes_() + + +# Construction +class D: + def __index__(self) -> int: + return 0 + + +class C: + def __complex__(self) -> complex: + return 3j + + +class B: + def __int__(self) -> int: + return 4 + + +class A: + def __float__(self) -> float: + return 4.0 + + +np.complex64(3j) +np.complex64(A()) +np.complex64(C()) +np.complex128(3j) +np.complex128(C()) +np.complex128(None) +np.complex64("1.2") +np.complex128(b"2j") + +np.int8(4) +np.int16(3.4) +np.int32(4) +np.int64(-1) +np.uint8(B()) +np.uint32() +np.int32("1") +np.int64(b"2") + +np.float16(A()) +np.float32(16) +np.float64(3.0) +np.float64(None) +np.float32("1") +np.float16(b"2.5") + +np.uint64(D()) +np.float32(D()) +np.complex64(D()) + +np.bytes_(b"hello") +np.bytes_("hello", 'utf-8') +np.bytes_("hello", encoding='utf-8') +np.str_("hello") +np.str_(b"hello", 'utf-8') +np.str_(b"hello", encoding='utf-8') + +# Array-ish semantics +np.int8().real +np.int16().imag +np.int32().data +np.int64().flags + +np.uint8().itemsize * 2 +np.uint16().ndim + 1 +np.uint32().strides +np.uint64().shape + +# Time structures +np.datetime64() +np.datetime64(0, "D") +np.datetime64(0, b"D") +np.datetime64(0, ('ms', 3)) +np.datetime64("2019") +np.datetime64(b"2019") +np.datetime64("2019", "D") +np.datetime64("2019", "us") +np.datetime64("2019", "as") +np.datetime64(np.datetime64()) +np.datetime64(np.datetime64()) +np.datetime64(dt.datetime(2000, 5, 3)) +np.datetime64(dt.datetime(2000, 5, 3), "D") +np.datetime64(dt.datetime(2000, 5, 3), "us") +np.datetime64(dt.datetime(2000, 5, 3), "as") +np.datetime64(dt.date(2000, 5, 3)) +np.datetime64(dt.date(2000, 5, 3), "D") +np.datetime64(dt.date(2000, 5, 3), "us") +np.datetime64(dt.date(2000, 5, 3), "as") +np.datetime64(None) +np.datetime64(None, "D") + +np.timedelta64() +np.timedelta64(0) +np.timedelta64(0, "D") +np.timedelta64(0, ('ms', 3)) +np.timedelta64(0, b"D") +np.timedelta64("3") +np.timedelta64(b"5") +np.timedelta64(np.timedelta64(2)) +np.timedelta64(dt.timedelta(2)) +np.timedelta64(None) +np.timedelta64(None, "D") + +np.void(1) +np.void(np.int64(1)) +np.void(True) +np.void(np.bool(True)) +np.void(b"test") +np.void(np.bytes_("test")) +np.void(object(), [("a", "O"), ("b", "O")]) +np.void(object(), dtype=[("a", "O"), ("b", "O")]) + +# Protocols +i8 = np.int64() +u8 = np.uint64() +f8 = np.float64() +c16 = np.complex128() +b = np.bool() +td = np.timedelta64() +U = np.str_("1") +S = np.bytes_("1") +AR = np.array(1, dtype=np.float64) + +int(i8) +int(u8) +int(f8) +int(b) +int(td) +int(U) +int(S) +int(AR) +with pytest.warns(np.exceptions.ComplexWarning): + int(c16) + +float(i8) +float(u8) +float(f8) +float(b_) +float(td) +float(U) +float(S) +float(AR) +with pytest.warns(np.exceptions.ComplexWarning): + float(c16) + +complex(i8) +complex(u8) +complex(f8) +complex(c16) +complex(b_) +complex(td) +complex(U) +complex(AR) + + +# Misc +c16.dtype +c16.real +c16.imag +c16.real.real +c16.real.imag +c16.ndim +c16.size +c16.itemsize +c16.shape +c16.strides +c16.squeeze() +c16.byteswap() +c16.transpose() + +# Aliases +np.byte() +np.short() +np.intc() +np.intp() +np.int_() +np.longlong() + +np.ubyte() +np.ushort() +np.uintc() +np.uintp() +np.uint() +np.ulonglong() + +np.half() +np.single() +np.double() +np.longdouble() + +np.csingle() +np.cdouble() +np.clongdouble() + +b.item() +i8.item() +u8.item() +f8.item() +c16.item() +U.item() +S.item() + +b.tolist() +i8.tolist() +u8.tolist() +f8.tolist() +c16.tolist() +U.tolist() +S.tolist() + +b.ravel() +i8.ravel() +u8.ravel() +f8.ravel() +c16.ravel() +U.ravel() +S.ravel() + +b.flatten() +i8.flatten() +u8.flatten() +f8.flatten() +c16.flatten() +U.flatten() +S.flatten() + +b.reshape(1) +i8.reshape(1) +u8.reshape(1) +f8.reshape(1) +c16.reshape(1) +U.reshape(1) +S.reshape(1) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/shape.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/shape.py new file mode 100644 index 0000000000000000000000000000000000000000..286c8a81dacf3351dc6e128437e39a86b4e146bb --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/shape.py @@ -0,0 +1,19 @@ +from typing import Any, NamedTuple, cast + +import numpy as np + + +# Subtype of tuple[int, int] +class XYGrid(NamedTuple): + x_axis: int + y_axis: int + +# Test variance of _ShapeT_co +def accepts_2d(a: np.ndarray[tuple[int, int], Any]) -> None: + return None + + +accepts_2d(np.empty(XYGrid(2, 2))) +accepts_2d(np.zeros(XYGrid(2, 2), dtype=int)) +accepts_2d(np.ones(XYGrid(2, 2), dtype=int)) +accepts_2d(np.full(XYGrid(2, 2), fill_value=5, dtype=int)) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/simple.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..8f44e6e76f8353d55a29f6cc0608d39e7db433dc --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/simple.py @@ -0,0 +1,168 @@ +"""Simple expression that should pass with mypy.""" +import operator + +import numpy as np +import numpy.typing as npt +from collections.abc import Iterable + +# Basic checks +array = np.array([1, 2]) + + +def ndarray_func(x: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]: + return x + + +ndarray_func(np.array([1, 2], dtype=np.float64)) +array == 1 +array.dtype == float + +# Dtype construction +np.dtype(float) +np.dtype(np.float64) +np.dtype(None) +np.dtype("float64") +np.dtype(np.dtype(float)) +np.dtype(("U", 10)) +np.dtype((np.int32, (2, 2))) +# Define the arguments on the previous line to prevent bidirectional +# type inference in mypy from broadening the types. +two_tuples_dtype = [("R", "u1"), ("G", "u1"), ("B", "u1")] +np.dtype(two_tuples_dtype) + +three_tuples_dtype = [("R", "u1", 2)] +np.dtype(three_tuples_dtype) + +mixed_tuples_dtype = [("R", "u1"), ("G", np.str_, 1)] +np.dtype(mixed_tuples_dtype) + +shape_tuple_dtype = [("R", "u1", (2, 2))] +np.dtype(shape_tuple_dtype) + +shape_like_dtype = [("R", "u1", (2, 2)), ("G", np.str_, 1)] +np.dtype(shape_like_dtype) + +object_dtype = [("field1", object)] +np.dtype(object_dtype) + +np.dtype((np.int32, (np.int8, 4))) + +# Dtype comparison +np.dtype(float) == float +np.dtype(float) != np.float64 +np.dtype(float) < None +np.dtype(float) <= "float64" +np.dtype(float) > np.dtype(float) +np.dtype(float) >= np.dtype(("U", 10)) + +# Iteration and indexing +def iterable_func(x: Iterable[object]) -> Iterable[object]: + return x + + +iterable_func(array) +list(array) +iter(array) +zip(array, array) +array[1] +array[:] +array[...] +array[:] = 0 + +array_2d = np.ones((3, 3)) +array_2d[:2, :2] +array_2d[:2, :2] = 0 +array_2d[..., 0] +array_2d[..., 0] = 2 +array_2d[-1, -1] = None + +array_obj = np.zeros(1, dtype=np.object_) +array_obj[0] = slice(None) + +# Other special methods +len(array) +str(array) +array_scalar = np.array(1) +int(array_scalar) +float(array_scalar) +complex(array_scalar) +bytes(array_scalar) +operator.index(array_scalar) +bool(array_scalar) + +# comparisons +array < 1 +array <= 1 +array == 1 +array != 1 +array > 1 +array >= 1 +1 < array +1 <= array +1 == array +1 != array +1 > array +1 >= array + +# binary arithmetic +array + 1 +1 + array +array += 1 + +array - 1 +1 - array +array -= 1 + +array * 1 +1 * array +array *= 1 + +nonzero_array = np.array([1, 2]) +array / 1 +1 / nonzero_array +float_array = np.array([1.0, 2.0]) +float_array /= 1 + +array // 1 +1 // nonzero_array +array //= 1 + +array % 1 +1 % nonzero_array +array %= 1 + +divmod(array, 1) +divmod(1, nonzero_array) + +array ** 1 +1 ** array +array **= 1 + +array << 1 +1 << array +array <<= 1 + +array >> 1 +1 >> array +array >>= 1 + +array & 1 +1 & array +array &= 1 + +array ^ 1 +1 ^ array +array ^= 1 + +array | 1 +1 | array +array |= 1 + +# unary arithmetic +-array ++array +abs(array) +~array + +# Other methods +np.array([1, 2]).transpose() diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/simple_py3.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/simple_py3.py new file mode 100644 index 0000000000000000000000000000000000000000..c05a1ce612ac5e40d0914732c4c72ad1d3f2d552 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/simple_py3.py @@ -0,0 +1,6 @@ +import numpy as np + +array = np.array([1, 2]) + +# The @ operator is not in python 2 +array @ array diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufunc_config.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufunc_config.py new file mode 100644 index 0000000000000000000000000000000000000000..778e1b57f7e3831464f0bc9ba37e4baddcdf49d8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufunc_config.py @@ -0,0 +1,64 @@ +"""Typing tests for `numpy._core._ufunc_config`.""" + +import numpy as np + + +def func1(a: str, b: int) -> None: + return None + + +def func2(a: str, b: int, c: float = 1.0) -> None: + return None + + +def func3(a: str, b: int) -> int: + return 0 + + +class Write1: + def write(self, a: str) -> None: + return None + + +class Write2: + def write(self, a: str, b: int = 1) -> None: + return None + + +class Write3: + def write(self, a: str) -> int: + return 0 + + +_err_default = np.geterr() +_bufsize_default = np.getbufsize() +_errcall_default = np.geterrcall() + +try: + np.seterr(all=None) + np.seterr(divide="ignore") + np.seterr(over="warn") + np.seterr(under="call") + np.seterr(invalid="raise") + np.geterr() + + np.setbufsize(4096) + np.getbufsize() + + np.seterrcall(func1) + np.seterrcall(func2) + np.seterrcall(func3) + np.seterrcall(Write1()) + np.seterrcall(Write2()) + np.seterrcall(Write3()) + np.geterrcall() + + with np.errstate(call=func1, all="call"): + pass + with np.errstate(call=Write1(), divide="log", over="log"): + pass + +finally: + np.seterr(**_err_default) + np.setbufsize(_bufsize_default) + np.seterrcall(_errcall_default) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufunclike.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufunclike.py new file mode 100644 index 0000000000000000000000000000000000000000..f993939ddba12851c5c93d3eed0f6d0923d0b98d --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufunclike.py @@ -0,0 +1,47 @@ +from __future__ import annotations +from typing import Any +import numpy as np + + +class Object: + def __ceil__(self) -> Object: + return self + + def __floor__(self) -> Object: + return self + + def __ge__(self, value: object) -> bool: + return True + + def __array__(self, dtype: np.typing.DTypeLike | None = None, + copy: bool | None = None) -> np.ndarray[Any, np.dtype[np.object_]]: + ret = np.empty((), dtype=object) + ret[()] = self + return ret + + +AR_LIKE_b = [True, True, False] +AR_LIKE_u = [np.uint32(1), np.uint32(2), np.uint32(3)] +AR_LIKE_i = [1, 2, 3] +AR_LIKE_f = [1.0, 2.0, 3.0] +AR_LIKE_O = [Object(), Object(), Object()] +AR_U: np.ndarray[Any, np.dtype[np.str_]] = np.zeros(3, dtype="U5") + +np.fix(AR_LIKE_b) +np.fix(AR_LIKE_u) +np.fix(AR_LIKE_i) +np.fix(AR_LIKE_f) +np.fix(AR_LIKE_O) +np.fix(AR_LIKE_f, out=AR_U) + +np.isposinf(AR_LIKE_b) +np.isposinf(AR_LIKE_u) +np.isposinf(AR_LIKE_i) +np.isposinf(AR_LIKE_f) +np.isposinf(AR_LIKE_f, out=AR_U) + +np.isneginf(AR_LIKE_b) +np.isneginf(AR_LIKE_u) +np.isneginf(AR_LIKE_i) +np.isneginf(AR_LIKE_f) +np.isneginf(AR_LIKE_f, out=AR_U) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufuncs.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc61bb0b17b594865bc909787798238c5b32bf5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/ufuncs.py @@ -0,0 +1,16 @@ +import numpy as np + +np.sin(1) +np.sin([1, 2, 3]) +np.sin(1, out=np.empty(1)) +np.matmul(np.ones((2, 2, 2)), np.ones((2, 2, 2)), axes=[(0, 1), (0, 1), (0, 1)]) +np.sin(1, signature="D->D") +# NOTE: `np.generic` subclasses are not guaranteed to support addition; +# re-enable this we can infer the exact return type of `np.sin(...)`. +# +# np.sin(1) + np.sin(1) +np.sin.types[0] +np.sin.__name__ +np.sin.__doc__ + +np.abs(np.array([1])) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/warnings_and_errors.py b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/warnings_and_errors.py new file mode 100644 index 0000000000000000000000000000000000000000..c351afb084c57b2d1264a3ff020ed1e63703a623 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/pass/warnings_and_errors.py @@ -0,0 +1,6 @@ +import numpy.exceptions as ex + +ex.AxisError("test") +ex.AxisError(1, ndim=2) +ex.AxisError(1, ndim=2, msg_prefix="error") +ex.AxisError(1, ndim=2, msg_prefix=None) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arithmetic.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arithmetic.pyi new file mode 100644 index 0000000000000000000000000000000000000000..763ff3914f53163a10af81f8ed5e459b73eeac57 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arithmetic.pyi @@ -0,0 +1,720 @@ +import datetime as dt +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy._typing import _64Bit, _128Bit + +b: bool +c: complex +f: float +i: int + +c16: np.complex128 +c8: np.complex64 + +# Can't directly import `np.float128` as it is not available on all platforms +f16: np.floating[_128Bit] +f8: np.float64 +f4: np.float32 + +i8: np.int64 +i4: np.int32 + +u8: np.uint64 +u4: np.uint32 + +b_: np.bool + +M8: np.datetime64 +M8_none: np.datetime64[None] +M8_date: np.datetime64[dt.date] +M8_time: np.datetime64[dt.datetime] +M8_int: np.datetime64[int] +date: dt.date +time: dt.datetime + +m8: np.timedelta64 +m8_none: np.timedelta64[None] +m8_int: np.timedelta64[int] +m8_delta: np.timedelta64[dt.timedelta] +delta: dt.timedelta + +AR_b: npt.NDArray[np.bool] +AR_u: npt.NDArray[np.uint32] +AR_i: npt.NDArray[np.int64] +AR_f: npt.NDArray[np.float64] +AR_c: npt.NDArray[np.complex128] +AR_m: npt.NDArray[np.timedelta64] +AR_M: npt.NDArray[np.datetime64] +AR_O: npt.NDArray[np.object_] +AR_S: npt.NDArray[np.bytes_] +AR_U: npt.NDArray[np.str_] +AR_T: np.ndarray[tuple[Any, ...], np.dtypes.StringDType] +AR_floating: npt.NDArray[np.floating] +AR_number: npt.NDArray[np.number] +AR_Any: npt.NDArray[Any] + +AR_LIKE_b: list[bool] +AR_LIKE_u: list[np.uint32] +AR_LIKE_i: list[int] +AR_LIKE_f: list[float] +AR_LIKE_c: list[complex] +AR_LIKE_m: list[np.timedelta64] +AR_LIKE_M: list[np.datetime64] +AR_LIKE_O: list[np.object_] + + +# Array subtraction + +assert_type(AR_number - AR_number, npt.NDArray[np.number]) + +assert_type(AR_b - AR_LIKE_u, npt.NDArray[np.uint32]) +assert_type(AR_b - AR_LIKE_i, npt.NDArray[np.signedinteger]) +assert_type(AR_b - AR_LIKE_f, npt.NDArray[np.floating]) +assert_type(AR_b - AR_LIKE_c, npt.NDArray[np.complexfloating]) +assert_type(AR_b - AR_LIKE_m, npt.NDArray[np.timedelta64]) +assert_type(AR_b - AR_LIKE_O, Any) + +assert_type(AR_LIKE_u - AR_b, npt.NDArray[np.uint32]) +assert_type(AR_LIKE_i - AR_b, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_f - AR_b, npt.NDArray[np.floating]) +assert_type(AR_LIKE_c - AR_b, npt.NDArray[np.complexfloating]) +assert_type(AR_LIKE_m - AR_b, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_M - AR_b, npt.NDArray[np.datetime64]) +assert_type(AR_LIKE_O - AR_b, Any) + +assert_type(AR_u - AR_LIKE_b, npt.NDArray[np.uint32]) +assert_type(AR_u - AR_LIKE_u, npt.NDArray[np.unsignedinteger]) +assert_type(AR_u - AR_LIKE_i, npt.NDArray[np.signedinteger]) +assert_type(AR_u - AR_LIKE_f, npt.NDArray[np.floating]) +assert_type(AR_u - AR_LIKE_c, npt.NDArray[np.complexfloating]) +assert_type(AR_u - AR_LIKE_m, npt.NDArray[np.timedelta64]) +assert_type(AR_u - AR_LIKE_O, Any) + +assert_type(AR_LIKE_b - AR_u, npt.NDArray[np.uint32]) +assert_type(AR_LIKE_u - AR_u, npt.NDArray[np.unsignedinteger]) +assert_type(AR_LIKE_i - AR_u, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_f - AR_u, npt.NDArray[np.floating]) +assert_type(AR_LIKE_c - AR_u, npt.NDArray[np.complexfloating]) +assert_type(AR_LIKE_m - AR_u, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_M - AR_u, npt.NDArray[np.datetime64]) +assert_type(AR_LIKE_O - AR_u, Any) + +assert_type(AR_i - AR_LIKE_b, npt.NDArray[np.int64]) +assert_type(AR_i - AR_LIKE_u, npt.NDArray[np.signedinteger]) +assert_type(AR_i - AR_LIKE_i, npt.NDArray[np.signedinteger]) +assert_type(AR_i - AR_LIKE_f, npt.NDArray[np.floating]) +assert_type(AR_i - AR_LIKE_c, npt.NDArray[np.complexfloating]) +assert_type(AR_i - AR_LIKE_m, npt.NDArray[np.timedelta64]) +assert_type(AR_i - AR_LIKE_O, Any) + +assert_type(AR_LIKE_b - AR_i, npt.NDArray[np.int64]) +assert_type(AR_LIKE_u - AR_i, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_i - AR_i, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_f - AR_i, npt.NDArray[np.floating]) +assert_type(AR_LIKE_c - AR_i, npt.NDArray[np.complexfloating]) +assert_type(AR_LIKE_m - AR_i, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_M - AR_i, npt.NDArray[np.datetime64]) +assert_type(AR_LIKE_O - AR_i, Any) + +assert_type(AR_f - AR_LIKE_b, npt.NDArray[np.float64]) +assert_type(AR_f - AR_LIKE_u, npt.NDArray[np.float64]) +assert_type(AR_f - AR_LIKE_i, npt.NDArray[np.float64]) +assert_type(AR_f - AR_LIKE_f, npt.NDArray[np.float64]) +assert_type(AR_f - AR_LIKE_c, npt.NDArray[np.complexfloating]) +assert_type(AR_f - AR_LIKE_O, Any) + +assert_type(AR_LIKE_b - AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_u - AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_i - AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_f - AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_c - AR_f, npt.NDArray[np.complexfloating]) +assert_type(AR_LIKE_O - AR_f, Any) + +assert_type(AR_c - AR_LIKE_b, npt.NDArray[np.complex128]) +assert_type(AR_c - AR_LIKE_u, npt.NDArray[np.complex128]) +assert_type(AR_c - AR_LIKE_i, npt.NDArray[np.complex128]) +assert_type(AR_c - AR_LIKE_f, npt.NDArray[np.complex128]) +assert_type(AR_c - AR_LIKE_c, npt.NDArray[np.complex128]) +assert_type(AR_c - AR_LIKE_O, Any) + +assert_type(AR_LIKE_b - AR_c, npt.NDArray[np.complex128]) +assert_type(AR_LIKE_u - AR_c, npt.NDArray[np.complex128]) +assert_type(AR_LIKE_i - AR_c, npt.NDArray[np.complex128]) +assert_type(AR_LIKE_f - AR_c, npt.NDArray[np.complex128]) +assert_type(AR_LIKE_c - AR_c, npt.NDArray[np.complex128]) +assert_type(AR_LIKE_O - AR_c, Any) + +assert_type(AR_m - AR_LIKE_b, npt.NDArray[np.timedelta64]) +assert_type(AR_m - AR_LIKE_u, npt.NDArray[np.timedelta64]) +assert_type(AR_m - AR_LIKE_i, npt.NDArray[np.timedelta64]) +assert_type(AR_m - AR_LIKE_m, npt.NDArray[np.timedelta64]) +assert_type(AR_m - AR_LIKE_O, Any) + +assert_type(AR_LIKE_b - AR_m, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_u - AR_m, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_i - AR_m, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_m - AR_m, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_M - AR_m, npt.NDArray[np.datetime64]) +assert_type(AR_LIKE_O - AR_m, Any) + +assert_type(AR_M - AR_LIKE_b, npt.NDArray[np.datetime64]) +assert_type(AR_M - AR_LIKE_u, npt.NDArray[np.datetime64]) +assert_type(AR_M - AR_LIKE_i, npt.NDArray[np.datetime64]) +assert_type(AR_M - AR_LIKE_m, npt.NDArray[np.datetime64]) +assert_type(AR_M - AR_LIKE_M, npt.NDArray[np.timedelta64]) +assert_type(AR_M - AR_LIKE_O, Any) + +assert_type(AR_LIKE_M - AR_M, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_O - AR_M, Any) + +assert_type(AR_O - AR_LIKE_b, Any) +assert_type(AR_O - AR_LIKE_u, Any) +assert_type(AR_O - AR_LIKE_i, Any) +assert_type(AR_O - AR_LIKE_f, Any) +assert_type(AR_O - AR_LIKE_c, Any) +assert_type(AR_O - AR_LIKE_m, Any) +assert_type(AR_O - AR_LIKE_M, Any) +assert_type(AR_O - AR_LIKE_O, Any) + +assert_type(AR_LIKE_b - AR_O, Any) +assert_type(AR_LIKE_u - AR_O, Any) +assert_type(AR_LIKE_i - AR_O, Any) +assert_type(AR_LIKE_f - AR_O, Any) +assert_type(AR_LIKE_c - AR_O, Any) +assert_type(AR_LIKE_m - AR_O, Any) +assert_type(AR_LIKE_M - AR_O, Any) +assert_type(AR_LIKE_O - AR_O, Any) + +# Array "true" division + +assert_type(AR_f / b, npt.NDArray[np.float64]) +assert_type(AR_f / i, npt.NDArray[np.float64]) +assert_type(AR_f / f, npt.NDArray[np.float64]) + +assert_type(b / AR_f, npt.NDArray[np.float64]) +assert_type(i / AR_f, npt.NDArray[np.float64]) +assert_type(f / AR_f, npt.NDArray[np.float64]) + +assert_type(AR_b / AR_LIKE_b, npt.NDArray[np.float64]) +assert_type(AR_b / AR_LIKE_u, npt.NDArray[np.float64]) +assert_type(AR_b / AR_LIKE_i, npt.NDArray[np.float64]) +assert_type(AR_b / AR_LIKE_f, npt.NDArray[np.float64]) +assert_type(AR_b / AR_LIKE_O, Any) + +assert_type(AR_LIKE_b / AR_b, npt.NDArray[np.float64]) +assert_type(AR_LIKE_u / AR_b, npt.NDArray[np.float64]) +assert_type(AR_LIKE_i / AR_b, npt.NDArray[np.float64]) +assert_type(AR_LIKE_f / AR_b, npt.NDArray[np.float64]) +assert_type(AR_LIKE_O / AR_b, Any) + +assert_type(AR_u / AR_LIKE_b, npt.NDArray[np.float64]) +assert_type(AR_u / AR_LIKE_u, npt.NDArray[np.float64]) +assert_type(AR_u / AR_LIKE_i, npt.NDArray[np.float64]) +assert_type(AR_u / AR_LIKE_f, npt.NDArray[np.float64]) +assert_type(AR_u / AR_LIKE_O, Any) + +assert_type(AR_LIKE_b / AR_u, npt.NDArray[np.float64]) +assert_type(AR_LIKE_u / AR_u, npt.NDArray[np.float64]) +assert_type(AR_LIKE_i / AR_u, npt.NDArray[np.float64]) +assert_type(AR_LIKE_f / AR_u, npt.NDArray[np.float64]) +assert_type(AR_LIKE_m / AR_u, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_O / AR_u, Any) + +assert_type(AR_i / AR_LIKE_b, npt.NDArray[np.float64]) +assert_type(AR_i / AR_LIKE_u, npt.NDArray[np.float64]) +assert_type(AR_i / AR_LIKE_i, npt.NDArray[np.float64]) +assert_type(AR_i / AR_LIKE_f, npt.NDArray[np.float64]) +assert_type(AR_i / AR_LIKE_O, Any) + +assert_type(AR_LIKE_b / AR_i, npt.NDArray[np.float64]) +assert_type(AR_LIKE_u / AR_i, npt.NDArray[np.float64]) +assert_type(AR_LIKE_i / AR_i, npt.NDArray[np.float64]) +assert_type(AR_LIKE_f / AR_i, npt.NDArray[np.float64]) +assert_type(AR_LIKE_m / AR_i, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_O / AR_i, Any) + +assert_type(AR_f / AR_LIKE_b, npt.NDArray[np.float64]) +assert_type(AR_f / AR_LIKE_u, npt.NDArray[np.float64]) +assert_type(AR_f / AR_LIKE_i, npt.NDArray[np.float64]) +assert_type(AR_f / AR_LIKE_f, npt.NDArray[np.float64]) +assert_type(AR_f / AR_LIKE_O, Any) + +assert_type(AR_LIKE_b / AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_u / AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_i / AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_f / AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_m / AR_f, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_O / AR_f, Any) + +assert_type(AR_m / AR_LIKE_u, npt.NDArray[np.timedelta64]) +assert_type(AR_m / AR_LIKE_i, npt.NDArray[np.timedelta64]) +assert_type(AR_m / AR_LIKE_f, npt.NDArray[np.timedelta64]) +assert_type(AR_m / AR_LIKE_m, npt.NDArray[np.float64]) +assert_type(AR_m / AR_LIKE_O, Any) + +assert_type(AR_LIKE_m / AR_m, npt.NDArray[np.float64]) +assert_type(AR_LIKE_O / AR_m, Any) + +assert_type(AR_O / AR_LIKE_b, Any) +assert_type(AR_O / AR_LIKE_u, Any) +assert_type(AR_O / AR_LIKE_i, Any) +assert_type(AR_O / AR_LIKE_f, Any) +assert_type(AR_O / AR_LIKE_m, Any) +assert_type(AR_O / AR_LIKE_M, Any) +assert_type(AR_O / AR_LIKE_O, Any) + +assert_type(AR_LIKE_b / AR_O, Any) +assert_type(AR_LIKE_u / AR_O, Any) +assert_type(AR_LIKE_i / AR_O, Any) +assert_type(AR_LIKE_f / AR_O, Any) +assert_type(AR_LIKE_m / AR_O, Any) +assert_type(AR_LIKE_M / AR_O, Any) +assert_type(AR_LIKE_O / AR_O, Any) + +# Array floor division + +assert_type(AR_b // AR_LIKE_b, npt.NDArray[np.int8]) +assert_type(AR_b // AR_LIKE_u, npt.NDArray[np.uint32]) +assert_type(AR_b // AR_LIKE_i, npt.NDArray[np.signedinteger]) +assert_type(AR_b // AR_LIKE_f, npt.NDArray[np.floating]) +assert_type(AR_b // AR_LIKE_O, Any) + +assert_type(AR_LIKE_b // AR_b, npt.NDArray[np.int8]) +assert_type(AR_LIKE_u // AR_b, npt.NDArray[np.uint32]) +assert_type(AR_LIKE_i // AR_b, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_f // AR_b, npt.NDArray[np.floating]) +assert_type(AR_LIKE_O // AR_b, Any) + +assert_type(AR_u // AR_LIKE_b, npt.NDArray[np.uint32]) +assert_type(AR_u // AR_LIKE_u, npt.NDArray[np.unsignedinteger]) +assert_type(AR_u // AR_LIKE_i, npt.NDArray[np.signedinteger]) +assert_type(AR_u // AR_LIKE_f, npt.NDArray[np.floating]) +assert_type(AR_u // AR_LIKE_O, Any) + +assert_type(AR_LIKE_b // AR_u, npt.NDArray[np.uint32]) +assert_type(AR_LIKE_u // AR_u, npt.NDArray[np.unsignedinteger]) +assert_type(AR_LIKE_i // AR_u, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_f // AR_u, npt.NDArray[np.floating]) +assert_type(AR_LIKE_m // AR_u, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_O // AR_u, Any) + +assert_type(AR_i // AR_LIKE_b, npt.NDArray[np.int64]) +assert_type(AR_i // AR_LIKE_u, npt.NDArray[np.signedinteger]) +assert_type(AR_i // AR_LIKE_i, npt.NDArray[np.signedinteger]) +assert_type(AR_i // AR_LIKE_f, npt.NDArray[np.floating]) +assert_type(AR_i // AR_LIKE_O, Any) + +assert_type(AR_LIKE_b // AR_i, npt.NDArray[np.int64]) +assert_type(AR_LIKE_u // AR_i, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_i // AR_i, npt.NDArray[np.signedinteger]) +assert_type(AR_LIKE_f // AR_i, npt.NDArray[np.floating]) +assert_type(AR_LIKE_m // AR_i, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_O // AR_i, Any) + +assert_type(AR_f // AR_LIKE_b, npt.NDArray[np.float64]) +assert_type(AR_f // AR_LIKE_u, npt.NDArray[np.float64]) +assert_type(AR_f // AR_LIKE_i, npt.NDArray[np.float64]) +assert_type(AR_f // AR_LIKE_f, npt.NDArray[np.float64]) +assert_type(AR_f // AR_LIKE_O, Any) + +assert_type(AR_LIKE_b // AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_u // AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_i // AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_f // AR_f, npt.NDArray[np.float64]) +assert_type(AR_LIKE_m // AR_f, npt.NDArray[np.timedelta64]) +assert_type(AR_LIKE_O // AR_f, Any) + +assert_type(AR_m // AR_LIKE_u, npt.NDArray[np.timedelta64]) +assert_type(AR_m // AR_LIKE_i, npt.NDArray[np.timedelta64]) +assert_type(AR_m // AR_LIKE_f, npt.NDArray[np.timedelta64]) +assert_type(AR_m // AR_LIKE_m, npt.NDArray[np.int64]) +assert_type(AR_m // AR_LIKE_O, Any) + +assert_type(AR_LIKE_m // AR_m, npt.NDArray[np.int64]) +assert_type(AR_LIKE_O // AR_m, Any) + +assert_type(AR_O // AR_LIKE_b, Any) +assert_type(AR_O // AR_LIKE_u, Any) +assert_type(AR_O // AR_LIKE_i, Any) +assert_type(AR_O // AR_LIKE_f, Any) +assert_type(AR_O // AR_LIKE_m, Any) +assert_type(AR_O // AR_LIKE_M, Any) +assert_type(AR_O // AR_LIKE_O, Any) + +assert_type(AR_LIKE_b // AR_O, Any) +assert_type(AR_LIKE_u // AR_O, Any) +assert_type(AR_LIKE_i // AR_O, Any) +assert_type(AR_LIKE_f // AR_O, Any) +assert_type(AR_LIKE_m // AR_O, Any) +assert_type(AR_LIKE_M // AR_O, Any) +assert_type(AR_LIKE_O // AR_O, Any) + +# unary ops + +assert_type(-f16, np.floating[_128Bit]) +assert_type(-c16, np.complex128) +assert_type(-c8, np.complex64) +assert_type(-f8, np.float64) +assert_type(-f4, np.float32) +assert_type(-i8, np.int64) +assert_type(-i4, np.int32) +assert_type(-u8, np.uint64) +assert_type(-u4, np.uint32) +assert_type(-m8, np.timedelta64) +assert_type(-m8_none, np.timedelta64[None]) +assert_type(-m8_int, np.timedelta64[int]) +assert_type(-m8_delta, np.timedelta64[dt.timedelta]) +assert_type(-AR_f, npt.NDArray[np.float64]) + +assert_type(+f16, np.floating[_128Bit]) +assert_type(+c16, np.complex128) +assert_type(+c8, np.complex64) +assert_type(+f8, np.float64) +assert_type(+f4, np.float32) +assert_type(+i8, np.int64) +assert_type(+i4, np.int32) +assert_type(+u8, np.uint64) +assert_type(+u4, np.uint32) +assert_type(+m8_none, np.timedelta64[None]) +assert_type(+m8_int, np.timedelta64[int]) +assert_type(+m8_delta, np.timedelta64[dt.timedelta]) +assert_type(+AR_f, npt.NDArray[np.float64]) + +assert_type(abs(f16), np.floating[_128Bit]) +assert_type(abs(c16), np.float64) +assert_type(abs(c8), np.float32) +assert_type(abs(f8), np.float64) +assert_type(abs(f4), np.float32) +assert_type(abs(i8), np.int64) +assert_type(abs(i4), np.int32) +assert_type(abs(u8), np.uint64) +assert_type(abs(u4), np.uint32) +assert_type(abs(m8), np.timedelta64) +assert_type(abs(m8_none), np.timedelta64[None]) +assert_type(abs(m8_int), np.timedelta64[int]) +assert_type(abs(m8_delta), np.timedelta64[dt.timedelta]) +assert_type(abs(b_), np.bool) +assert_type(abs(AR_O), npt.NDArray[np.object_]) + +# Time structures + +assert_type(M8 + m8, np.datetime64) +assert_type(M8 + i, np.datetime64) +assert_type(M8 + i8, np.datetime64) +assert_type(M8 - M8, np.timedelta64) +assert_type(M8 - i, np.datetime64) +assert_type(M8 - i8, np.datetime64) + +assert_type(M8_none + m8, np.datetime64[None]) +assert_type(M8_none + i, np.datetime64[None]) +assert_type(M8_none + i8, np.datetime64[None]) +assert_type(M8_none - M8, np.timedelta64[None]) +assert_type(M8_none - m8, np.datetime64[None]) +assert_type(M8_none - i, np.datetime64[None]) +assert_type(M8_none - i8, np.datetime64[None]) + +assert_type(m8 + m8, np.timedelta64) +assert_type(m8 + i, np.timedelta64) +assert_type(m8 + i8, np.timedelta64) +assert_type(m8 - m8, np.timedelta64) +assert_type(m8 - i, np.timedelta64) +assert_type(m8 - i8, np.timedelta64) +assert_type(m8 * f, np.timedelta64) +assert_type(m8 * f4, np.timedelta64) +assert_type(m8 * np.True_, np.timedelta64) +assert_type(m8 / f, np.timedelta64) +assert_type(m8 / f4, np.timedelta64) +assert_type(m8 / m8, np.float64) +assert_type(m8 // m8, np.int64) +assert_type(m8 % m8, np.timedelta64) +assert_type(divmod(m8, m8), tuple[np.int64, np.timedelta64]) + +assert_type(m8_none + m8, np.timedelta64[None]) +assert_type(m8_none + i, np.timedelta64[None]) +assert_type(m8_none + i8, np.timedelta64[None]) +assert_type(m8_none - i, np.timedelta64[None]) +assert_type(m8_none - i8, np.timedelta64[None]) + +assert_type(m8_int + i, np.timedelta64[int]) +assert_type(m8_int + m8_delta, np.timedelta64[int]) +assert_type(m8_int + m8, np.timedelta64[int | None]) +assert_type(m8_int - i, np.timedelta64[int]) +assert_type(m8_int - m8_delta, np.timedelta64[int]) +assert_type(m8_int - m8, np.timedelta64[int | None]) + +assert_type(m8_delta + date, dt.date) +assert_type(m8_delta + time, dt.datetime) +assert_type(m8_delta + delta, dt.timedelta) +assert_type(m8_delta - delta, dt.timedelta) +assert_type(m8_delta / delta, float) +assert_type(m8_delta // delta, int) +assert_type(m8_delta % delta, dt.timedelta) +assert_type(divmod(m8_delta, delta), tuple[int, dt.timedelta]) + +# boolean + +assert_type(b_ / b, np.float64) +assert_type(b_ / b_, np.float64) +assert_type(b_ / i, np.float64) +assert_type(b_ / i8, np.float64) +assert_type(b_ / i4, np.float64) +assert_type(b_ / u8, np.float64) +assert_type(b_ / u4, np.float64) +assert_type(b_ / f, np.float64) +assert_type(b_ / f16, np.floating[_128Bit]) +assert_type(b_ / f8, np.float64) +assert_type(b_ / f4, np.float32) +assert_type(b_ / c, np.complex128) +assert_type(b_ / c16, np.complex128) +assert_type(b_ / c8, np.complex64) + +assert_type(b / b_, np.float64) +assert_type(b_ / b_, np.float64) +assert_type(i / b_, np.float64) +assert_type(i8 / b_, np.float64) +assert_type(i4 / b_, np.float64) +assert_type(u8 / b_, np.float64) +assert_type(u4 / b_, np.float64) +assert_type(f / b_, np.float64) +assert_type(f16 / b_, np.floating[_128Bit]) +assert_type(f8 / b_, np.float64) +assert_type(f4 / b_, np.float32) +assert_type(c / b_, np.complex128) +assert_type(c16 / b_, np.complex128) +assert_type(c8 / b_, np.complex64) + +# Complex + +assert_type(c16 + f16, np.complexfloating) +assert_type(c16 + c16, np.complex128) +assert_type(c16 + f8, np.complex128) +assert_type(c16 + i8, np.complex128) +assert_type(c16 + c8, np.complex128) +assert_type(c16 + f4, np.complex128) +assert_type(c16 + i4, np.complex128) +assert_type(c16 + b_, np.complex128) +assert_type(c16 + b, np.complex128) +assert_type(c16 + c, np.complex128) +assert_type(c16 + f, np.complex128) +assert_type(c16 + AR_f, npt.NDArray[np.complex128]) + +assert_type(f16 + c16, np.complexfloating) +assert_type(c16 + c16, np.complex128) +assert_type(f8 + c16, np.complex128) +assert_type(i8 + c16, np.complex128) +assert_type(c8 + c16, np.complex128 | np.complex64) +assert_type(f4 + c16, np.complexfloating) +assert_type(i4 + c16, np.complex128) +assert_type(b_ + c16, np.complex128) +assert_type(b + c16, np.complex128) +assert_type(c + c16, np.complex128) +assert_type(f + c16, np.complex128) +assert_type(AR_f + c16, npt.NDArray[np.complex128]) + +assert_type(c8 + f16, np.complex64 | np.complexfloating[_128Bit, _128Bit]) +assert_type(c8 + c16, np.complex64 | np.complex128) +assert_type(c8 + f8, np.complex64 | np.complex128) +assert_type(c8 + i8, np.complex64 | np.complexfloating[_64Bit, _64Bit]) +assert_type(c8 + c8, np.complex64) +assert_type(c8 + f4, np.complex64) +assert_type(c8 + i4, np.complex64) +assert_type(c8 + b_, np.complex64) +assert_type(c8 + b, np.complex64) +assert_type(c8 + c, np.complex64 | np.complex128) +assert_type(c8 + f, np.complex64 | np.complex128) +assert_type(c8 + AR_f, npt.NDArray[np.complexfloating]) + +assert_type(f16 + c8, np.complexfloating[_128Bit, _128Bit] | np.complex64) +assert_type(c16 + c8, np.complex128) +assert_type(f8 + c8, np.complexfloating[_64Bit, _64Bit]) +assert_type(i8 + c8, np.complexfloating[_64Bit, _64Bit] | np.complex64) +assert_type(c8 + c8, np.complex64) +assert_type(f4 + c8, np.complex64) +assert_type(i4 + c8, np.complex64) +assert_type(b_ + c8, np.complex64) +assert_type(b + c8, np.complex64) +assert_type(c + c8, np.complex64 | np.complex128) +assert_type(f + c8, np.complex64 | np.complex128) +assert_type(AR_f + c8, npt.NDArray[np.complexfloating]) + +# Float + +assert_type(f8 + f16, np.floating) +assert_type(f8 + f8, np.float64) +assert_type(f8 + i8, np.float64) +assert_type(f8 + f4, np.float64) +assert_type(f8 + i4, np.float64) +assert_type(f8 + b_, np.float64) +assert_type(f8 + b, np.float64) +assert_type(f8 + c, np.float64 | np.complex128) +assert_type(f8 + f, np.float64) +assert_type(f8 + AR_f, npt.NDArray[np.float64]) + +assert_type(f16 + f8, np.floating) +assert_type(f8 + f8, np.float64) +assert_type(i8 + f8, np.float64) +assert_type(f4 + f8, np.floating) +assert_type(i4 + f8, np.float64) +assert_type(b_ + f8, np.float64) +assert_type(b + f8, np.float64) +assert_type(c + f8, np.complex128 | np.float64) +assert_type(f + f8, np.float64) +assert_type(AR_f + f8, npt.NDArray[np.float64]) + +assert_type(f4 + f16, np.floating) +assert_type(f4 + f8, np.floating) +assert_type(f4 + i8, np.floating) +assert_type(f4 + f4, np.float32) +assert_type(f4 + i4, np.floating) +assert_type(f4 + b_, np.float32) +assert_type(f4 + b, np.float32) +assert_type(f4 + c, np.complexfloating) +assert_type(f4 + f, np.float32) +assert_type(f4 + AR_f, npt.NDArray[np.float64]) + +assert_type(f16 + f4, np.floating) +assert_type(f8 + f4, np.float64) +assert_type(i8 + f4, np.floating) +assert_type(f4 + f4, np.float32) +assert_type(i4 + f4, np.floating) +assert_type(b_ + f4, np.float32) +assert_type(b + f4, np.float32) +assert_type(c + f4, np.complexfloating) +assert_type(f + f4, np.float32) +assert_type(AR_f + f4, npt.NDArray[np.float64]) + +# Int + +assert_type(i8 + i8, np.int64) +assert_type(i8 + u8, Any) +assert_type(i8 + i4, np.signedinteger) +assert_type(i8 + u4, Any) +assert_type(i8 + b_, np.int64) +assert_type(i8 + b, np.int64) +assert_type(i8 + c, np.complex128) +assert_type(i8 + f, np.float64) +assert_type(i8 + AR_f, npt.NDArray[np.float64]) + +assert_type(u8 + u8, np.uint64) +assert_type(u8 + i4, Any) +assert_type(u8 + u4, np.unsignedinteger) +assert_type(u8 + b_, np.uint64) +assert_type(u8 + b, np.uint64) +assert_type(u8 + c, np.complex128) +assert_type(u8 + f, np.float64) +assert_type(u8 + AR_f, npt.NDArray[np.float64]) + +assert_type(i8 + i8, np.int64) +assert_type(u8 + i8, Any) +assert_type(i4 + i8, np.signedinteger) +assert_type(u4 + i8, Any) +assert_type(b_ + i8, np.int64) +assert_type(b + i8, np.int64) +assert_type(c + i8, np.complex128) +assert_type(f + i8, np.float64) +assert_type(AR_f + i8, npt.NDArray[np.float64]) + +assert_type(u8 + u8, np.uint64) +assert_type(i4 + u8, Any) +assert_type(u4 + u8, np.unsignedinteger) +assert_type(b_ + u8, np.uint64) +assert_type(b + u8, np.uint64) +assert_type(c + u8, np.complex128) +assert_type(f + u8, np.float64) +assert_type(AR_f + u8, npt.NDArray[np.float64]) + +assert_type(i4 + i8, np.signedinteger) +assert_type(i4 + i4, np.int32) +assert_type(i4 + b_, np.int32) +assert_type(i4 + b, np.int32) +assert_type(i4 + AR_f, npt.NDArray[np.float64]) + +assert_type(u4 + i8, Any) +assert_type(u4 + i4, Any) +assert_type(u4 + u8, np.unsignedinteger) +assert_type(u4 + u4, np.uint32) +assert_type(u4 + b_, np.uint32) +assert_type(u4 + b, np.uint32) +assert_type(u4 + AR_f, npt.NDArray[np.float64]) + +assert_type(i8 + i4, np.signedinteger) +assert_type(i4 + i4, np.int32) +assert_type(b_ + i4, np.int32) +assert_type(b + i4, np.int32) +assert_type(AR_f + i4, npt.NDArray[np.float64]) + +assert_type(i8 + u4, Any) +assert_type(i4 + u4, Any) +assert_type(u8 + u4, np.unsignedinteger) +assert_type(u4 + u4, np.uint32) +assert_type(b_ + u4, np.uint32) +assert_type(b + u4, np.uint32) +assert_type(AR_f + u4, npt.NDArray[np.float64]) + +# Any + +assert_type(AR_Any + 2, npt.NDArray[Any]) + +# regression tests for https://github.com/numpy/numpy/issues/28805 + +assert_type(AR_floating + f, npt.NDArray[np.floating]) +assert_type(AR_floating - f, npt.NDArray[np.floating]) +assert_type(AR_floating * f, npt.NDArray[np.floating]) +assert_type(AR_floating ** f, npt.NDArray[np.floating]) +assert_type(AR_floating / f, npt.NDArray[np.floating]) +assert_type(AR_floating // f, npt.NDArray[np.floating]) +assert_type(AR_floating % f, npt.NDArray[np.floating]) +assert_type(divmod(AR_floating, f), tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]) + +assert_type(f + AR_floating, npt.NDArray[np.floating]) +assert_type(f - AR_floating, npt.NDArray[np.floating]) +assert_type(f * AR_floating, npt.NDArray[np.floating]) +assert_type(f ** AR_floating, npt.NDArray[np.floating]) +assert_type(f / AR_floating, npt.NDArray[np.floating]) +assert_type(f // AR_floating, npt.NDArray[np.floating]) +assert_type(f % AR_floating, npt.NDArray[np.floating]) +assert_type(divmod(f, AR_floating), tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]) + +# character-like + +assert_type(AR_S + b"", npt.NDArray[np.bytes_]) +assert_type(AR_S + [b""], npt.NDArray[np.bytes_]) +assert_type([b""] + AR_S, npt.NDArray[np.bytes_]) +assert_type(AR_S + AR_S, npt.NDArray[np.bytes_]) + +assert_type(AR_U + "", npt.NDArray[np.str_]) +assert_type(AR_U + [""], npt.NDArray[np.str_]) +assert_type("" + AR_U, npt.NDArray[np.str_]) +assert_type([""] + AR_U, npt.NDArray[np.str_]) +assert_type(AR_U + AR_U, npt.NDArray[np.str_]) + +assert_type(AR_T + "", np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type(AR_T + [""], np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type("" + AR_T, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type([""] + AR_T, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type(AR_T + AR_T, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type(AR_T + AR_U, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type(AR_U + AR_T, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) + +assert_type(AR_S * i, np.ndarray[tuple[Any, ...], np.dtype[np.bytes_]]) +assert_type(AR_S * AR_LIKE_i, np.ndarray[tuple[Any, ...], np.dtype[np.bytes_]]) +assert_type(AR_S * AR_i, np.ndarray[tuple[Any, ...], np.dtype[np.bytes_]]) +assert_type(i * AR_S, np.ndarray[tuple[Any, ...], np.dtype[np.bytes_]]) +# mypy incorrectly infers `AR_LIKE_i * AR_S` as `list[int]` +assert_type(AR_i * AR_S, np.ndarray[tuple[Any, ...], np.dtype[np.bytes_]]) + +assert_type(AR_U * i, np.ndarray[tuple[Any, ...], np.dtype[np.str_]]) +assert_type(AR_U * AR_LIKE_i, np.ndarray[tuple[Any, ...], np.dtype[np.str_]]) +assert_type(AR_U * AR_i, np.ndarray[tuple[Any, ...], np.dtype[np.str_]]) +assert_type(i * AR_U, np.ndarray[tuple[Any, ...], np.dtype[np.str_]]) +# mypy incorrectly infers `AR_LIKE_i * AR_U` as `list[int]` +assert_type(AR_i * AR_U, np.ndarray[tuple[Any, ...], np.dtype[np.str_]]) + +assert_type(AR_T * i, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type(AR_T * AR_LIKE_i, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type(AR_T * AR_i, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +assert_type(i * AR_T, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) +# mypy incorrectly infers `AR_LIKE_i * AR_T` as `list[int]` +assert_type(AR_i * AR_T, np.ndarray[tuple[Any, ...], np.dtypes.StringDType]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/array_api_info.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/array_api_info.pyi new file mode 100644 index 0000000000000000000000000000000000000000..765f9eff516856657f4aaa8aebf22a4c2851a599 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/array_api_info.pyi @@ -0,0 +1,70 @@ +from typing import Literal, Never, assert_type + +import numpy as np + +info = np.__array_namespace_info__() + +assert_type(info.__module__, Literal["numpy"]) + +assert_type(info.default_device(), Literal["cpu"]) +assert_type(info.devices()[0], Literal["cpu"]) +assert_type(info.devices()[-1], Literal["cpu"]) + +assert_type(info.capabilities()["boolean indexing"], Literal[True]) +assert_type(info.capabilities()["data-dependent shapes"], Literal[True]) + +assert_type(info.default_dtypes()["real floating"], np.dtype[np.float64]) +assert_type(info.default_dtypes()["complex floating"], np.dtype[np.complex128]) +assert_type(info.default_dtypes()["integral"], np.dtype[np.int_]) +assert_type(info.default_dtypes()["indexing"], np.dtype[np.intp]) + +assert_type(info.dtypes()["bool"], np.dtype[np.bool]) +assert_type(info.dtypes()["int8"], np.dtype[np.int8]) +assert_type(info.dtypes()["uint8"], np.dtype[np.uint8]) +assert_type(info.dtypes()["float32"], np.dtype[np.float32]) +assert_type(info.dtypes()["complex64"], np.dtype[np.complex64]) + +assert_type(info.dtypes(kind="bool")["bool"], np.dtype[np.bool]) +assert_type(info.dtypes(kind="signed integer")["int64"], np.dtype[np.int64]) +assert_type(info.dtypes(kind="unsigned integer")["uint64"], np.dtype[np.uint64]) +assert_type(info.dtypes(kind="integral")["int32"], np.dtype[np.int32]) +assert_type(info.dtypes(kind="integral")["uint32"], np.dtype[np.uint32]) +assert_type(info.dtypes(kind="real floating")["float64"], np.dtype[np.float64]) +assert_type(info.dtypes(kind="complex floating")["complex128"], np.dtype[np.complex128]) +assert_type(info.dtypes(kind="numeric")["int16"], np.dtype[np.int16]) +assert_type(info.dtypes(kind="numeric")["uint16"], np.dtype[np.uint16]) +assert_type(info.dtypes(kind="numeric")["float64"], np.dtype[np.float64]) +assert_type(info.dtypes(kind="numeric")["complex128"], np.dtype[np.complex128]) + +assert_type(info.dtypes(kind=()), dict[Never, Never]) + +assert_type(info.dtypes(kind=("bool",))["bool"], np.dtype[np.bool]) +assert_type(info.dtypes(kind=("signed integer",))["int64"], np.dtype[np.int64]) +assert_type(info.dtypes(kind=("integral",))["uint32"], np.dtype[np.uint32]) +assert_type(info.dtypes(kind=("complex floating",))["complex128"], np.dtype[np.complex128]) +assert_type(info.dtypes(kind=("numeric",))["float64"], np.dtype[np.float64]) + +assert_type( + info.dtypes(kind=("signed integer", "unsigned integer"))["int8"], + np.dtype[np.int8], +) +assert_type( + info.dtypes(kind=("signed integer", "unsigned integer"))["uint8"], + np.dtype[np.uint8], +) +assert_type( + info.dtypes(kind=("integral", "real floating", "complex floating"))["int16"], + np.dtype[np.int16], +) +assert_type( + info.dtypes(kind=("integral", "real floating", "complex floating"))["uint16"], + np.dtype[np.uint16], +) +assert_type( + info.dtypes(kind=("integral", "real floating", "complex floating"))["float32"], + np.dtype[np.float32], +) +assert_type( + info.dtypes(kind=("integral", "real floating", "complex floating"))["complex64"], + np.dtype[np.complex64], +) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/array_constructors.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/array_constructors.pyi new file mode 100644 index 0000000000000000000000000000000000000000..49425bb890bcd559dd273ba11261b35832ef3029 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/array_constructors.pyi @@ -0,0 +1,249 @@ +import sys +from collections import deque +from pathlib import Path +from typing import Any, TypeVar, assert_type + +import numpy as np +import numpy.typing as npt + +_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, covariant=True) + +class SubClass(npt.NDArray[_ScalarT_co]): ... + +i8: np.int64 + +A: npt.NDArray[np.float64] +B: SubClass[np.float64] +C: list[int] +D: SubClass[np.float64 | np.int64] + +mixed_shape: tuple[int, np.int64] + +def func(i: int, j: int, **kwargs: Any) -> SubClass[np.float64]: ... + +assert_type(np.empty_like(A), npt.NDArray[np.float64]) +assert_type(np.empty_like(B), SubClass[np.float64]) +assert_type(np.empty_like([1, 1.0]), npt.NDArray[Any]) +assert_type(np.empty_like(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.empty_like(A, dtype='c16'), npt.NDArray[Any]) + +assert_type(np.array(A), npt.NDArray[np.float64]) +assert_type(np.array(B), npt.NDArray[np.float64]) +assert_type(np.array([1, 1.0]), npt.NDArray[Any]) +assert_type(np.array(deque([1, 2, 3])), npt.NDArray[Any]) +assert_type(np.array(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.array(A, dtype='c16'), npt.NDArray[Any]) +assert_type(np.array(A, like=A), npt.NDArray[np.float64]) +assert_type(np.array(A, subok=True), npt.NDArray[np.float64]) +assert_type(np.array(B, subok=True), SubClass[np.float64]) +assert_type(np.array(B, subok=True, ndmin=0), SubClass[np.float64]) +assert_type(np.array(B, subok=True, ndmin=1), SubClass[np.float64]) +assert_type(np.array(D), npt.NDArray[np.float64 | np.int64]) +# https://github.com/numpy/numpy/issues/29245 +assert_type(np.array([], dtype=np.bool), npt.NDArray[np.bool]) + +assert_type(np.zeros([1, 5, 6]), npt.NDArray[np.float64]) +assert_type(np.zeros([1, 5, 6], dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.zeros([1, 5, 6], dtype='c16'), npt.NDArray[Any]) +assert_type(np.zeros(mixed_shape), npt.NDArray[np.float64]) + +assert_type(np.empty([1, 5, 6]), npt.NDArray[np.float64]) +assert_type(np.empty([1, 5, 6], dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.empty([1, 5, 6], dtype='c16'), npt.NDArray[Any]) +assert_type(np.empty(mixed_shape), npt.NDArray[np.float64]) + +assert_type(np.concatenate(A), npt.NDArray[np.float64]) +assert_type(np.concatenate([A, A]), npt.NDArray[Any]) # pyright correctly infers this as NDArray[float64] +assert_type(np.concatenate([[1], A]), npt.NDArray[Any]) +assert_type(np.concatenate([[1], [1]]), npt.NDArray[Any]) +assert_type(np.concatenate((A, A)), npt.NDArray[np.float64]) +assert_type(np.concatenate(([1], [1])), npt.NDArray[Any]) +assert_type(np.concatenate([1, 1.0]), npt.NDArray[Any]) +assert_type(np.concatenate(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.concatenate(A, dtype='c16'), npt.NDArray[Any]) +assert_type(np.concatenate([1, 1.0], out=A), npt.NDArray[np.float64]) + +assert_type(np.asarray(A), npt.NDArray[np.float64]) +assert_type(np.asarray(B), npt.NDArray[np.float64]) +assert_type(np.asarray([1, 1.0]), npt.NDArray[Any]) +assert_type(np.asarray(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.asarray(A, dtype='c16'), npt.NDArray[Any]) + +assert_type(np.asanyarray(A), npt.NDArray[np.float64]) +assert_type(np.asanyarray(B), SubClass[np.float64]) +assert_type(np.asanyarray([1, 1.0]), npt.NDArray[Any]) +assert_type(np.asanyarray(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.asanyarray(A, dtype='c16'), npt.NDArray[Any]) + +assert_type(np.ascontiguousarray(A), npt.NDArray[np.float64]) +assert_type(np.ascontiguousarray(B), npt.NDArray[np.float64]) +assert_type(np.ascontiguousarray([1, 1.0]), npt.NDArray[Any]) +assert_type(np.ascontiguousarray(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.ascontiguousarray(A, dtype='c16'), npt.NDArray[Any]) + +assert_type(np.asfortranarray(A), npt.NDArray[np.float64]) +assert_type(np.asfortranarray(B), npt.NDArray[np.float64]) +assert_type(np.asfortranarray([1, 1.0]), npt.NDArray[Any]) +assert_type(np.asfortranarray(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.asfortranarray(A, dtype='c16'), npt.NDArray[Any]) + +assert_type(np.fromstring("1 1 1", sep=" "), npt.NDArray[np.float64]) +assert_type(np.fromstring(b"1 1 1", sep=" "), npt.NDArray[np.float64]) +assert_type(np.fromstring("1 1 1", dtype=np.int64, sep=" "), npt.NDArray[np.int64]) +assert_type(np.fromstring(b"1 1 1", dtype=np.int64, sep=" "), npt.NDArray[np.int64]) +assert_type(np.fromstring("1 1 1", dtype="c16", sep=" "), npt.NDArray[Any]) +assert_type(np.fromstring(b"1 1 1", dtype="c16", sep=" "), npt.NDArray[Any]) + +assert_type(np.fromfile("test.txt", sep=" "), npt.NDArray[np.float64]) +assert_type(np.fromfile("test.txt", dtype=np.int64, sep=" "), npt.NDArray[np.int64]) +assert_type(np.fromfile("test.txt", dtype="c16", sep=" "), npt.NDArray[Any]) +with open("test.txt") as f: + assert_type(np.fromfile(f, sep=" "), npt.NDArray[np.float64]) + assert_type(np.fromfile(b"test.txt", sep=" "), npt.NDArray[np.float64]) + assert_type(np.fromfile(Path("test.txt"), sep=" "), npt.NDArray[np.float64]) + +assert_type(np.fromiter("12345", np.float64), npt.NDArray[np.float64]) +assert_type(np.fromiter("12345", float), npt.NDArray[Any]) + +assert_type(np.frombuffer(A), npt.NDArray[np.float64]) +assert_type(np.frombuffer(A, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.frombuffer(A, dtype="c16"), npt.NDArray[Any]) + +assert_type(np.arange(False, True), np.ndarray[tuple[int], np.dtype[np.signedinteger]]) +assert_type(np.arange(10), np.ndarray[tuple[int], np.dtype[np.signedinteger]]) +assert_type(np.arange(0, 10, step=2), np.ndarray[tuple[int], np.dtype[np.signedinteger]]) +assert_type(np.arange(10.0), np.ndarray[tuple[int], np.dtype[np.floating]]) +assert_type(np.arange(start=0, stop=10.0), np.ndarray[tuple[int], np.dtype[np.floating]]) +assert_type(np.arange(np.timedelta64(0)), np.ndarray[tuple[int], np.dtype[np.timedelta64]]) +assert_type(np.arange(0, np.timedelta64(10)), np.ndarray[tuple[int], np.dtype[np.timedelta64]]) +assert_type(np.arange(np.datetime64("0"), np.datetime64("10")), np.ndarray[tuple[int], np.dtype[np.datetime64]]) +assert_type(np.arange(10, dtype=np.float64), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(np.arange(0, 10, step=2, dtype=np.int16), np.ndarray[tuple[int], np.dtype[np.int16]]) +assert_type(np.arange(10, dtype=int), np.ndarray[tuple[int], np.dtype]) +assert_type(np.arange(0, 10, dtype="f8"), np.ndarray[tuple[int], np.dtype]) + +assert_type(np.require(A), npt.NDArray[np.float64]) +assert_type(np.require(B), SubClass[np.float64]) +assert_type(np.require(B, requirements=None), SubClass[np.float64]) +assert_type(np.require(B, dtype=int), npt.NDArray[Any]) +assert_type(np.require(B, requirements="E"), npt.NDArray[Any]) +assert_type(np.require(B, requirements=["ENSUREARRAY"]), npt.NDArray[Any]) +assert_type(np.require(B, requirements={"F", "E"}), npt.NDArray[Any]) +assert_type(np.require(B, requirements=["C", "OWNDATA"]), SubClass[np.float64]) +assert_type(np.require(B, requirements="W"), SubClass[np.float64]) +assert_type(np.require(B, requirements="A"), SubClass[np.float64]) +assert_type(np.require(C), npt.NDArray[Any]) + +assert_type(np.linspace(0, 10), npt.NDArray[np.float64]) +assert_type(np.linspace(0, 10j), npt.NDArray[np.complexfloating]) +assert_type(np.linspace(0, 10, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.linspace(0, 10, dtype=int), npt.NDArray[Any]) +assert_type(np.linspace(0, 10, retstep=True), tuple[npt.NDArray[np.float64], np.float64]) +assert_type(np.linspace(0j, 10, retstep=True), tuple[npt.NDArray[np.complexfloating], np.complexfloating]) +assert_type(np.linspace(0, 10, retstep=True, dtype=np.int64), tuple[npt.NDArray[np.int64], np.int64]) +assert_type(np.linspace(0j, 10, retstep=True, dtype=int), tuple[npt.NDArray[Any], Any]) + +assert_type(np.logspace(0, 10), npt.NDArray[np.float64]) +assert_type(np.logspace(0, 10j), npt.NDArray[np.complexfloating]) +assert_type(np.logspace(0, 10, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.logspace(0, 10, dtype=int), npt.NDArray[Any]) + +assert_type(np.geomspace(0, 10), npt.NDArray[np.float64]) +assert_type(np.geomspace(0, 10j), npt.NDArray[np.complexfloating]) +assert_type(np.geomspace(0, 10, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.geomspace(0, 10, dtype=int), npt.NDArray[Any]) + +assert_type(np.zeros_like(A), npt.NDArray[np.float64]) +assert_type(np.zeros_like(C), npt.NDArray[Any]) +assert_type(np.zeros_like(A, dtype=float), npt.NDArray[Any]) +assert_type(np.zeros_like(B), SubClass[np.float64]) +assert_type(np.zeros_like(B, dtype=np.int64), npt.NDArray[np.int64]) + +assert_type(np.ones_like(A), npt.NDArray[np.float64]) +assert_type(np.ones_like(C), npt.NDArray[Any]) +assert_type(np.ones_like(A, dtype=float), npt.NDArray[Any]) +assert_type(np.ones_like(B), SubClass[np.float64]) +assert_type(np.ones_like(B, dtype=np.int64), npt.NDArray[np.int64]) + +assert_type(np.full_like(A, i8), npt.NDArray[np.float64]) +assert_type(np.full_like(C, i8), npt.NDArray[Any]) +assert_type(np.full_like(A, i8, dtype=int), npt.NDArray[Any]) +assert_type(np.full_like(B, i8), SubClass[np.float64]) +assert_type(np.full_like(B, i8, dtype=np.int64), npt.NDArray[np.int64]) + +_size: int +_shape_0d: tuple[()] +_shape_1d: tuple[int] +_shape_2d: tuple[int, int] +_shape_nd: tuple[int, ...] +_shape_like: list[int] + +assert_type(np.ones(_shape_0d), np.ndarray[tuple[()], np.dtype[np.float64]]) +assert_type(np.ones(_size), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(np.ones(_shape_2d), np.ndarray[tuple[int, int], np.dtype[np.float64]]) +assert_type(np.ones(_shape_nd), np.ndarray[tuple[int, ...], np.dtype[np.float64]]) +assert_type(np.ones(_shape_1d, dtype=np.int64), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(np.ones(_shape_like), npt.NDArray[np.float64]) +assert_type(np.ones(_shape_like, dtype=np.dtypes.Int64DType()), np.ndarray[Any, np.dtypes.Int64DType]) +assert_type(np.ones(_shape_like, dtype=int), npt.NDArray[Any]) +assert_type(np.ones(mixed_shape), npt.NDArray[np.float64]) + +assert_type(np.full(_size, i8), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(np.full(_shape_2d, i8), np.ndarray[tuple[int, int], np.dtype[np.int64]]) +assert_type(np.full(_shape_like, i8), npt.NDArray[np.int64]) +assert_type(np.full(_shape_like, 42), npt.NDArray[Any]) +assert_type(np.full(_size, i8, dtype=np.float64), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(np.full(_size, i8, dtype=float), np.ndarray[tuple[int], np.dtype]) +assert_type(np.full(_shape_like, 42, dtype=float), npt.NDArray[Any]) +assert_type(np.full(_shape_0d, i8, dtype=object), np.ndarray[tuple[()], np.dtype]) + +assert_type(np.indices([1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.indices([1, 2, 3], sparse=True), tuple[npt.NDArray[np.int_], ...]) + +assert_type(np.fromfunction(func, (3, 5)), SubClass[np.float64]) + +assert_type(np.identity(10), npt.NDArray[np.float64]) +assert_type(np.identity(10, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.identity(10, dtype=int), npt.NDArray[Any]) + +assert_type(np.atleast_1d(A), npt.NDArray[np.float64]) +assert_type(np.atleast_1d(C), npt.NDArray[Any]) +assert_type(np.atleast_1d(A, A), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]) +assert_type(np.atleast_1d(A, C), tuple[npt.NDArray[Any], npt.NDArray[Any]]) +assert_type(np.atleast_1d(C, C), tuple[npt.NDArray[Any], npt.NDArray[Any]]) +assert_type(np.atleast_1d(A, A, A), tuple[npt.NDArray[np.float64], ...]) +assert_type(np.atleast_1d(C, C, C), tuple[npt.NDArray[Any], ...]) + +assert_type(np.atleast_2d(A), npt.NDArray[np.float64]) +assert_type(np.atleast_2d(A, A), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]) +assert_type(np.atleast_2d(A, A, A), tuple[npt.NDArray[np.float64], ...]) + +assert_type(np.atleast_3d(A), npt.NDArray[np.float64]) +assert_type(np.atleast_3d(A, A), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]) +assert_type(np.atleast_3d(A, A, A), tuple[npt.NDArray[np.float64], ...]) + +assert_type(np.vstack([A, A]), npt.NDArray[np.float64]) +assert_type(np.vstack([A, A], dtype=np.float32), npt.NDArray[np.float32]) +assert_type(np.vstack([A, C]), npt.NDArray[Any]) +assert_type(np.vstack([C, C]), npt.NDArray[Any]) + +assert_type(np.hstack([A, A]), npt.NDArray[np.float64]) +assert_type(np.hstack([A, A], dtype=np.float32), npt.NDArray[np.float32]) + +assert_type(np.stack([A, A]), npt.NDArray[np.float64]) +assert_type(np.stack([A, A], dtype=np.float32), npt.NDArray[np.float32]) +assert_type(np.stack([A, C]), npt.NDArray[Any]) +assert_type(np.stack([C, C]), npt.NDArray[Any]) +assert_type(np.stack([A, A], axis=0), npt.NDArray[np.float64]) +assert_type(np.stack([A, A], out=B), SubClass[np.float64]) + +assert_type(np.block([[A, A], [A, A]]), npt.NDArray[Any]) # pyright correctly infers this as NDArray[float64] +assert_type(np.block(C), npt.NDArray[Any]) + +if sys.version_info >= (3, 12): + from collections.abc import Buffer + + def create_array(obj: npt.ArrayLike) -> npt.NDArray[Any]: ... + + buffer: Buffer + assert_type(create_array(buffer), npt.NDArray[Any]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arraypad.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arraypad.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c5a443d93fe3de2a02dd6294250df47fb834275b --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arraypad.pyi @@ -0,0 +1,22 @@ +from collections.abc import Mapping +from typing import Any, SupportsIndex, assert_type + +import numpy as np +import numpy.typing as npt + +def mode_func( + ar: npt.NDArray[np.number], + width: tuple[int, int], + iaxis: SupportsIndex, + kwargs: Mapping[str, Any], +) -> None: ... + +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] +AR_LIKE: list[int] + +assert_type(np.pad(AR_i8, (2, 3), "constant"), npt.NDArray[np.int64]) +assert_type(np.pad(AR_LIKE, (2, 3), "constant"), npt.NDArray[Any]) + +assert_type(np.pad(AR_f8, (2, 3), mode_func), npt.NDArray[np.float64]) +assert_type(np.pad(AR_f8, (2, 3), mode_func, a=1, b=2), npt.NDArray[np.float64]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arrayprint.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arrayprint.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3b339edced32a905431bd90d577981e0b5d31bea --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arrayprint.pyi @@ -0,0 +1,25 @@ +import contextlib +from collections.abc import Callable +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy._core.arrayprint import _FormatOptions + +AR: npt.NDArray[np.int64] +func_float: Callable[[np.floating], str] +func_int: Callable[[np.integer], str] + +assert_type(np.get_printoptions(), _FormatOptions) +assert_type( + np.array2string(AR, formatter={'float_kind': func_float, 'int_kind': func_int}), + str, +) +assert_type(np.format_float_scientific(1.0), str) +assert_type(np.format_float_positional(1), str) +assert_type(np.array_repr(AR), str) +assert_type(np.array_str(AR), str) + +assert_type(np.printoptions(), contextlib._GeneratorContextManager[_FormatOptions]) +with np.printoptions() as dct: + assert_type(dct, _FormatOptions) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arraysetops.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arraysetops.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7e5ca5c5717baeb13a12af2cb8cfa42469c34fb7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arraysetops.pyi @@ -0,0 +1,74 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy.lib._arraysetops_impl import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, +) + +AR_b: npt.NDArray[np.bool] +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] +AR_M: npt.NDArray[np.datetime64] +AR_O: npt.NDArray[np.object_] + +AR_LIKE_f8: list[float] + +assert_type(np.ediff1d(AR_b), npt.NDArray[np.int8]) +assert_type(np.ediff1d(AR_i8, to_end=[1, 2, 3]), npt.NDArray[np.int64]) +assert_type(np.ediff1d(AR_M), npt.NDArray[np.timedelta64]) +assert_type(np.ediff1d(AR_O), npt.NDArray[np.object_]) +assert_type(np.ediff1d(AR_LIKE_f8, to_begin=[1, 1.5]), npt.NDArray[Any]) + +assert_type(np.intersect1d(AR_i8, AR_i8), npt.NDArray[np.int64]) +assert_type(np.intersect1d(AR_M, AR_M, assume_unique=True), npt.NDArray[np.datetime64]) +assert_type(np.intersect1d(AR_f8, AR_i8), npt.NDArray[Any]) +assert_type( + np.intersect1d(AR_f8, AR_f8, return_indices=True), + tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp]], +) + +assert_type(np.setxor1d(AR_i8, AR_i8), npt.NDArray[np.int64]) +assert_type(np.setxor1d(AR_M, AR_M, assume_unique=True), npt.NDArray[np.datetime64]) +assert_type(np.setxor1d(AR_f8, AR_i8), npt.NDArray[Any]) + +assert_type(np.isin(AR_i8, AR_i8), npt.NDArray[np.bool]) +assert_type(np.isin(AR_M, AR_M, assume_unique=True), npt.NDArray[np.bool]) +assert_type(np.isin(AR_f8, AR_i8), npt.NDArray[np.bool]) +assert_type(np.isin(AR_f8, AR_LIKE_f8, invert=True), npt.NDArray[np.bool]) + +assert_type(np.union1d(AR_i8, AR_i8), npt.NDArray[np.int64]) +assert_type(np.union1d(AR_M, AR_M), npt.NDArray[np.datetime64]) +assert_type(np.union1d(AR_f8, AR_i8), npt.NDArray[Any]) + +assert_type(np.setdiff1d(AR_i8, AR_i8), npt.NDArray[np.int64]) +assert_type(np.setdiff1d(AR_M, AR_M, assume_unique=True), npt.NDArray[np.datetime64]) +assert_type(np.setdiff1d(AR_f8, AR_i8), npt.NDArray[Any]) + +assert_type(np.unique(AR_f8), npt.NDArray[np.float64]) +assert_type(np.unique(AR_LIKE_f8, axis=0), npt.NDArray[Any]) +assert_type(np.unique(AR_f8, return_index=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_LIKE_f8, return_index=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_f8, return_inverse=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_LIKE_f8, return_inverse=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_f8, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_LIKE_f8, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_f8, return_index=True, return_inverse=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_LIKE_f8, return_index=True, return_inverse=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_f8, return_index=True, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_LIKE_f8, return_index=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_f8, return_inverse=True, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_LIKE_f8, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.unique(AR_LIKE_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]]) + +assert_type(np.unique_all(AR_f8), UniqueAllResult[np.float64]) +assert_type(np.unique_all(AR_LIKE_f8), UniqueAllResult[Any]) +assert_type(np.unique_counts(AR_f8), UniqueCountsResult[np.float64]) +assert_type(np.unique_counts(AR_LIKE_f8), UniqueCountsResult[Any]) +assert_type(np.unique_inverse(AR_f8), UniqueInverseResult[np.float64]) +assert_type(np.unique_inverse(AR_LIKE_f8), UniqueInverseResult[Any]) +assert_type(np.unique_values(AR_f8), npt.NDArray[np.float64]) +assert_type(np.unique_values(AR_LIKE_f8), npt.NDArray[Any]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arrayterator.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arrayterator.pyi new file mode 100644 index 0000000000000000000000000000000000000000..470160c24de33c4069e7f0fc36596d5cbf95f6de --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/arrayterator.pyi @@ -0,0 +1,27 @@ +from collections.abc import Generator +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +AR_i8: npt.NDArray[np.int64] +ar_iter = np.lib.Arrayterator(AR_i8) + +assert_type(ar_iter.var, npt.NDArray[np.int64]) +assert_type(ar_iter.buf_size, int | None) +assert_type(ar_iter.start, list[int]) +assert_type(ar_iter.stop, list[int]) +assert_type(ar_iter.step, list[int]) +assert_type(ar_iter.shape, tuple[Any, ...]) +assert_type(ar_iter.flat, Generator[np.int64, None, None]) + +assert_type(ar_iter.__array__(), npt.NDArray[np.int64]) + +for i in ar_iter: + assert_type(i, npt.NDArray[np.int64]) + +assert_type(ar_iter[0], np.lib.Arrayterator[tuple[Any, ...], np.dtype[np.int64]]) +assert_type(ar_iter[...], np.lib.Arrayterator[tuple[Any, ...], np.dtype[np.int64]]) +assert_type(ar_iter[:], np.lib.Arrayterator[tuple[Any, ...], np.dtype[np.int64]]) +assert_type(ar_iter[0, 0, 0], np.lib.Arrayterator[tuple[Any, ...], np.dtype[np.int64]]) +assert_type(ar_iter[..., 0, :], np.lib.Arrayterator[tuple[Any, ...], np.dtype[np.int64]]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/bitwise_ops.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/bitwise_ops.pyi new file mode 100644 index 0000000000000000000000000000000000000000..cd56caad1527e6e3210444db37087b50a6fee808 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/bitwise_ops.pyi @@ -0,0 +1,167 @@ +from typing import Any, TypeAlias, assert_type +from typing import Literal as L + +import numpy as np +import numpy.typing as npt + +FalseType: TypeAlias = L[False] +TrueType: TypeAlias = L[True] + +i4: np.int32 +i8: np.int64 + +u4: np.uint32 +u8: np.uint64 + +b_: np.bool[bool] +b0_: np.bool[FalseType] +b1_: np.bool[TrueType] + +b: bool +b0: FalseType +b1: TrueType + +i: int + +AR: npt.NDArray[np.int32] + +assert_type(i8 << i8, np.int64) +assert_type(i8 >> i8, np.int64) +assert_type(i8 | i8, np.int64) +assert_type(i8 ^ i8, np.int64) +assert_type(i8 & i8, np.int64) + +assert_type(i8 << AR, npt.NDArray[np.signedinteger]) +assert_type(i8 >> AR, npt.NDArray[np.signedinteger]) +assert_type(i8 | AR, npt.NDArray[np.signedinteger]) +assert_type(i8 ^ AR, npt.NDArray[np.signedinteger]) +assert_type(i8 & AR, npt.NDArray[np.signedinteger]) + +assert_type(i4 << i4, np.int32) +assert_type(i4 >> i4, np.int32) +assert_type(i4 | i4, np.int32) +assert_type(i4 ^ i4, np.int32) +assert_type(i4 & i4, np.int32) + +assert_type(i8 << i4, np.signedinteger) +assert_type(i8 >> i4, np.signedinteger) +assert_type(i8 | i4, np.signedinteger) +assert_type(i8 ^ i4, np.signedinteger) +assert_type(i8 & i4, np.signedinteger) + +assert_type(i8 << b_, np.int64) +assert_type(i8 >> b_, np.int64) +assert_type(i8 | b_, np.int64) +assert_type(i8 ^ b_, np.int64) +assert_type(i8 & b_, np.int64) + +assert_type(i8 << b, np.int64) +assert_type(i8 >> b, np.int64) +assert_type(i8 | b, np.int64) +assert_type(i8 ^ b, np.int64) +assert_type(i8 & b, np.int64) + +assert_type(u8 << u8, np.uint64) +assert_type(u8 >> u8, np.uint64) +assert_type(u8 | u8, np.uint64) +assert_type(u8 ^ u8, np.uint64) +assert_type(u8 & u8, np.uint64) + +assert_type(u8 << AR, npt.NDArray[np.signedinteger]) +assert_type(u8 >> AR, npt.NDArray[np.signedinteger]) +assert_type(u8 | AR, npt.NDArray[np.signedinteger]) +assert_type(u8 ^ AR, npt.NDArray[np.signedinteger]) +assert_type(u8 & AR, npt.NDArray[np.signedinteger]) + +assert_type(u4 << u4, np.uint32) +assert_type(u4 >> u4, np.uint32) +assert_type(u4 | u4, np.uint32) +assert_type(u4 ^ u4, np.uint32) +assert_type(u4 & u4, np.uint32) + +assert_type(u4 << i4, np.signedinteger) +assert_type(u4 >> i4, np.signedinteger) +assert_type(u4 | i4, np.signedinteger) +assert_type(u4 ^ i4, np.signedinteger) +assert_type(u4 & i4, np.signedinteger) + +assert_type(u4 << i, np.uint32) +assert_type(u4 >> i, np.uint32) +assert_type(u4 | i, np.uint32) +assert_type(u4 ^ i, np.uint32) +assert_type(u4 & i, np.uint32) + +assert_type(u8 << b_, np.uint64) +assert_type(u8 >> b_, np.uint64) +assert_type(u8 | b_, np.uint64) +assert_type(u8 ^ b_, np.uint64) +assert_type(u8 & b_, np.uint64) + +assert_type(u8 << b, np.uint64) +assert_type(u8 >> b, np.uint64) +assert_type(u8 | b, np.uint64) +assert_type(u8 ^ b, np.uint64) +assert_type(u8 & b, np.uint64) + +assert_type(b_ << b_, np.int8) +assert_type(b_ >> b_, np.int8) +assert_type(b_ | b_, np.bool) +assert_type(b_ ^ b_, np.bool) +assert_type(b_ & b_, np.bool) + +assert_type(b_ << AR, npt.NDArray[np.signedinteger]) +assert_type(b_ >> AR, npt.NDArray[np.signedinteger]) +assert_type(b_ | AR, npt.NDArray[np.signedinteger]) +assert_type(b_ ^ AR, npt.NDArray[np.signedinteger]) +assert_type(b_ & AR, npt.NDArray[np.signedinteger]) + +assert_type(b_ << b, np.int8) +assert_type(b_ >> b, np.int8) +assert_type(b_ | b, np.bool) +assert_type(b_ ^ b, np.bool) +assert_type(b_ & b, np.bool) + +assert_type(b_ << i, np.int_) +assert_type(b_ >> i, np.int_) +assert_type(b_ | i, np.bool | np.int_) +assert_type(b_ ^ i, np.bool | np.int_) +assert_type(b_ & i, np.bool | np.int_) + +assert_type(~i8, np.int64) +assert_type(~i4, np.int32) +assert_type(~u8, np.uint64) +assert_type(~u4, np.uint32) +assert_type(~b_, np.bool) +assert_type(~b0_, np.bool[TrueType]) +assert_type(~b1_, np.bool[FalseType]) +assert_type(~AR, npt.NDArray[np.int32]) + +assert_type(b_ | b0_, np.bool) +assert_type(b0_ | b_, np.bool) +assert_type(b_ | b1_, np.bool[TrueType]) +assert_type(b1_ | b_, np.bool[TrueType]) + +assert_type(b_ ^ b0_, np.bool) +assert_type(b0_ ^ b_, np.bool) +assert_type(b_ ^ b1_, np.bool) +assert_type(b1_ ^ b_, np.bool) + +assert_type(b_ & b0_, np.bool[FalseType]) +assert_type(b0_ & b_, np.bool[FalseType]) +assert_type(b_ & b1_, np.bool) +assert_type(b1_ & b_, np.bool) + +assert_type(b0_ | b0_, np.bool[FalseType]) +assert_type(b0_ | b1_, np.bool[TrueType]) +assert_type(b1_ | b0_, np.bool[TrueType]) +assert_type(b1_ | b1_, np.bool[TrueType]) + +assert_type(b0_ ^ b0_, np.bool[FalseType]) +assert_type(b0_ ^ b1_, np.bool[TrueType]) +assert_type(b1_ ^ b0_, np.bool[TrueType]) +assert_type(b1_ ^ b1_, np.bool[FalseType]) + +assert_type(b0_ & b0_, np.bool[FalseType]) +assert_type(b0_ & b1_, np.bool[FalseType]) +assert_type(b1_ & b0_, np.bool[FalseType]) +assert_type(b1_ & b1_, np.bool[TrueType]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/char.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/char.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5c6af73888d0f9af658e9010bdcb2b197d9c68b0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/char.pyi @@ -0,0 +1,224 @@ +from typing import TypeAlias, assert_type + +import numpy as np +import numpy._typing as np_t +import numpy.typing as npt + +AR_T_alias: TypeAlias = np.ndarray[np_t._AnyShape, np.dtypes.StringDType] +AR_TU_alias: TypeAlias = AR_T_alias | npt.NDArray[np.str_] + +AR_U: npt.NDArray[np.str_] +AR_S: npt.NDArray[np.bytes_] +AR_T: AR_T_alias + +assert_type(np.char.equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.char.equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.char.equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.not_equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.char.not_equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.char.not_equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.greater_equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.char.greater_equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.char.greater_equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.less_equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.char.less_equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.char.less_equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.greater(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.char.greater(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.char.greater(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.less(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.char.less(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.char.less(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.multiply(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.char.multiply(AR_S, [5, 4, 3]), npt.NDArray[np.bytes_]) +assert_type(np.char.multiply(AR_T, 5), AR_T_alias) + +assert_type(np.char.mod(AR_U, "test"), npt.NDArray[np.str_]) +assert_type(np.char.mod(AR_S, "test"), npt.NDArray[np.bytes_]) +assert_type(np.char.mod(AR_T, "test"), AR_T_alias) + +assert_type(np.char.capitalize(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.capitalize(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.char.capitalize(AR_T), AR_T_alias) + +assert_type(np.char.center(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.char.center(AR_S, [2, 3, 4], b"a"), npt.NDArray[np.bytes_]) +assert_type(np.char.center(AR_T, 5), AR_T_alias) + +assert_type(np.char.encode(AR_U), npt.NDArray[np.bytes_]) +assert_type(np.char.encode(AR_T), npt.NDArray[np.bytes_]) +assert_type(np.char.decode(AR_S), npt.NDArray[np.str_]) + +assert_type(np.char.expandtabs(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.expandtabs(AR_S, tabsize=4), npt.NDArray[np.bytes_]) +assert_type(np.char.expandtabs(AR_T), AR_T_alias) + +assert_type(np.char.join(AR_U, "_"), npt.NDArray[np.str_]) +assert_type(np.char.join(AR_S, [b"_", b""]), npt.NDArray[np.bytes_]) +assert_type(np.char.join(AR_T, "_"), AR_TU_alias) + +assert_type(np.char.ljust(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.char.ljust(AR_S, [4, 3, 1], fillchar=[b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.char.ljust(AR_T, 5), AR_T_alias) +assert_type(np.char.ljust(AR_T, [4, 2, 1], fillchar=["a", "b", "c"]), AR_TU_alias) + +assert_type(np.char.rjust(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.char.rjust(AR_S, [4, 3, 1], fillchar=[b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.char.rjust(AR_T, 5), AR_T_alias) +assert_type(np.char.rjust(AR_T, [4, 2, 1], fillchar=["a", "b", "c"]), AR_TU_alias) + +assert_type(np.char.lstrip(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.lstrip(AR_S, b"_"), npt.NDArray[np.bytes_]) +assert_type(np.char.lstrip(AR_T), AR_T_alias) +assert_type(np.char.lstrip(AR_T, "_"), AR_TU_alias) + +assert_type(np.char.rstrip(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.rstrip(AR_S, b"_"), npt.NDArray[np.bytes_]) +assert_type(np.char.rstrip(AR_T), AR_T_alias) +assert_type(np.char.rstrip(AR_T, "_"), AR_TU_alias) + +assert_type(np.char.strip(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.strip(AR_S, b"_"), npt.NDArray[np.bytes_]) +assert_type(np.char.strip(AR_T), AR_T_alias) +assert_type(np.char.strip(AR_T, "_"), AR_TU_alias) + +assert_type(np.char.count(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.char.count(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.char.count(AR_T, AR_T, start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.char.count(AR_T, ["a", "b", "c"], end=9), npt.NDArray[np.int_]) + +assert_type(np.char.partition(AR_U, "\n"), npt.NDArray[np.str_]) +assert_type(np.char.partition(AR_S, [b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.char.partition(AR_T, "\n"), AR_TU_alias) + +assert_type(np.char.rpartition(AR_U, "\n"), npt.NDArray[np.str_]) +assert_type(np.char.rpartition(AR_S, [b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.char.rpartition(AR_T, "\n"), AR_TU_alias) + +assert_type(np.char.replace(AR_U, "_", "-"), npt.NDArray[np.str_]) +assert_type(np.char.replace(AR_S, [b"_", b""], [b"a", b"b"]), npt.NDArray[np.bytes_]) +assert_type(np.char.replace(AR_T, "_", "_"), AR_TU_alias) + +assert_type(np.char.split(AR_U, "_"), npt.NDArray[np.object_]) +assert_type(np.char.split(AR_S, maxsplit=[1, 2, 3]), npt.NDArray[np.object_]) +assert_type(np.char.split(AR_T, "_"), npt.NDArray[np.object_]) + +assert_type(np.char.rsplit(AR_U, "_"), npt.NDArray[np.object_]) +assert_type(np.char.rsplit(AR_S, maxsplit=[1, 2, 3]), npt.NDArray[np.object_]) +assert_type(np.char.rsplit(AR_T, "_"), npt.NDArray[np.object_]) + +assert_type(np.char.splitlines(AR_U), npt.NDArray[np.object_]) +assert_type(np.char.splitlines(AR_S, keepends=[True, True, False]), npt.NDArray[np.object_]) +assert_type(np.char.splitlines(AR_T), npt.NDArray[np.object_]) + +assert_type(np.char.lower(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.lower(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.char.lower(AR_T), AR_T_alias) + +assert_type(np.char.upper(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.upper(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.char.upper(AR_T), AR_T_alias) + +assert_type(np.char.swapcase(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.swapcase(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.char.swapcase(AR_T), AR_T_alias) + +assert_type(np.char.title(AR_U), npt.NDArray[np.str_]) +assert_type(np.char.title(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.char.title(AR_T), AR_T_alias) + +assert_type(np.char.zfill(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.char.zfill(AR_S, [2, 3, 4]), npt.NDArray[np.bytes_]) +assert_type(np.char.zfill(AR_T, 5), AR_T_alias) + +assert_type(np.char.endswith(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) +assert_type(np.char.endswith(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.bool]) +assert_type(np.char.endswith(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) + +assert_type(np.char.startswith(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) +assert_type(np.char.startswith(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.bool]) +assert_type(np.char.startswith(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) + +assert_type(np.char.find(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.char.find(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.char.find(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.char.rfind(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.char.rfind(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.char.rfind(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.char.index(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.char.index(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.char.index(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.char.rindex(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.char.rindex(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.char.rindex(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.char.isalpha(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.isalpha(AR_S), npt.NDArray[np.bool]) +assert_type(np.char.isalpha(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.isalnum(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.isalnum(AR_S), npt.NDArray[np.bool]) +assert_type(np.char.isalnum(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.isdecimal(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.isdecimal(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.isdigit(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.isdigit(AR_S), npt.NDArray[np.bool]) +assert_type(np.char.isdigit(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.islower(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.islower(AR_S), npt.NDArray[np.bool]) +assert_type(np.char.islower(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.isnumeric(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.isnumeric(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.isspace(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.isspace(AR_S), npt.NDArray[np.bool]) +assert_type(np.char.isspace(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.istitle(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.istitle(AR_S), npt.NDArray[np.bool]) +assert_type(np.char.istitle(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.isupper(AR_U), npt.NDArray[np.bool]) +assert_type(np.char.isupper(AR_S), npt.NDArray[np.bool]) +assert_type(np.char.isupper(AR_T), npt.NDArray[np.bool]) + +assert_type(np.char.str_len(AR_U), npt.NDArray[np.int_]) +assert_type(np.char.str_len(AR_S), npt.NDArray[np.int_]) +assert_type(np.char.str_len(AR_T), npt.NDArray[np.int_]) + +assert_type(np.char.translate(AR_U, ""), npt.NDArray[np.str_]) +assert_type(np.char.translate(AR_S, ""), npt.NDArray[np.bytes_]) +assert_type(np.char.translate(AR_T, ""), AR_T_alias) + +assert_type(np.char.array(AR_U), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) +assert_type(np.char.array(AR_S, order="K"), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.array("bob", copy=True), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) +assert_type(np.char.array(b"bob", itemsize=5), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.array(1, unicode=False), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.array(1, unicode=True), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) +assert_type(np.char.array(1), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]] | np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.array(AR_U, unicode=False), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.array(AR_S, unicode=True), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) + +assert_type(np.char.asarray(AR_U), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) +assert_type(np.char.asarray(AR_S, order="K"), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.asarray("bob"), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) +assert_type(np.char.asarray(b"bob", itemsize=5), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.asarray(1, unicode=False), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.asarray(1, unicode=True), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) +assert_type(np.char.asarray(1), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]] | np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.asarray(AR_U, unicode=False), np.char.chararray[np_t._AnyShape, np.dtype[np.bytes_]]) +assert_type(np.char.asarray(AR_S, unicode=True), np.char.chararray[np_t._AnyShape, np.dtype[np.str_]]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/chararray.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/chararray.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b5f4392b75c8743682ef7b7d2d4139825d3a90dd --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/chararray.pyi @@ -0,0 +1,137 @@ +from typing import Any, TypeAlias, assert_type + +import numpy as np +import numpy.typing as npt + +_BytesCharArray: TypeAlias = np.char.chararray[tuple[Any, ...], np.dtype[np.bytes_]] +_StrCharArray: TypeAlias = np.char.chararray[tuple[Any, ...], np.dtype[np.str_]] + +AR_U: _StrCharArray +AR_S: _BytesCharArray + +assert_type(AR_U == AR_U, npt.NDArray[np.bool]) +assert_type(AR_S == AR_S, npt.NDArray[np.bool]) + +assert_type(AR_U != AR_U, npt.NDArray[np.bool]) +assert_type(AR_S != AR_S, npt.NDArray[np.bool]) + +assert_type(AR_U >= AR_U, npt.NDArray[np.bool]) +assert_type(AR_S >= AR_S, npt.NDArray[np.bool]) + +assert_type(AR_U <= AR_U, npt.NDArray[np.bool]) +assert_type(AR_S <= AR_S, npt.NDArray[np.bool]) + +assert_type(AR_U > AR_U, npt.NDArray[np.bool]) +assert_type(AR_S > AR_S, npt.NDArray[np.bool]) + +assert_type(AR_U < AR_U, npt.NDArray[np.bool]) +assert_type(AR_S < AR_S, npt.NDArray[np.bool]) + +assert_type(AR_U * 5, _StrCharArray) +assert_type(AR_S * [5], _BytesCharArray) + +assert_type(AR_U % "test", _StrCharArray) +assert_type(AR_S % b"test", _BytesCharArray) + +assert_type(AR_U.capitalize(), _StrCharArray) +assert_type(AR_S.capitalize(), _BytesCharArray) + +assert_type(AR_U.center(5), _StrCharArray) +assert_type(AR_S.center([2, 3, 4], b"a"), _BytesCharArray) + +assert_type(AR_U.encode(), _BytesCharArray) +assert_type(AR_S.decode(), _StrCharArray) + +assert_type(AR_U.expandtabs(), _StrCharArray) +assert_type(AR_S.expandtabs(tabsize=4), _BytesCharArray) + +assert_type(AR_U.join("_"), _StrCharArray) +assert_type(AR_S.join([b"_", b""]), _BytesCharArray) + +assert_type(AR_U.ljust(5), _StrCharArray) +assert_type(AR_S.ljust([4, 3, 1], fillchar=[b"a", b"b", b"c"]), _BytesCharArray) +assert_type(AR_U.rjust(5), _StrCharArray) +assert_type(AR_S.rjust([4, 3, 1], fillchar=[b"a", b"b", b"c"]), _BytesCharArray) + +assert_type(AR_U.lstrip(), _StrCharArray) +assert_type(AR_S.lstrip(chars=b"_"), _BytesCharArray) +assert_type(AR_U.rstrip(), _StrCharArray) +assert_type(AR_S.rstrip(chars=b"_"), _BytesCharArray) +assert_type(AR_U.strip(), _StrCharArray) +assert_type(AR_S.strip(chars=b"_"), _BytesCharArray) + +assert_type(AR_U.partition("\n"), _StrCharArray) +assert_type(AR_S.partition([b"a", b"b", b"c"]), _BytesCharArray) +assert_type(AR_U.rpartition("\n"), _StrCharArray) +assert_type(AR_S.rpartition([b"a", b"b", b"c"]), _BytesCharArray) + +assert_type(AR_U.replace("_", "-"), _StrCharArray) +assert_type(AR_S.replace([b"_", b""], [b"a", b"b"]), _BytesCharArray) + +assert_type(AR_U.split("_"), npt.NDArray[np.object_]) +assert_type(AR_S.split(maxsplit=[1, 2, 3]), npt.NDArray[np.object_]) +assert_type(AR_U.rsplit("_"), npt.NDArray[np.object_]) +assert_type(AR_S.rsplit(maxsplit=[1, 2, 3]), npt.NDArray[np.object_]) + +assert_type(AR_U.splitlines(), npt.NDArray[np.object_]) +assert_type(AR_S.splitlines(keepends=[True, True, False]), npt.NDArray[np.object_]) + +assert_type(AR_U.swapcase(), _StrCharArray) +assert_type(AR_S.swapcase(), _BytesCharArray) + +assert_type(AR_U.title(), _StrCharArray) +assert_type(AR_S.title(), _BytesCharArray) + +assert_type(AR_U.upper(), _StrCharArray) +assert_type(AR_S.upper(), _BytesCharArray) + +assert_type(AR_U.zfill(5), _StrCharArray) +assert_type(AR_S.zfill([2, 3, 4]), _BytesCharArray) + +assert_type(AR_U.count("a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(AR_S.count([b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) + +assert_type(AR_U.endswith("a", start=[1, 2, 3]), npt.NDArray[np.bool]) +assert_type(AR_S.endswith([b"a", b"b", b"c"], end=9), npt.NDArray[np.bool]) +assert_type(AR_U.startswith("a", start=[1, 2, 3]), npt.NDArray[np.bool]) +assert_type(AR_S.startswith([b"a", b"b", b"c"], end=9), npt.NDArray[np.bool]) + +assert_type(AR_U.find("a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(AR_S.find([b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(AR_U.rfind("a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(AR_S.rfind([b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) + +assert_type(AR_U.index("a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(AR_S.index([b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(AR_U.rindex("a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(AR_S.rindex([b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) + +assert_type(AR_U.isalpha(), npt.NDArray[np.bool]) +assert_type(AR_S.isalpha(), npt.NDArray[np.bool]) + +assert_type(AR_U.isalnum(), npt.NDArray[np.bool]) +assert_type(AR_S.isalnum(), npt.NDArray[np.bool]) + +assert_type(AR_U.isdecimal(), npt.NDArray[np.bool]) +assert_type(AR_S.isdecimal(), npt.NDArray[np.bool]) + +assert_type(AR_U.isdigit(), npt.NDArray[np.bool]) +assert_type(AR_S.isdigit(), npt.NDArray[np.bool]) + +assert_type(AR_U.islower(), npt.NDArray[np.bool]) +assert_type(AR_S.islower(), npt.NDArray[np.bool]) + +assert_type(AR_U.isnumeric(), npt.NDArray[np.bool]) +assert_type(AR_S.isnumeric(), npt.NDArray[np.bool]) + +assert_type(AR_U.isspace(), npt.NDArray[np.bool]) +assert_type(AR_S.isspace(), npt.NDArray[np.bool]) + +assert_type(AR_U.istitle(), npt.NDArray[np.bool]) +assert_type(AR_S.istitle(), npt.NDArray[np.bool]) + +assert_type(AR_U.isupper(), npt.NDArray[np.bool]) +assert_type(AR_S.isupper(), npt.NDArray[np.bool]) + +assert_type(AR_U.__array_finalize__(object()), None) +assert_type(AR_S.__array_finalize__(object()), None) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/comparisons.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/comparisons.pyi new file mode 100644 index 0000000000000000000000000000000000000000..2165d17fce349a8e62ce260bcdae0c1a72786c85 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/comparisons.pyi @@ -0,0 +1,264 @@ +import decimal +import fractions +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +c16 = np.complex128() +f8 = np.float64() +i8 = np.int64() +u8 = np.uint64() + +c8 = np.complex64() +f4 = np.float32() +i4 = np.int32() +u4 = np.uint32() + +dt = np.datetime64(0, "D") +td = np.timedelta64(0, "D") + +b_ = np.bool() + +b = bool() +c = complex() +f = float() +i = int() + +AR = np.array([0], dtype=np.int64) +AR.setflags(write=False) + +SEQ = (0, 1, 2, 3, 4) + +# object-like comparisons + +assert_type(i8 > fractions.Fraction(1, 5), np.bool) +assert_type(i8 > [fractions.Fraction(1, 5)], npt.NDArray[np.bool]) +assert_type(i8 > decimal.Decimal("1.5"), np.bool) +assert_type(i8 > [decimal.Decimal("1.5")], npt.NDArray[np.bool]) + +# Time structures + +assert_type(dt > dt, np.bool) + +assert_type(td > td, np.bool) +assert_type(td > i, np.bool) +assert_type(td > i4, np.bool) +assert_type(td > i8, np.bool) + +assert_type(td > AR, npt.NDArray[np.bool]) +assert_type(td > SEQ, npt.NDArray[np.bool]) +assert_type(AR > SEQ, npt.NDArray[np.bool]) +assert_type(AR > td, npt.NDArray[np.bool]) +assert_type(SEQ > td, npt.NDArray[np.bool]) +assert_type(SEQ > AR, npt.NDArray[np.bool]) + +# boolean + +assert_type(b_ > b, np.bool) +assert_type(b_ > b_, np.bool) +assert_type(b_ > i, np.bool) +assert_type(b_ > i8, np.bool) +assert_type(b_ > i4, np.bool) +assert_type(b_ > u8, np.bool) +assert_type(b_ > u4, np.bool) +assert_type(b_ > f, np.bool) +assert_type(b_ > f8, np.bool) +assert_type(b_ > f4, np.bool) +assert_type(b_ > c, np.bool) +assert_type(b_ > c16, np.bool) +assert_type(b_ > c8, np.bool) +assert_type(b_ > AR, npt.NDArray[np.bool]) +assert_type(b_ > SEQ, npt.NDArray[np.bool]) + +# Complex + +assert_type(c16 > c16, np.bool) +assert_type(c16 > f8, np.bool) +assert_type(c16 > i8, np.bool) +assert_type(c16 > c8, np.bool) +assert_type(c16 > f4, np.bool) +assert_type(c16 > i4, np.bool) +assert_type(c16 > b_, np.bool) +assert_type(c16 > b, np.bool) +assert_type(c16 > c, np.bool) +assert_type(c16 > f, np.bool) +assert_type(c16 > i, np.bool) +assert_type(c16 > AR, npt.NDArray[np.bool]) +assert_type(c16 > SEQ, npt.NDArray[np.bool]) + +assert_type(c16 > c16, np.bool) +assert_type(f8 > c16, np.bool) +assert_type(i8 > c16, np.bool) +assert_type(c8 > c16, np.bool) +assert_type(f4 > c16, np.bool) +assert_type(i4 > c16, np.bool) +assert_type(b_ > c16, np.bool) +assert_type(b > c16, np.bool) +assert_type(c > c16, np.bool) +assert_type(f > c16, np.bool) +assert_type(i > c16, np.bool) +assert_type(AR > c16, npt.NDArray[np.bool]) +assert_type(SEQ > c16, npt.NDArray[np.bool]) + +assert_type(c8 > c16, np.bool) +assert_type(c8 > f8, np.bool) +assert_type(c8 > i8, np.bool) +assert_type(c8 > c8, np.bool) +assert_type(c8 > f4, np.bool) +assert_type(c8 > i4, np.bool) +assert_type(c8 > b_, np.bool) +assert_type(c8 > b, np.bool) +assert_type(c8 > c, np.bool) +assert_type(c8 > f, np.bool) +assert_type(c8 > i, np.bool) +assert_type(c8 > AR, npt.NDArray[np.bool]) +assert_type(c8 > SEQ, npt.NDArray[np.bool]) + +assert_type(c16 > c8, np.bool) +assert_type(f8 > c8, np.bool) +assert_type(i8 > c8, np.bool) +assert_type(c8 > c8, np.bool) +assert_type(f4 > c8, np.bool) +assert_type(i4 > c8, np.bool) +assert_type(b_ > c8, np.bool) +assert_type(b > c8, np.bool) +assert_type(c > c8, np.bool) +assert_type(f > c8, np.bool) +assert_type(i > c8, np.bool) +assert_type(AR > c8, npt.NDArray[np.bool]) +assert_type(SEQ > c8, npt.NDArray[np.bool]) + +# Float + +assert_type(f8 > f8, np.bool) +assert_type(f8 > i8, np.bool) +assert_type(f8 > f4, np.bool) +assert_type(f8 > i4, np.bool) +assert_type(f8 > b_, np.bool) +assert_type(f8 > b, np.bool) +assert_type(f8 > c, np.bool) +assert_type(f8 > f, np.bool) +assert_type(f8 > i, np.bool) +assert_type(f8 > AR, npt.NDArray[np.bool]) +assert_type(f8 > SEQ, npt.NDArray[np.bool]) + +assert_type(f8 > f8, np.bool) +assert_type(i8 > f8, np.bool) +assert_type(f4 > f8, np.bool) +assert_type(i4 > f8, np.bool) +assert_type(b_ > f8, np.bool) +assert_type(b > f8, np.bool) +assert_type(c > f8, np.bool) +assert_type(f > f8, np.bool) +assert_type(i > f8, np.bool) +assert_type(AR > f8, npt.NDArray[np.bool]) +assert_type(SEQ > f8, npt.NDArray[np.bool]) + +assert_type(f4 > f8, np.bool) +assert_type(f4 > i8, np.bool) +assert_type(f4 > f4, np.bool) +assert_type(f4 > i4, np.bool) +assert_type(f4 > b_, np.bool) +assert_type(f4 > b, np.bool) +assert_type(f4 > c, np.bool) +assert_type(f4 > f, np.bool) +assert_type(f4 > i, np.bool) +assert_type(f4 > AR, npt.NDArray[np.bool]) +assert_type(f4 > SEQ, npt.NDArray[np.bool]) + +assert_type(f8 > f4, np.bool) +assert_type(i8 > f4, np.bool) +assert_type(f4 > f4, np.bool) +assert_type(i4 > f4, np.bool) +assert_type(b_ > f4, np.bool) +assert_type(b > f4, np.bool) +assert_type(c > f4, np.bool) +assert_type(f > f4, np.bool) +assert_type(i > f4, np.bool) +assert_type(AR > f4, npt.NDArray[np.bool]) +assert_type(SEQ > f4, npt.NDArray[np.bool]) + +# Int + +assert_type(i8 > i8, np.bool) +assert_type(i8 > u8, np.bool) +assert_type(i8 > i4, np.bool) +assert_type(i8 > u4, np.bool) +assert_type(i8 > b_, np.bool) +assert_type(i8 > b, np.bool) +assert_type(i8 > c, np.bool) +assert_type(i8 > f, np.bool) +assert_type(i8 > i, np.bool) +assert_type(i8 > AR, npt.NDArray[np.bool]) +assert_type(i8 > SEQ, npt.NDArray[np.bool]) + +assert_type(u8 > u8, np.bool) +assert_type(u8 > i4, np.bool) +assert_type(u8 > u4, np.bool) +assert_type(u8 > b_, np.bool) +assert_type(u8 > b, np.bool) +assert_type(u8 > c, np.bool) +assert_type(u8 > f, np.bool) +assert_type(u8 > i, np.bool) +assert_type(u8 > AR, npt.NDArray[np.bool]) +assert_type(u8 > SEQ, npt.NDArray[np.bool]) + +assert_type(i8 > i8, np.bool) +assert_type(u8 > i8, np.bool) +assert_type(i4 > i8, np.bool) +assert_type(u4 > i8, np.bool) +assert_type(b_ > i8, np.bool) +assert_type(b > i8, np.bool) +assert_type(c > i8, np.bool) +assert_type(f > i8, np.bool) +assert_type(i > i8, np.bool) +assert_type(AR > i8, npt.NDArray[np.bool]) +assert_type(SEQ > i8, npt.NDArray[np.bool]) + +assert_type(u8 > u8, np.bool) +assert_type(i4 > u8, np.bool) +assert_type(u4 > u8, np.bool) +assert_type(b_ > u8, np.bool) +assert_type(b > u8, np.bool) +assert_type(c > u8, np.bool) +assert_type(f > u8, np.bool) +assert_type(i > u8, np.bool) +assert_type(AR > u8, npt.NDArray[np.bool]) +assert_type(SEQ > u8, npt.NDArray[np.bool]) + +assert_type(i4 > i8, np.bool) +assert_type(i4 > i4, np.bool) +assert_type(i4 > i, np.bool) +assert_type(i4 > b_, np.bool) +assert_type(i4 > b, np.bool) +assert_type(i4 > AR, npt.NDArray[np.bool]) +assert_type(i4 > SEQ, npt.NDArray[np.bool]) + +assert_type(u4 > i8, np.bool) +assert_type(u4 > i4, np.bool) +assert_type(u4 > u8, np.bool) +assert_type(u4 > u4, np.bool) +assert_type(u4 > i, np.bool) +assert_type(u4 > b_, np.bool) +assert_type(u4 > b, np.bool) +assert_type(u4 > AR, npt.NDArray[np.bool]) +assert_type(u4 > SEQ, npt.NDArray[np.bool]) + +assert_type(i8 > i4, np.bool) +assert_type(i4 > i4, np.bool) +assert_type(i > i4, np.bool) +assert_type(b_ > i4, np.bool) +assert_type(b > i4, np.bool) +assert_type(AR > i4, npt.NDArray[np.bool]) +assert_type(SEQ > i4, npt.NDArray[np.bool]) + +assert_type(i8 > u4, np.bool) +assert_type(i4 > u4, np.bool) +assert_type(u8 > u4, np.bool) +assert_type(u4 > u4, np.bool) +assert_type(b_ > u4, np.bool) +assert_type(b > u4, np.bool) +assert_type(i > u4, np.bool) +assert_type(AR > u4, npt.NDArray[np.bool]) +assert_type(SEQ > u4, npt.NDArray[np.bool]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/constants.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/constants.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d4474f46ce7e19571125dc45669692a1c501a5b9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/constants.pyi @@ -0,0 +1,14 @@ +from typing import Literal, assert_type + +import numpy as np + +assert_type(np.e, float) +assert_type(np.euler_gamma, float) +assert_type(np.inf, float) +assert_type(np.nan, float) +assert_type(np.pi, float) + +assert_type(np.little_endian, bool) + +assert_type(np.True_, np.bool[Literal[True]]) +assert_type(np.False_, np.bool[Literal[False]]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ctypeslib.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ctypeslib.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0564d725cf62f2419460b0f7cdabfd890947438e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ctypeslib.pyi @@ -0,0 +1,81 @@ +import ctypes as ct +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy import ctypeslib + +AR_bool: npt.NDArray[np.bool] +AR_ubyte: npt.NDArray[np.ubyte] +AR_ushort: npt.NDArray[np.ushort] +AR_uintc: npt.NDArray[np.uintc] +AR_ulong: npt.NDArray[np.ulong] +AR_ulonglong: npt.NDArray[np.ulonglong] +AR_byte: npt.NDArray[np.byte] +AR_short: npt.NDArray[np.short] +AR_intc: npt.NDArray[np.intc] +AR_long: npt.NDArray[np.long] +AR_longlong: npt.NDArray[np.longlong] +AR_single: npt.NDArray[np.single] +AR_double: npt.NDArray[np.double] +AR_longdouble: npt.NDArray[np.longdouble] +AR_void: npt.NDArray[np.void] + +pointer: ct._Pointer[Any] + +assert_type(np.ctypeslib.c_intp(), ctypeslib.c_intp) + +assert_type(np.ctypeslib.ndpointer(), type[ctypeslib._ndptr[None]]) +assert_type(np.ctypeslib.ndpointer(dtype=np.float64), type[ctypeslib._ndptr[np.dtype[np.float64]]]) +assert_type(np.ctypeslib.ndpointer(dtype=float), type[ctypeslib._ndptr[np.dtype]]) +assert_type(np.ctypeslib.ndpointer(shape=(10, 3)), type[ctypeslib._ndptr[None]]) +assert_type(np.ctypeslib.ndpointer(np.int64, shape=(10, 3)), type[ctypeslib._concrete_ndptr[np.dtype[np.int64]]]) +assert_type(np.ctypeslib.ndpointer(int, shape=(1,)), type[np.ctypeslib._concrete_ndptr[np.dtype]]) + +assert_type(np.ctypeslib.as_ctypes_type(np.bool), type[ct.c_bool]) +assert_type(np.ctypeslib.as_ctypes_type(np.ubyte), type[ct.c_ubyte]) +assert_type(np.ctypeslib.as_ctypes_type(np.ushort), type[ct.c_ushort]) +assert_type(np.ctypeslib.as_ctypes_type(np.uintc), type[ct.c_uint]) +assert_type(np.ctypeslib.as_ctypes_type(np.byte), type[ct.c_byte]) +assert_type(np.ctypeslib.as_ctypes_type(np.short), type[ct.c_short]) +assert_type(np.ctypeslib.as_ctypes_type(np.intc), type[ct.c_int]) +assert_type(np.ctypeslib.as_ctypes_type(np.single), type[ct.c_float]) +assert_type(np.ctypeslib.as_ctypes_type(np.double), type[ct.c_double]) +assert_type(np.ctypeslib.as_ctypes_type(ct.c_double), type[ct.c_double]) +assert_type(np.ctypeslib.as_ctypes_type("q"), type[ct.c_longlong]) +assert_type(np.ctypeslib.as_ctypes_type([("i8", np.int64), ("f8", np.float64)]), type[Any]) +assert_type(np.ctypeslib.as_ctypes_type("i8"), type[Any]) +assert_type(np.ctypeslib.as_ctypes_type("f8"), type[Any]) + +assert_type(np.ctypeslib.as_ctypes(AR_bool.take(0)), ct.c_bool) +assert_type(np.ctypeslib.as_ctypes(AR_ubyte.take(0)), ct.c_ubyte) +assert_type(np.ctypeslib.as_ctypes(AR_ushort.take(0)), ct.c_ushort) +assert_type(np.ctypeslib.as_ctypes(AR_uintc.take(0)), ct.c_uint) + +assert_type(np.ctypeslib.as_ctypes(AR_byte.take(0)), ct.c_byte) +assert_type(np.ctypeslib.as_ctypes(AR_short.take(0)), ct.c_short) +assert_type(np.ctypeslib.as_ctypes(AR_intc.take(0)), ct.c_int) +assert_type(np.ctypeslib.as_ctypes(AR_single.take(0)), ct.c_float) +assert_type(np.ctypeslib.as_ctypes(AR_double.take(0)), ct.c_double) +assert_type(np.ctypeslib.as_ctypes(AR_void.take(0)), Any) +assert_type(np.ctypeslib.as_ctypes(AR_bool), ct.Array[ct.c_bool]) +assert_type(np.ctypeslib.as_ctypes(AR_ubyte), ct.Array[ct.c_ubyte]) +assert_type(np.ctypeslib.as_ctypes(AR_ushort), ct.Array[ct.c_ushort]) +assert_type(np.ctypeslib.as_ctypes(AR_uintc), ct.Array[ct.c_uint]) +assert_type(np.ctypeslib.as_ctypes(AR_byte), ct.Array[ct.c_byte]) +assert_type(np.ctypeslib.as_ctypes(AR_short), ct.Array[ct.c_short]) +assert_type(np.ctypeslib.as_ctypes(AR_intc), ct.Array[ct.c_int]) +assert_type(np.ctypeslib.as_ctypes(AR_single), ct.Array[ct.c_float]) +assert_type(np.ctypeslib.as_ctypes(AR_double), ct.Array[ct.c_double]) +assert_type(np.ctypeslib.as_ctypes(AR_void), ct.Array[Any]) + +assert_type(np.ctypeslib.as_array(AR_ubyte), npt.NDArray[np.ubyte]) +assert_type(np.ctypeslib.as_array(1), npt.NDArray[Any]) +assert_type(np.ctypeslib.as_array(pointer), npt.NDArray[Any]) + +assert_type(np.ctypeslib.as_ctypes_type(np.long), type[ct.c_long]) +assert_type(np.ctypeslib.as_ctypes_type(np.ulong), type[ct.c_ulong]) +assert_type(np.ctypeslib.as_ctypes(AR_ulong), ct.Array[ct.c_ulong]) +assert_type(np.ctypeslib.as_ctypes(AR_long), ct.Array[ct.c_long]) +assert_type(np.ctypeslib.as_ctypes(AR_long.take(0)), ct.c_long) +assert_type(np.ctypeslib.as_ctypes(AR_ulong.take(0)), ct.c_ulong) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/datasource.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/datasource.pyi new file mode 100644 index 0000000000000000000000000000000000000000..9f017911a1ff89dd1f80c059fb13976bc60cd12b --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/datasource.pyi @@ -0,0 +1,23 @@ +from pathlib import Path +from typing import IO, Any, assert_type + +import numpy as np + +path1: Path +path2: str + +d1 = np.lib.npyio.DataSource(path1) +d2 = np.lib.npyio.DataSource(path2) +d3 = np.lib.npyio.DataSource(None) + +assert_type(d1.abspath("..."), str) +assert_type(d2.abspath("..."), str) +assert_type(d3.abspath("..."), str) + +assert_type(d1.exists("..."), bool) +assert_type(d2.exists("..."), bool) +assert_type(d3.exists("..."), bool) + +assert_type(d1.open("...", "r"), IO[Any]) +assert_type(d2.open("...", encoding="utf8"), IO[Any]) +assert_type(d3.open("...", newline="/n"), IO[Any]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/dtype.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/dtype.pyi new file mode 100644 index 0000000000000000000000000000000000000000..721d2708737f7f59ccfc34b691e584cb944dd847 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/dtype.pyi @@ -0,0 +1,136 @@ +import ctypes as ct +import datetime as dt +from decimal import Decimal +from fractions import Fraction +from typing import Any, Literal, LiteralString, TypeAlias, assert_type + +import numpy as np +from numpy.dtypes import StringDType + +# a combination of likely `object` dtype-like candidates (no `_co`) +_PyObjectLike: TypeAlias = Decimal | Fraction | dt.datetime | dt.timedelta + +dtype_U: np.dtype[np.str_] +dtype_V: np.dtype[np.void] +dtype_i8: np.dtype[np.int64] + +py_int_co: type[int] +py_float_co: type[float] +py_complex_co: type[complex] +py_object: type[_PyObjectLike] +py_character: type[str | bytes] +py_flexible: type[str | bytes | memoryview] + +ct_floating: type[ct.c_float | ct.c_double | ct.c_longdouble] +ct_number: type[ct.c_uint8 | ct.c_float] +ct_generic: type[ct.c_bool | ct.c_char] + +cs_integer: Literal["u1", "V", "S"] +cs_generic: Literal["H", "U", "h", "|M8[Y]", "?"] + +dt_inexact: np.dtype[np.inexact] +dt_string: StringDType + +assert_type(np.dtype(np.float64), np.dtype[np.float64]) +assert_type(np.dtype(np.float64, metadata={"test": "test"}), np.dtype[np.float64]) +assert_type(np.dtype(np.int64), np.dtype[np.int64]) + +# String aliases +assert_type(np.dtype("float64"), np.dtype[np.float64]) +assert_type(np.dtype("float32"), np.dtype[np.float32]) +assert_type(np.dtype("int64"), np.dtype[np.int64]) +assert_type(np.dtype("int32"), np.dtype[np.int32]) +assert_type(np.dtype("bool"), np.dtype[np.bool]) +assert_type(np.dtype("bytes"), np.dtype[np.bytes_]) +assert_type(np.dtype("str"), np.dtype[np.str_]) + +# Python types +assert_type(np.dtype(bool), np.dtype[np.bool]) +assert_type(np.dtype(py_int_co), np.dtype[np.int_ | np.bool]) +assert_type(np.dtype(int), np.dtype[np.int_ | np.bool]) +assert_type(np.dtype(py_float_co), np.dtype[np.float64 | np.int_ | np.bool]) +assert_type(np.dtype(float), np.dtype[np.float64 | np.int_ | np.bool]) +assert_type(np.dtype(py_complex_co), np.dtype[np.complex128 | np.float64 | np.int_ | np.bool]) +assert_type(np.dtype(complex), np.dtype[np.complex128 | np.float64 | np.int_ | np.bool]) +assert_type(np.dtype(py_object), np.dtype[np.object_]) +assert_type(np.dtype(str), np.dtype[np.str_]) +assert_type(np.dtype(bytes), np.dtype[np.bytes_]) +assert_type(np.dtype(py_character), np.dtype[np.character]) +assert_type(np.dtype(memoryview), np.dtype[np.void]) +assert_type(np.dtype(py_flexible), np.dtype[np.flexible]) + +assert_type(np.dtype(list), np.dtype[np.object_]) +assert_type(np.dtype(dt.datetime), np.dtype[np.object_]) +assert_type(np.dtype(dt.timedelta), np.dtype[np.object_]) +assert_type(np.dtype(Decimal), np.dtype[np.object_]) +assert_type(np.dtype(Fraction), np.dtype[np.object_]) + +# char-codes +assert_type(np.dtype("?"), np.dtype[np.bool]) +assert_type(np.dtype("|b1"), np.dtype[np.bool]) +assert_type(np.dtype("u1"), np.dtype[np.uint8]) +assert_type(np.dtype("l"), np.dtype[np.long]) +assert_type(np.dtype("longlong"), np.dtype[np.longlong]) +assert_type(np.dtype(">g"), np.dtype[np.longdouble]) +assert_type(np.dtype(cs_integer), np.dtype[np.integer]) +assert_type(np.dtype(cs_number), np.dtype[np.number]) +assert_type(np.dtype(cs_flex), np.dtype[np.flexible]) +assert_type(np.dtype(cs_generic), np.dtype[np.generic]) + +# ctypes +assert_type(np.dtype(ct.c_double), np.dtype[np.double]) +assert_type(np.dtype(ct.c_longlong), np.dtype[np.longlong]) +assert_type(np.dtype(ct.c_uint32), np.dtype[np.uint32]) +assert_type(np.dtype(ct.c_bool), np.dtype[np.bool]) +assert_type(np.dtype(ct.c_char), np.dtype[np.bytes_]) +assert_type(np.dtype(ct.py_object), np.dtype[np.object_]) + +# Special case for None +assert_type(np.dtype(None), np.dtype[np.float64]) + +# Dypes of dtypes +assert_type(np.dtype(np.dtype(np.float64)), np.dtype[np.float64]) +assert_type(np.dtype(dt_inexact), np.dtype[np.inexact]) + +# Parameterized dtypes +assert_type(np.dtype("S8"), np.dtype) + +# Void +assert_type(np.dtype(("U", 10)), np.dtype[np.void]) + +# StringDType +assert_type(np.dtype(dt_string), StringDType) +assert_type(np.dtype("T"), StringDType) +assert_type(np.dtype("=T"), StringDType) +assert_type(np.dtype("|T"), StringDType) + +# Methods and attributes +assert_type(dtype_U.base, np.dtype) +assert_type(dtype_U.subdtype, tuple[np.dtype, tuple[Any, ...]] | None) +assert_type(dtype_U.newbyteorder(), np.dtype[np.str_]) +assert_type(dtype_U.type, type[np.str_]) +assert_type(dtype_U.name, LiteralString) +assert_type(dtype_U.names, tuple[str, ...] | None) + +assert_type(dtype_U * 0, np.dtype[np.str_]) +assert_type(dtype_U * 1, np.dtype[np.str_]) +assert_type(dtype_U * 2, np.dtype[np.str_]) + +assert_type(dtype_i8 * 0, np.dtype[np.void]) +assert_type(dtype_i8 * 1, np.dtype[np.int64]) +assert_type(dtype_i8 * 2, np.dtype[np.void]) + +assert_type(0 * dtype_U, np.dtype[np.str_]) +assert_type(1 * dtype_U, np.dtype[np.str_]) +assert_type(2 * dtype_U, np.dtype[np.str_]) + +assert_type(0 * dtype_i8, np.dtype) +assert_type(1 * dtype_i8, np.dtype) +assert_type(2 * dtype_i8, np.dtype) + +assert_type(dtype_V["f0"], np.dtype) +assert_type(dtype_V[0], np.dtype) +assert_type(dtype_V[["f0", "f1"]], np.dtype[np.void]) +assert_type(dtype_V[["f0"]], np.dtype[np.void]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/einsumfunc.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/einsumfunc.pyi new file mode 100644 index 0000000000000000000000000000000000000000..cc58f006e249bb5fd520fa45bbe1f1690e4c203d --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/einsumfunc.pyi @@ -0,0 +1,39 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +AR_LIKE_b: list[bool] +AR_LIKE_u: list[np.uint32] +AR_LIKE_i: list[int] +AR_LIKE_f: list[float] +AR_LIKE_c: list[complex] +AR_LIKE_U: list[str] +AR_o: npt.NDArray[np.object_] + +OUT_f: npt.NDArray[np.float64] + +assert_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_b), Any) +assert_type(np.einsum("i,i->i", AR_o, AR_o), Any) +assert_type(np.einsum("i,i->i", AR_LIKE_u, AR_LIKE_u), Any) +assert_type(np.einsum("i,i->i", AR_LIKE_i, AR_LIKE_i), Any) +assert_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f), Any) +assert_type(np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c), Any) +assert_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_i), Any) +assert_type(np.einsum("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c), Any) + +assert_type(np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c, out=OUT_f), npt.NDArray[np.float64]) +assert_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe", out=OUT_f), npt.NDArray[np.float64]) +assert_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, dtype="c16"), Any) +assert_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe"), Any) + +assert_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_b), tuple[list[Any], str]) +assert_type(np.einsum_path("i,i->i", AR_LIKE_u, AR_LIKE_u), tuple[list[Any], str]) +assert_type(np.einsum_path("i,i->i", AR_LIKE_i, AR_LIKE_i), tuple[list[Any], str]) +assert_type(np.einsum_path("i,i->i", AR_LIKE_f, AR_LIKE_f), tuple[list[Any], str]) +assert_type(np.einsum_path("i,i->i", AR_LIKE_c, AR_LIKE_c), tuple[list[Any], str]) +assert_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i), tuple[list[Any], str]) +assert_type(np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c), tuple[list[Any], str]) + +assert_type(np.einsum([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i), Any) +assert_type(np.einsum_path([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i), tuple[list[Any], str]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/emath.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/emath.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1d7bff893e73445dec3f34e5078ebe19a7d893ed --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/emath.pyi @@ -0,0 +1,54 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +f8: np.float64 +c16: np.complex128 + +assert_type(np.emath.sqrt(f8), Any) +assert_type(np.emath.sqrt(AR_f8), npt.NDArray[Any]) +assert_type(np.emath.sqrt(c16), np.complexfloating) +assert_type(np.emath.sqrt(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.log(f8), Any) +assert_type(np.emath.log(AR_f8), npt.NDArray[Any]) +assert_type(np.emath.log(c16), np.complexfloating) +assert_type(np.emath.log(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.log10(f8), Any) +assert_type(np.emath.log10(AR_f8), npt.NDArray[Any]) +assert_type(np.emath.log10(c16), np.complexfloating) +assert_type(np.emath.log10(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.log2(f8), Any) +assert_type(np.emath.log2(AR_f8), npt.NDArray[Any]) +assert_type(np.emath.log2(c16), np.complexfloating) +assert_type(np.emath.log2(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.logn(f8, 2), Any) +assert_type(np.emath.logn(AR_f8, 4), npt.NDArray[Any]) +assert_type(np.emath.logn(f8, 1j), np.complexfloating) +assert_type(np.emath.logn(AR_c16, 1.5), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.power(f8, 2), Any) +assert_type(np.emath.power(AR_f8, 4), npt.NDArray[Any]) +assert_type(np.emath.power(f8, 2j), np.complexfloating) +assert_type(np.emath.power(AR_c16, 1.5), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.arccos(f8), Any) +assert_type(np.emath.arccos(AR_f8), npt.NDArray[Any]) +assert_type(np.emath.arccos(c16), np.complexfloating) +assert_type(np.emath.arccos(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.arcsin(f8), Any) +assert_type(np.emath.arcsin(AR_f8), npt.NDArray[Any]) +assert_type(np.emath.arcsin(c16), np.complexfloating) +assert_type(np.emath.arcsin(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.emath.arctanh(f8), Any) +assert_type(np.emath.arctanh(AR_f8), npt.NDArray[Any]) +assert_type(np.emath.arctanh(c16), np.complexfloating) +assert_type(np.emath.arctanh(AR_c16), npt.NDArray[np.complexfloating]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/fft.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/fft.pyi new file mode 100644 index 0000000000000000000000000000000000000000..dacd2b89777c1453088f2fedee82996bfb1e66ca --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/fft.pyi @@ -0,0 +1,37 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +AR_LIKE_f8: list[float] + +assert_type(np.fft.fftshift(AR_f8), npt.NDArray[np.float64]) +assert_type(np.fft.fftshift(AR_LIKE_f8, axes=0), npt.NDArray[Any]) + +assert_type(np.fft.ifftshift(AR_f8), npt.NDArray[np.float64]) +assert_type(np.fft.ifftshift(AR_LIKE_f8, axes=0), npt.NDArray[Any]) + +assert_type(np.fft.fftfreq(5, AR_f8), npt.NDArray[np.floating]) +assert_type(np.fft.fftfreq(np.int64(), AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.fft.fftfreq(5, AR_f8), npt.NDArray[np.floating]) +assert_type(np.fft.fftfreq(np.int64(), AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.fft.fft(AR_f8), npt.NDArray[np.complex128]) +assert_type(np.fft.ifft(AR_f8, axis=1), npt.NDArray[np.complex128]) +assert_type(np.fft.rfft(AR_f8, n=None), npt.NDArray[np.complex128]) +assert_type(np.fft.irfft(AR_f8, norm="ortho"), npt.NDArray[np.float64]) +assert_type(np.fft.hfft(AR_f8, n=2), npt.NDArray[np.float64]) +assert_type(np.fft.ihfft(AR_f8), npt.NDArray[np.complex128]) + +assert_type(np.fft.fftn(AR_f8), npt.NDArray[np.complex128]) +assert_type(np.fft.ifftn(AR_f8), npt.NDArray[np.complex128]) +assert_type(np.fft.rfftn(AR_f8), npt.NDArray[np.complex128]) +assert_type(np.fft.irfftn(AR_f8), npt.NDArray[np.float64]) + +assert_type(np.fft.rfft2(AR_f8), npt.NDArray[np.complex128]) +assert_type(np.fft.ifft2(AR_f8), npt.NDArray[np.complex128]) +assert_type(np.fft.fft2(AR_f8), npt.NDArray[np.complex128]) +assert_type(np.fft.irfft2(AR_f8), npt.NDArray[np.float64]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/flatiter.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/flatiter.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e188d30fe79f8d883ee45457d5860271b9809ed4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/flatiter.pyi @@ -0,0 +1,47 @@ +from typing import Literal, TypeAlias, assert_type + +import numpy as np +import numpy.typing as npt + +a: np.flatiter[npt.NDArray[np.str_]] +a_1d: np.flatiter[np.ndarray[tuple[int], np.dtype[np.bytes_]]] + +Size: TypeAlias = Literal[42] +a_1d_fixed: np.flatiter[np.ndarray[tuple[Size], np.dtype[np.object_]]] + +assert_type(a.base, npt.NDArray[np.str_]) +assert_type(a.copy(), npt.NDArray[np.str_]) +assert_type(a.coords, tuple[int, ...]) +assert_type(a.index, int) +assert_type(iter(a), np.flatiter[npt.NDArray[np.str_]]) +assert_type(next(a), np.str_) +assert_type(a[0], np.str_) +assert_type(a[[0, 1, 2]], npt.NDArray[np.str_]) +assert_type(a[...], npt.NDArray[np.str_]) +assert_type(a[:], npt.NDArray[np.str_]) +assert_type(a[(...,)], npt.NDArray[np.str_]) +assert_type(a[(0,)], np.str_) + +assert_type(a.__array__(), npt.NDArray[np.str_]) +assert_type(a.__array__(np.dtype(np.float64)), npt.NDArray[np.float64]) +assert_type( + a_1d.__array__(), + np.ndarray[tuple[int], np.dtype[np.bytes_]], +) +assert_type( + a_1d.__array__(np.dtype(np.float64)), + np.ndarray[tuple[int], np.dtype[np.float64]], +) +assert_type( + a_1d_fixed.__array__(), + np.ndarray[tuple[Size], np.dtype[np.object_]], +) +assert_type( + a_1d_fixed.__array__(np.dtype(np.float64)), + np.ndarray[tuple[Size], np.dtype[np.float64]], +) + +a[0] = "a" +a[:5] = "a" +a[...] = "a" +a[(...,)] = "a" diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/fromnumeric.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/fromnumeric.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5438e001a13f0b6c0202232a5a34f1314000e568 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/fromnumeric.pyi @@ -0,0 +1,347 @@ +"""Tests for :mod:`_core.fromnumeric`.""" + +from typing import Any, assert_type +from typing import Literal as L + +import numpy as np +import numpy.typing as npt + +class NDArraySubclass(npt.NDArray[np.complex128]): ... + +AR_b: npt.NDArray[np.bool] +AR_f4: npt.NDArray[np.float32] +AR_c16: npt.NDArray[np.complex128] +AR_u8: npt.NDArray[np.uint64] +AR_i8: npt.NDArray[np.int64] +AR_O: npt.NDArray[np.object_] +AR_subclass: NDArraySubclass +AR_m: npt.NDArray[np.timedelta64] +AR_0d: np.ndarray[tuple[()]] +AR_1d: np.ndarray[tuple[int]] +AR_nd: np.ndarray + +b: np.bool +f4: np.float32 +i8: np.int64 +f: float + +# integer‑dtype subclass for argmin/argmax +class NDArrayIntSubclass(npt.NDArray[np.intp]): ... +AR_sub_i: NDArrayIntSubclass + +assert_type(np.take(b, 0), np.bool) +assert_type(np.take(f4, 0), np.float32) +assert_type(np.take(f, 0), Any) +assert_type(np.take(AR_b, 0), np.bool) +assert_type(np.take(AR_f4, 0), np.float32) +assert_type(np.take(AR_b, [0]), npt.NDArray[np.bool]) +assert_type(np.take(AR_f4, [0]), npt.NDArray[np.float32]) +assert_type(np.take([1], [0]), npt.NDArray[Any]) +assert_type(np.take(AR_f4, [0], out=AR_subclass), NDArraySubclass) + +assert_type(np.reshape(b, 1), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(np.reshape(f4, 1), np.ndarray[tuple[int], np.dtype[np.float32]]) +assert_type(np.reshape(f, 1), np.ndarray[tuple[int], np.dtype]) +assert_type(np.reshape(AR_b, 1), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(np.reshape(AR_f4, 1), np.ndarray[tuple[int], np.dtype[np.float32]]) + +assert_type(np.choose(1, [True, True]), Any) +assert_type(np.choose([1], [True, True]), npt.NDArray[Any]) +assert_type(np.choose([1], AR_b), npt.NDArray[np.bool]) +assert_type(np.choose([1], AR_b, out=AR_f4), npt.NDArray[np.float32]) + +assert_type(np.repeat(b, 1), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(np.repeat(b, 1, axis=0), npt.NDArray[np.bool]) +assert_type(np.repeat(f4, 1), np.ndarray[tuple[int], np.dtype[np.float32]]) +assert_type(np.repeat(f, 1), np.ndarray[tuple[int], np.dtype[Any]]) +assert_type(np.repeat(AR_b, 1), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(np.repeat(AR_f4, 1), np.ndarray[tuple[int], np.dtype[np.float32]]) +assert_type(np.repeat(AR_f4, 1, axis=0), npt.NDArray[np.float32]) + +# TODO: array_bdd tests for np.put() + +assert_type(np.swapaxes([[0, 1]], 0, 0), npt.NDArray[Any]) +assert_type(np.swapaxes(AR_b, 0, 0), npt.NDArray[np.bool]) +assert_type(np.swapaxes(AR_f4, 0, 0), npt.NDArray[np.float32]) + +assert_type(np.transpose(b), npt.NDArray[np.bool]) +assert_type(np.transpose(f4), npt.NDArray[np.float32]) +assert_type(np.transpose(f), npt.NDArray[Any]) +assert_type(np.transpose(AR_b), npt.NDArray[np.bool]) +assert_type(np.transpose(AR_f4), npt.NDArray[np.float32]) + +assert_type(np.partition(b, 0, axis=None), npt.NDArray[np.bool]) +assert_type(np.partition(f4, 0, axis=None), npt.NDArray[np.float32]) +assert_type(np.partition(f, 0, axis=None), npt.NDArray[Any]) +assert_type(np.partition(AR_b, 0), npt.NDArray[np.bool]) +assert_type(np.partition(AR_f4, 0), npt.NDArray[np.float32]) + +assert_type(np.argpartition(b, 0), npt.NDArray[np.intp]) +assert_type(np.argpartition(f4, 0), npt.NDArray[np.intp]) +assert_type(np.argpartition(f, 0), npt.NDArray[np.intp]) +assert_type(np.argpartition(AR_b, 0), npt.NDArray[np.intp]) +assert_type(np.argpartition(AR_f4, 0), npt.NDArray[np.intp]) + +assert_type(np.sort([2, 1], 0), npt.NDArray[Any]) +assert_type(np.sort(AR_b, 0), npt.NDArray[np.bool]) +assert_type(np.sort(AR_f4, 0), npt.NDArray[np.float32]) + +assert_type(np.argsort(AR_b, 0), npt.NDArray[np.intp]) +assert_type(np.argsort(AR_f4, 0), npt.NDArray[np.intp]) + +assert_type(np.argmax(AR_b), np.intp) +assert_type(np.argmax(AR_f4), np.intp) +assert_type(np.argmax(AR_b, axis=0), Any) +assert_type(np.argmax(AR_f4, axis=0), Any) +assert_type(np.argmax(AR_f4, out=AR_sub_i), NDArrayIntSubclass) + +assert_type(np.argmin(AR_b), np.intp) +assert_type(np.argmin(AR_f4), np.intp) +assert_type(np.argmin(AR_b, axis=0), Any) +assert_type(np.argmin(AR_f4, axis=0), Any) +assert_type(np.argmin(AR_f4, out=AR_sub_i), NDArrayIntSubclass) + +assert_type(np.searchsorted(AR_b[0], 0), np.intp) +assert_type(np.searchsorted(AR_f4[0], 0), np.intp) +assert_type(np.searchsorted(AR_b[0], [0]), npt.NDArray[np.intp]) +assert_type(np.searchsorted(AR_f4[0], [0]), npt.NDArray[np.intp]) + +assert_type(np.resize(b, (5, 5)), np.ndarray[tuple[int, int], np.dtype[np.bool]]) +assert_type(np.resize(f4, (5, 5)), np.ndarray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.resize(f, (5, 5)), np.ndarray[tuple[int, int], np.dtype]) +assert_type(np.resize(AR_b, (5, 5)), np.ndarray[tuple[int, int], np.dtype[np.bool]]) +assert_type(np.resize(AR_f4, (5, 5)), np.ndarray[tuple[int, int], np.dtype[np.float32]]) + +assert_type(np.squeeze(b), np.bool) +assert_type(np.squeeze(f4), np.float32) +assert_type(np.squeeze(f), npt.NDArray[Any]) +assert_type(np.squeeze(AR_b), npt.NDArray[np.bool]) +assert_type(np.squeeze(AR_f4), npt.NDArray[np.float32]) + +assert_type(np.diagonal(AR_b), npt.NDArray[np.bool]) +assert_type(np.diagonal(AR_f4), npt.NDArray[np.float32]) + +assert_type(np.trace(AR_b), Any) +assert_type(np.trace(AR_f4), Any) +assert_type(np.trace(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.ravel(b), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(np.ravel(f4), np.ndarray[tuple[int], np.dtype[np.float32]]) +assert_type(np.ravel(f), np.ndarray[tuple[int], np.dtype[np.float64 | np.int_ | np.bool]]) +assert_type(np.ravel(AR_b), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(np.ravel(AR_f4), np.ndarray[tuple[int], np.dtype[np.float32]]) + +assert_type(np.nonzero(AR_b), tuple[npt.NDArray[np.intp], ...]) +assert_type(np.nonzero(AR_f4), tuple[npt.NDArray[np.intp], ...]) +assert_type(np.nonzero(AR_1d), tuple[npt.NDArray[np.intp], ...]) +assert_type(np.nonzero(AR_nd), tuple[npt.NDArray[np.intp], ...]) + +assert_type(np.shape(b), tuple[()]) +assert_type(np.shape(f), tuple[()]) +assert_type(np.shape([1]), tuple[int]) +assert_type(np.shape([[2]]), tuple[int, int]) +assert_type(np.shape([[[3]]]), tuple[Any, ...]) +assert_type(np.shape(AR_b), tuple[Any, ...]) +assert_type(np.shape(AR_nd), tuple[Any, ...]) +# these fail on mypy, but it works as expected with pyright/pylance +# assert_type(np.shape(AR_0d), tuple[()]) +# assert_type(np.shape(AR_1d), tuple[int]) +# assert_type(np.shape(AR_2d), tuple[int, int]) + +assert_type(np.compress([True], b), npt.NDArray[np.bool]) +assert_type(np.compress([True], f4), npt.NDArray[np.float32]) +assert_type(np.compress([True], f), npt.NDArray[Any]) +assert_type(np.compress([True], AR_b), npt.NDArray[np.bool]) +assert_type(np.compress([True], AR_f4), npt.NDArray[np.float32]) + +assert_type(np.clip(b, 0, 1.0), np.bool) +assert_type(np.clip(f4, -1, 1), np.float32) +assert_type(np.clip(f, 0, 1), Any) +assert_type(np.clip(AR_b, 0, 1), npt.NDArray[np.bool]) +assert_type(np.clip(AR_f4, 0, 1), npt.NDArray[np.float32]) +assert_type(np.clip([0], 0, 1), npt.NDArray[Any]) +assert_type(np.clip(AR_b, 0, 1, out=AR_subclass), NDArraySubclass) + +assert_type(np.sum(b), np.bool) +assert_type(np.sum(f4), np.float32) +assert_type(np.sum(f), Any) +assert_type(np.sum(AR_b), np.bool) +assert_type(np.sum(AR_f4), np.float32) +assert_type(np.sum(AR_b, axis=0), Any) +assert_type(np.sum(AR_f4, axis=0), Any) +assert_type(np.sum(AR_f4, out=AR_subclass), NDArraySubclass) +assert_type(np.sum(AR_f4, dtype=np.float64), np.float64) +assert_type(np.sum(AR_f4, None, np.float64), np.float64) +assert_type(np.sum(AR_f4, dtype=np.float64, keepdims=False), np.float64) +assert_type(np.sum(AR_f4, None, np.float64, keepdims=False), np.float64) +assert_type(np.sum(AR_f4, dtype=np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64]) +assert_type(np.sum(AR_f4, None, np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64]) + +assert_type(np.all(b), np.bool) +assert_type(np.all(f4), np.bool) +assert_type(np.all(f), np.bool) +assert_type(np.all(AR_b), np.bool) +assert_type(np.all(AR_f4), np.bool) +assert_type(np.all(AR_b, axis=0), Any) +assert_type(np.all(AR_f4, axis=0), Any) +assert_type(np.all(AR_b, keepdims=True), Any) +assert_type(np.all(AR_f4, keepdims=True), Any) +assert_type(np.all(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.any(b), np.bool) +assert_type(np.any(f4), np.bool) +assert_type(np.any(f), np.bool) +assert_type(np.any(AR_b), np.bool) +assert_type(np.any(AR_f4), np.bool) +assert_type(np.any(AR_b, axis=0), Any) +assert_type(np.any(AR_f4, axis=0), Any) +assert_type(np.any(AR_b, keepdims=True), Any) +assert_type(np.any(AR_f4, keepdims=True), Any) +assert_type(np.any(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.cumsum(b), npt.NDArray[np.bool]) +assert_type(np.cumsum(f4), npt.NDArray[np.float32]) +assert_type(np.cumsum(f), npt.NDArray[Any]) +assert_type(np.cumsum(AR_b), npt.NDArray[np.bool]) +assert_type(np.cumsum(AR_f4), npt.NDArray[np.float32]) +assert_type(np.cumsum(f, dtype=float), npt.NDArray[Any]) +assert_type(np.cumsum(f, dtype=np.float64), npt.NDArray[np.float64]) +assert_type(np.cumsum(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.cumulative_sum(b), npt.NDArray[np.bool]) +assert_type(np.cumulative_sum(f4), npt.NDArray[np.float32]) +assert_type(np.cumulative_sum(f), npt.NDArray[Any]) +assert_type(np.cumulative_sum(AR_b), npt.NDArray[np.bool]) +assert_type(np.cumulative_sum(AR_f4), npt.NDArray[np.float32]) +assert_type(np.cumulative_sum(f, dtype=float), npt.NDArray[Any]) +assert_type(np.cumulative_sum(f, dtype=np.float64), npt.NDArray[np.float64]) +assert_type(np.cumulative_sum(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.ptp(b), np.bool) +assert_type(np.ptp(f4), np.float32) +assert_type(np.ptp(f), Any) +assert_type(np.ptp(AR_b), np.bool) +assert_type(np.ptp(AR_f4), np.float32) +assert_type(np.ptp(AR_b, axis=0), Any) +assert_type(np.ptp(AR_f4, axis=0), Any) +assert_type(np.ptp(AR_b, keepdims=True), Any) +assert_type(np.ptp(AR_f4, keepdims=True), Any) +assert_type(np.ptp(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.amax(b), np.bool) +assert_type(np.amax(f4), np.float32) +assert_type(np.amax(f), Any) +assert_type(np.amax(AR_b), np.bool) +assert_type(np.amax(AR_f4), np.float32) +assert_type(np.amax(AR_b, axis=0), Any) +assert_type(np.amax(AR_f4, axis=0), Any) +assert_type(np.amax(AR_b, keepdims=True), Any) +assert_type(np.amax(AR_f4, keepdims=True), Any) +assert_type(np.amax(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.amin(b), np.bool) +assert_type(np.amin(f4), np.float32) +assert_type(np.amin(f), Any) +assert_type(np.amin(AR_b), np.bool) +assert_type(np.amin(AR_f4), np.float32) +assert_type(np.amin(AR_b, axis=0), Any) +assert_type(np.amin(AR_f4, axis=0), Any) +assert_type(np.amin(AR_b, keepdims=True), Any) +assert_type(np.amin(AR_f4, keepdims=True), Any) +assert_type(np.amin(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.prod(AR_b), np.int_) +assert_type(np.prod(AR_u8), np.uint64) +assert_type(np.prod(AR_i8), np.int64) +assert_type(np.prod(AR_f4), np.floating) +assert_type(np.prod(AR_c16), np.complexfloating) +assert_type(np.prod(AR_O), Any) +assert_type(np.prod(AR_f4, axis=0), Any) +assert_type(np.prod(AR_f4, keepdims=True), Any) +assert_type(np.prod(AR_f4, dtype=np.float64), np.float64) +assert_type(np.prod(AR_f4, dtype=float), Any) +assert_type(np.prod(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.cumprod(AR_b), npt.NDArray[np.int_]) +assert_type(np.cumprod(AR_u8), npt.NDArray[np.uint64]) +assert_type(np.cumprod(AR_i8), npt.NDArray[np.int64]) +assert_type(np.cumprod(AR_f4), npt.NDArray[np.floating]) +assert_type(np.cumprod(AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.cumprod(AR_O), npt.NDArray[np.object_]) +assert_type(np.cumprod(AR_f4, axis=0), npt.NDArray[np.floating]) +assert_type(np.cumprod(AR_f4, dtype=np.float64), npt.NDArray[np.float64]) +assert_type(np.cumprod(AR_f4, dtype=float), npt.NDArray[Any]) +assert_type(np.cumprod(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.cumulative_prod(AR_b), npt.NDArray[np.int_]) +assert_type(np.cumulative_prod(AR_u8), npt.NDArray[np.uint64]) +assert_type(np.cumulative_prod(AR_i8), npt.NDArray[np.int64]) +assert_type(np.cumulative_prod(AR_f4), npt.NDArray[np.floating]) +assert_type(np.cumulative_prod(AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.cumulative_prod(AR_O), npt.NDArray[np.object_]) +assert_type(np.cumulative_prod(AR_f4, axis=0), npt.NDArray[np.floating]) +assert_type(np.cumulative_prod(AR_f4, dtype=np.float64), npt.NDArray[np.float64]) +assert_type(np.cumulative_prod(AR_f4, dtype=float), npt.NDArray[Any]) +assert_type(np.cumulative_prod(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.ndim(b), int) +assert_type(np.ndim(f4), int) +assert_type(np.ndim(f), int) +assert_type(np.ndim(AR_b), int) +assert_type(np.ndim(AR_f4), int) + +assert_type(np.size(b), int) +assert_type(np.size(f4), int) +assert_type(np.size(f), int) +assert_type(np.size(AR_b), int) +assert_type(np.size(AR_f4), int) + +assert_type(np.around(b), np.float16) +assert_type(np.around(f), Any) +assert_type(np.around(i8), np.int64) +assert_type(np.around(f4), np.float32) +assert_type(np.around(AR_b), npt.NDArray[np.float16]) +assert_type(np.around(AR_i8), npt.NDArray[np.int64]) +assert_type(np.around(AR_f4), npt.NDArray[np.float32]) +assert_type(np.around([1.5]), npt.NDArray[Any]) +assert_type(np.around(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.mean(AR_b), np.floating) +assert_type(np.mean(AR_i8), np.floating) +assert_type(np.mean(AR_f4), np.floating) +assert_type(np.mean(AR_m), np.timedelta64) +assert_type(np.mean(AR_c16), np.complexfloating) +assert_type(np.mean(AR_O), Any) +assert_type(np.mean(AR_f4, axis=0), Any) +assert_type(np.mean(AR_f4, keepdims=True), Any) +assert_type(np.mean(AR_f4, dtype=float), Any) +assert_type(np.mean(AR_f4, dtype=np.float64), np.float64) +assert_type(np.mean(AR_f4, out=AR_subclass), NDArraySubclass) +assert_type(np.mean(AR_f4, dtype=np.float64), np.float64) +assert_type(np.mean(AR_f4, None, np.float64), np.float64) +assert_type(np.mean(AR_f4, dtype=np.float64, keepdims=False), np.float64) +assert_type(np.mean(AR_f4, None, np.float64, keepdims=False), np.float64) +assert_type(np.mean(AR_f4, dtype=np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64]) +assert_type(np.mean(AR_f4, None, np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64]) + +assert_type(np.std(AR_b), np.floating) +assert_type(np.std(AR_i8), np.floating) +assert_type(np.std(AR_f4), np.floating) +assert_type(np.std(AR_c16), np.floating) +assert_type(np.std(AR_O), Any) +assert_type(np.std(AR_f4, axis=0), Any) +assert_type(np.std(AR_f4, keepdims=True), Any) +assert_type(np.std(AR_f4, dtype=float), Any) +assert_type(np.std(AR_f4, dtype=np.float64), np.float64) +assert_type(np.std(AR_f4, out=AR_subclass), NDArraySubclass) + +assert_type(np.var(AR_b), np.floating) +assert_type(np.var(AR_i8), np.floating) +assert_type(np.var(AR_f4), np.floating) +assert_type(np.var(AR_c16), np.floating) +assert_type(np.var(AR_O), Any) +assert_type(np.var(AR_f4, axis=0), Any) +assert_type(np.var(AR_f4, keepdims=True), Any) +assert_type(np.var(AR_f4, dtype=float), Any) +assert_type(np.var(AR_f4, dtype=np.float64), np.float64) +assert_type(np.var(AR_f4, out=AR_subclass), NDArraySubclass) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/getlimits.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/getlimits.pyi new file mode 100644 index 0000000000000000000000000000000000000000..825daba43064b33a5280f8a335c49d5b6e728cdc --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/getlimits.pyi @@ -0,0 +1,51 @@ +from typing import Any, LiteralString, assert_type + +import numpy as np +from numpy._typing import _64Bit + +f: float +f8: np.float64 +c8: np.complex64 + +i: int +i8: np.int64 +u4: np.uint32 + +finfo_f8: np.finfo[np.float64] +iinfo_i8: np.iinfo[np.int64] + +assert_type(np.finfo(f), np.finfo[np.float64]) +assert_type(np.finfo(f8), np.finfo[np.floating[_64Bit]]) +assert_type(np.finfo(c8), np.finfo[np.float32]) +assert_type(np.finfo('f2'), np.finfo[np.floating]) + +assert_type(finfo_f8.dtype, np.dtype[np.float64]) +assert_type(finfo_f8.bits, int) +assert_type(finfo_f8.eps, np.float64) +assert_type(finfo_f8.epsneg, np.float64) +assert_type(finfo_f8.iexp, int) +assert_type(finfo_f8.machep, int) +assert_type(finfo_f8.max, np.float64) +assert_type(finfo_f8.maxexp, int) +assert_type(finfo_f8.min, np.float64) +assert_type(finfo_f8.minexp, int) +assert_type(finfo_f8.negep, int) +assert_type(finfo_f8.nexp, int) +assert_type(finfo_f8.nmant, int) +assert_type(finfo_f8.precision, int) +assert_type(finfo_f8.resolution, np.float64) +assert_type(finfo_f8.tiny, np.float64) +assert_type(finfo_f8.smallest_normal, np.float64) +assert_type(finfo_f8.smallest_subnormal, np.float64) + +assert_type(np.iinfo(i), np.iinfo[np.int_]) +assert_type(np.iinfo(i8), np.iinfo[np.int64]) +assert_type(np.iinfo(u4), np.iinfo[np.uint32]) +assert_type(np.iinfo('i2'), np.iinfo[Any]) + +assert_type(iinfo_i8.dtype, np.dtype[np.int64]) +assert_type(iinfo_i8.kind, LiteralString) +assert_type(iinfo_i8.bits, int) +assert_type(iinfo_i8.key, LiteralString) +assert_type(iinfo_i8.min, int) +assert_type(iinfo_i8.max, int) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/histograms.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/histograms.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c1c63d59cb8805237e1d4907f2e90ef6f6ee4cfd --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/histograms.pyi @@ -0,0 +1,25 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] + +assert_type(np.histogram_bin_edges(AR_i8, bins="auto"), npt.NDArray[Any]) +assert_type(np.histogram_bin_edges(AR_i8, bins="rice", range=(0, 3)), npt.NDArray[Any]) +assert_type(np.histogram_bin_edges(AR_i8, bins="scott", weights=AR_f8), npt.NDArray[Any]) + +assert_type(np.histogram(AR_i8, bins="auto"), tuple[npt.NDArray[Any], npt.NDArray[Any]]) +assert_type(np.histogram(AR_i8, bins="rice", range=(0, 3)), tuple[npt.NDArray[Any], npt.NDArray[Any]]) +assert_type(np.histogram(AR_i8, bins="scott", weights=AR_f8), tuple[npt.NDArray[Any], npt.NDArray[Any]]) +assert_type(np.histogram(AR_f8, bins=1, density=True), tuple[npt.NDArray[Any], npt.NDArray[Any]]) + +assert_type(np.histogramdd(AR_i8, bins=[1]), + tuple[npt.NDArray[Any], tuple[npt.NDArray[Any], ...]]) +assert_type(np.histogramdd(AR_i8, range=[(0, 3)]), + tuple[npt.NDArray[Any], tuple[npt.NDArray[Any], ...]]) +assert_type(np.histogramdd(AR_i8, weights=AR_f8), + tuple[npt.NDArray[Any], tuple[npt.NDArray[Any], ...]]) +assert_type(np.histogramdd(AR_f8, density=True), + tuple[npt.NDArray[Any], tuple[npt.NDArray[Any], ...]]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/index_tricks.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/index_tricks.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f6067c3bed6bfd0adecab400a1028f40af479b24 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/index_tricks.pyi @@ -0,0 +1,70 @@ +from types import EllipsisType +from typing import Any, Literal, assert_type + +import numpy as np +import numpy.typing as npt + +AR_LIKE_b: list[bool] +AR_LIKE_i: list[int] +AR_LIKE_f: list[float] +AR_LIKE_U: list[str] +AR_LIKE_O: list[object] + +AR_i8: npt.NDArray[np.int64] +AR_O: npt.NDArray[np.object_] + +assert_type(np.ndenumerate(AR_i8), np.ndenumerate[np.int64]) +assert_type(np.ndenumerate(AR_LIKE_f), np.ndenumerate[np.float64]) +assert_type(np.ndenumerate(AR_LIKE_U), np.ndenumerate[np.str_]) +assert_type(np.ndenumerate(AR_LIKE_O), np.ndenumerate[Any]) + +assert_type(next(np.ndenumerate(AR_i8)), tuple[tuple[Any, ...], np.int64]) +assert_type(next(np.ndenumerate(AR_LIKE_f)), tuple[tuple[Any, ...], np.float64]) +assert_type(next(np.ndenumerate(AR_LIKE_U)), tuple[tuple[Any, ...], np.str_]) +assert_type(next(np.ndenumerate(AR_LIKE_O)), tuple[tuple[Any, ...], Any]) + +assert_type(iter(np.ndenumerate(AR_i8)), np.ndenumerate[np.int64]) +assert_type(iter(np.ndenumerate(AR_LIKE_f)), np.ndenumerate[np.float64]) +assert_type(iter(np.ndenumerate(AR_LIKE_U)), np.ndenumerate[np.str_]) +assert_type(iter(np.ndenumerate(AR_LIKE_O)), np.ndenumerate[Any]) + +assert_type(np.ndindex(1, 2, 3), np.ndindex) +assert_type(np.ndindex((1, 2, 3)), np.ndindex) +assert_type(iter(np.ndindex(1, 2, 3)), np.ndindex) +assert_type(next(np.ndindex(1, 2, 3)), tuple[Any, ...]) + +assert_type(np.unravel_index([22, 41, 37], (7, 6)), tuple[npt.NDArray[np.intp], ...]) +assert_type(np.unravel_index([31, 41, 13], (7, 6), order="F"), tuple[npt.NDArray[np.intp], ...]) +assert_type(np.unravel_index(1621, (6, 7, 8, 9)), tuple[np.intp, ...]) + +assert_type(np.ravel_multi_index([[1]], (7, 6)), npt.NDArray[np.intp]) +assert_type(np.ravel_multi_index(AR_LIKE_i, (7, 6)), np.intp) +assert_type(np.ravel_multi_index(AR_LIKE_i, (7, 6), order="F"), np.intp) +assert_type(np.ravel_multi_index(AR_LIKE_i, (4, 6), mode="clip"), np.intp) +assert_type(np.ravel_multi_index(AR_LIKE_i, (4, 4), mode=("clip", "wrap")), np.intp) +assert_type(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9)), np.intp) + +assert_type(np.mgrid[1:1:2], npt.NDArray[Any]) +assert_type(np.mgrid[1:1:2, None:10], npt.NDArray[Any]) + +assert_type(np.ogrid[1:1:2], tuple[npt.NDArray[Any], ...]) +assert_type(np.ogrid[1:1:2, None:10], tuple[npt.NDArray[Any], ...]) + +assert_type(np.index_exp[0:1], tuple[slice[int, int, None]]) +assert_type(np.index_exp[0:1, None:3], tuple[slice[int, int, None], slice[None, int, None]]) +assert_type(np.index_exp[0, 0:1, ..., [0, 1, 3]], tuple[Literal[0], slice[int, int, None], EllipsisType, list[int]]) + +assert_type(np.s_[0:1], slice[int, int, None]) +assert_type(np.s_[0:1, None:3], tuple[slice[int, int, None], slice[None, int, None]]) +assert_type(np.s_[0, 0:1, ..., [0, 1, 3]], tuple[Literal[0], slice[int, int, None], EllipsisType, list[int]]) + +assert_type(np.ix_(AR_LIKE_b), tuple[npt.NDArray[np.bool], ...]) +assert_type(np.ix_(AR_LIKE_i, AR_LIKE_f), tuple[npt.NDArray[np.float64], ...]) +assert_type(np.ix_(AR_i8), tuple[npt.NDArray[np.int64], ...]) + +assert_type(np.fill_diagonal(AR_i8, 5), None) + +assert_type(np.diag_indices(4), tuple[npt.NDArray[np.int_], ...]) +assert_type(np.diag_indices(2, 3), tuple[npt.NDArray[np.int_], ...]) + +assert_type(np.diag_indices_from(AR_i8), tuple[npt.NDArray[np.int_], ...]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_function_base.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_function_base.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3ce8d375201bfd5c6d77668c59e9284eddc0c89f --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_function_base.pyi @@ -0,0 +1,213 @@ +from collections.abc import Callable +from fractions import Fraction +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +vectorized_func: np.vectorize + +f8: np.float64 +AR_LIKE_f8: list[float] +AR_LIKE_c16: list[complex] +AR_LIKE_O: list[Fraction] + +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +AR_m: npt.NDArray[np.timedelta64] +AR_M: npt.NDArray[np.datetime64] +AR_O: npt.NDArray[np.object_] +AR_b: npt.NDArray[np.bool] +AR_U: npt.NDArray[np.str_] +CHAR_AR_U: np.char.chararray[tuple[Any, ...], np.dtype[np.str_]] + +AR_b_list: list[npt.NDArray[np.bool]] + +def func( + a: npt.NDArray[Any], + posarg: bool = ..., + /, + arg: int = ..., + *, + kwarg: str = ..., +) -> npt.NDArray[Any]: ... + +assert_type(vectorized_func.pyfunc, Callable[..., Any]) +assert_type(vectorized_func.cache, bool) +assert_type(vectorized_func.signature, str | None) +assert_type(vectorized_func.otypes, str | None) +assert_type(vectorized_func.excluded, set[int | str]) +assert_type(vectorized_func.__doc__, str | None) +assert_type(vectorized_func([1]), Any) +assert_type(np.vectorize(int), np.vectorize) +assert_type( + np.vectorize(int, otypes="i", doc="doc", excluded=(), cache=True, signature=None), + np.vectorize, +) + +assert_type(np.rot90(AR_f8, k=2), npt.NDArray[np.float64]) +assert_type(np.rot90(AR_LIKE_f8, axes=(0, 1)), npt.NDArray[Any]) + +assert_type(np.flip(f8), np.float64) +assert_type(np.flip(1.0), Any) +assert_type(np.flip(AR_f8, axis=(0, 1)), npt.NDArray[np.float64]) +assert_type(np.flip(AR_LIKE_f8, axis=0), npt.NDArray[Any]) + +assert_type(np.iterable(1), bool) +assert_type(np.iterable([1]), bool) + +assert_type(np.average(AR_f8), np.floating) +assert_type(np.average(AR_f8, weights=AR_c16), np.complexfloating) +assert_type(np.average(AR_O), Any) +assert_type(np.average(AR_f8, returned=True), tuple[np.floating, np.floating]) +assert_type(np.average(AR_f8, weights=AR_c16, returned=True), tuple[np.complexfloating, np.complexfloating]) +assert_type(np.average(AR_O, returned=True), tuple[Any, Any]) +assert_type(np.average(AR_f8, axis=0), Any) +assert_type(np.average(AR_f8, axis=0, returned=True), tuple[Any, Any]) + +assert_type(np.asarray_chkfinite(AR_f8), npt.NDArray[np.float64]) +assert_type(np.asarray_chkfinite(AR_LIKE_f8), npt.NDArray[Any]) +assert_type(np.asarray_chkfinite(AR_f8, dtype=np.float64), npt.NDArray[np.float64]) +assert_type(np.asarray_chkfinite(AR_f8, dtype=float), npt.NDArray[Any]) + +assert_type(np.piecewise(AR_f8, AR_b, [func]), npt.NDArray[np.float64]) +assert_type(np.piecewise(AR_f8, AR_b_list, [func]), npt.NDArray[np.float64]) +assert_type(np.piecewise(AR_f8, AR_b_list, [func], True, -1, kwarg=''), npt.NDArray[np.float64]) +assert_type(np.piecewise(AR_f8, AR_b_list, [func], True, arg=-1, kwarg=''), npt.NDArray[np.float64]) +assert_type(np.piecewise(AR_LIKE_f8, AR_b_list, [func]), npt.NDArray[Any]) + +assert_type(np.select([AR_f8], [AR_f8]), npt.NDArray[Any]) + +assert_type(np.copy(AR_LIKE_f8), npt.NDArray[Any]) +assert_type(np.copy(AR_U), npt.NDArray[np.str_]) +assert_type(np.copy(CHAR_AR_U), np.ndarray[Any, Any]) # pyright correctly infers `NDArray[str_]` +assert_type(np.copy(CHAR_AR_U, "K", subok=True), np.char.chararray[tuple[Any, ...], np.dtype[np.str_]]) +assert_type(np.copy(CHAR_AR_U, subok=True), np.char.chararray[tuple[Any, ...], np.dtype[np.str_]]) + +assert_type(np.gradient(AR_f8, axis=None), Any) +assert_type(np.gradient(AR_LIKE_f8, edge_order=2), Any) + +assert_type(np.diff("bob", n=0), str) +assert_type(np.diff(AR_f8, axis=0), npt.NDArray[Any]) +assert_type(np.diff(AR_LIKE_f8, prepend=1.5), npt.NDArray[Any]) + +assert_type(np.interp(1, [1], AR_f8), np.float64) +assert_type(np.interp(1, [1], [1]), np.float64) +assert_type(np.interp(1, [1], AR_c16), np.complex128) +assert_type(np.interp(1, [1], [1j]), np.complex128) # pyright correctly infers `complex128 | float64` +assert_type(np.interp([1], [1], AR_f8), npt.NDArray[np.float64]) +assert_type(np.interp([1], [1], [1]), npt.NDArray[np.float64]) +assert_type(np.interp([1], [1], AR_c16), npt.NDArray[np.complex128]) +assert_type(np.interp([1], [1], [1j]), npt.NDArray[np.complex128]) # pyright correctly infers `NDArray[complex128 | float64]` + +assert_type(np.angle(f8), np.floating) +assert_type(np.angle(AR_f8), npt.NDArray[np.floating]) +assert_type(np.angle(AR_c16, deg=True), npt.NDArray[np.floating]) +assert_type(np.angle(AR_O), npt.NDArray[np.object_]) + +assert_type(np.unwrap(AR_f8), npt.NDArray[np.floating]) +assert_type(np.unwrap(AR_O), npt.NDArray[np.object_]) + +assert_type(np.sort_complex(AR_f8), npt.NDArray[np.complexfloating]) + +assert_type(np.trim_zeros(AR_f8), npt.NDArray[np.float64]) +assert_type(np.trim_zeros(AR_LIKE_f8), list[float]) + +assert_type(np.extract(AR_i8, AR_f8), npt.NDArray[np.float64]) +assert_type(np.extract(AR_i8, AR_LIKE_f8), npt.NDArray[Any]) + +assert_type(np.place(AR_f8, mask=AR_i8, vals=5.0), None) + +assert_type(np.cov(AR_f8, bias=True), npt.NDArray[np.floating]) +assert_type(np.cov(AR_f8, AR_c16, ddof=1), npt.NDArray[np.complexfloating]) +assert_type(np.cov(AR_f8, aweights=AR_f8, dtype=np.float32), npt.NDArray[np.float32]) +assert_type(np.cov(AR_f8, fweights=AR_f8, dtype=float), npt.NDArray[Any]) + +assert_type(np.corrcoef(AR_f8, rowvar=True), npt.NDArray[np.floating]) +assert_type(np.corrcoef(AR_f8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.corrcoef(AR_f8, dtype=np.float32), npt.NDArray[np.float32]) +assert_type(np.corrcoef(AR_f8, dtype=float), npt.NDArray[Any]) + +assert_type(np.blackman(5), npt.NDArray[np.floating]) +assert_type(np.bartlett(6), npt.NDArray[np.floating]) +assert_type(np.hanning(4.5), npt.NDArray[np.floating]) +assert_type(np.hamming(0), npt.NDArray[np.floating]) +assert_type(np.i0(AR_i8), npt.NDArray[np.floating]) +assert_type(np.kaiser(4, 5.9), npt.NDArray[np.floating]) + +assert_type(np.sinc(1.0), np.floating) +assert_type(np.sinc(1j), np.complexfloating) +assert_type(np.sinc(AR_f8), npt.NDArray[np.floating]) +assert_type(np.sinc(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.median(AR_f8, keepdims=False), np.floating) +assert_type(np.median(AR_c16, overwrite_input=True), np.complexfloating) +assert_type(np.median(AR_m), np.timedelta64) +assert_type(np.median(AR_O), Any) +assert_type(np.median(AR_f8, keepdims=True), Any) +assert_type(np.median(AR_c16, axis=0), Any) +assert_type(np.median(AR_LIKE_f8, out=AR_c16), npt.NDArray[np.complex128]) + +assert_type(np.percentile(AR_f8, 50), np.floating) +assert_type(np.percentile(AR_c16, 50), np.complexfloating) +assert_type(np.percentile(AR_m, 50), np.timedelta64) +assert_type(np.percentile(AR_M, 50, overwrite_input=True), np.datetime64) +assert_type(np.percentile(AR_O, 50), Any) +assert_type(np.percentile(AR_f8, [50]), npt.NDArray[np.floating]) +assert_type(np.percentile(AR_c16, [50]), npt.NDArray[np.complexfloating]) +assert_type(np.percentile(AR_m, [50]), npt.NDArray[np.timedelta64]) +assert_type(np.percentile(AR_M, [50], method="nearest"), npt.NDArray[np.datetime64]) +assert_type(np.percentile(AR_O, [50]), npt.NDArray[np.object_]) +assert_type(np.percentile(AR_f8, [50], keepdims=True), Any) +assert_type(np.percentile(AR_f8, [50], axis=[1]), Any) +assert_type(np.percentile(AR_f8, [50], out=AR_c16), npt.NDArray[np.complex128]) + +assert_type(np.quantile(AR_f8, 0.5), np.floating) +assert_type(np.quantile(AR_c16, 0.5), np.complexfloating) +assert_type(np.quantile(AR_m, 0.5), np.timedelta64) +assert_type(np.quantile(AR_M, 0.5, overwrite_input=True), np.datetime64) +assert_type(np.quantile(AR_O, 0.5), Any) +assert_type(np.quantile(AR_f8, [0.5]), npt.NDArray[np.floating]) +assert_type(np.quantile(AR_c16, [0.5]), npt.NDArray[np.complexfloating]) +assert_type(np.quantile(AR_m, [0.5]), npt.NDArray[np.timedelta64]) +assert_type(np.quantile(AR_M, [0.5], method="nearest"), npt.NDArray[np.datetime64]) +assert_type(np.quantile(AR_O, [0.5]), npt.NDArray[np.object_]) +assert_type(np.quantile(AR_f8, [0.5], keepdims=True), Any) +assert_type(np.quantile(AR_f8, [0.5], axis=[1]), Any) +assert_type(np.quantile(AR_f8, [0.5], out=AR_c16), npt.NDArray[np.complex128]) + +assert_type(np.trapezoid(AR_LIKE_f8), np.float64) +assert_type(np.trapezoid(AR_LIKE_f8, AR_LIKE_f8), np.float64) +assert_type(np.trapezoid(AR_LIKE_c16), np.complex128) +assert_type(np.trapezoid(AR_LIKE_c16, AR_LIKE_f8), np.complex128) +assert_type(np.trapezoid(AR_LIKE_f8, AR_LIKE_c16), np.complex128) +assert_type(np.trapezoid(AR_LIKE_O), float) +assert_type(np.trapezoid(AR_LIKE_O, AR_LIKE_f8), float) +assert_type(np.trapezoid(AR_f8), np.float64 | npt.NDArray[np.float64]) +assert_type(np.trapezoid(AR_f8, AR_f8), np.float64 | npt.NDArray[np.float64]) +assert_type(np.trapezoid(AR_c16), np.complex128 | npt.NDArray[np.complex128]) +assert_type(np.trapezoid(AR_c16, AR_c16), np.complex128 | npt.NDArray[np.complex128]) +assert_type(np.trapezoid(AR_m), np.timedelta64 | npt.NDArray[np.timedelta64]) +assert_type(np.trapezoid(AR_O), float | npt.NDArray[np.object_]) +assert_type(np.trapezoid(AR_O, AR_LIKE_f8), float | npt.NDArray[np.object_]) + +assert_type(np.meshgrid(), tuple[()]) +assert_type(np.meshgrid(AR_c16, indexing="ij"), tuple[npt.NDArray[np.complex128]]) +assert_type(np.meshgrid(AR_i8, AR_f8, copy=False), tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]) +assert_type(np.meshgrid(AR_LIKE_f8, AR_f8), tuple[npt.NDArray[Any], npt.NDArray[np.float64]]) +assert_type(np.meshgrid(AR_LIKE_f8, AR_i8, AR_c16), tuple[npt.NDArray[Any], npt.NDArray[Any], npt.NDArray[Any]]) +assert_type(np.meshgrid(AR_f8, AR_f8, AR_f8, AR_f8), tuple[npt.NDArray[Any], npt.NDArray[Any], npt.NDArray[Any], npt.NDArray[Any]]) +assert_type(np.meshgrid(*AR_LIKE_f8), tuple[npt.NDArray[Any], ...]) + +assert_type(np.delete(AR_f8, np.s_[:5]), npt.NDArray[np.float64]) +assert_type(np.delete(AR_LIKE_f8, [0, 4, 9], axis=0), npt.NDArray[Any]) + +assert_type(np.insert(AR_f8, np.s_[:5], 5), npt.NDArray[np.float64]) +assert_type(np.insert(AR_LIKE_f8, [0, 4, 9], [0.5, 9.2, 7], axis=0), npt.NDArray[Any]) + +assert_type(np.append(AR_f8, 5), npt.NDArray[Any]) +assert_type(np.append(AR_LIKE_f8, 1j, axis=0), npt.NDArray[Any]) + +assert_type(np.digitize(4.5, [1]), np.intp) +assert_type(np.digitize(AR_f8, [1, 2, 3]), npt.NDArray[np.intp]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_polynomial.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_polynomial.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8b0a9f3d22e739d24bba24eb677f235745e6ae86 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_polynomial.pyi @@ -0,0 +1,144 @@ +from collections.abc import Iterator +from typing import Any, NoReturn, assert_type + +import numpy as np +import numpy.typing as npt + +AR_b: npt.NDArray[np.bool] +AR_u4: npt.NDArray[np.uint32] +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +AR_O: npt.NDArray[np.object_] + +poly_obj: np.poly1d + +assert_type(poly_obj.variable, str) +assert_type(poly_obj.order, int) +assert_type(poly_obj.o, int) +assert_type(poly_obj.roots, npt.NDArray[Any]) +assert_type(poly_obj.r, npt.NDArray[Any]) +assert_type(poly_obj.coeffs, npt.NDArray[Any]) +assert_type(poly_obj.c, npt.NDArray[Any]) +assert_type(poly_obj.coef, npt.NDArray[Any]) +assert_type(poly_obj.coefficients, npt.NDArray[Any]) +assert_type(poly_obj.__hash__, None) + +assert_type(poly_obj(1), Any) +assert_type(poly_obj([1]), npt.NDArray[Any]) +assert_type(poly_obj(poly_obj), np.poly1d) + +assert_type(len(poly_obj), int) +assert_type(-poly_obj, np.poly1d) +assert_type(+poly_obj, np.poly1d) + +assert_type(poly_obj * 5, np.poly1d) +assert_type(5 * poly_obj, np.poly1d) +assert_type(poly_obj + 5, np.poly1d) +assert_type(5 + poly_obj, np.poly1d) +assert_type(poly_obj - 5, np.poly1d) +assert_type(5 - poly_obj, np.poly1d) +assert_type(poly_obj**1, np.poly1d) +assert_type(poly_obj**1.0, np.poly1d) +assert_type(poly_obj / 5, np.poly1d) +assert_type(5 / poly_obj, np.poly1d) + +assert_type(poly_obj[0], Any) +poly_obj[0] = 5 +assert_type(iter(poly_obj), Iterator[Any]) +assert_type(poly_obj.deriv(), np.poly1d) +assert_type(poly_obj.integ(), np.poly1d) + +assert_type(np.poly(poly_obj), npt.NDArray[np.floating]) +assert_type(np.poly(AR_f8), npt.NDArray[np.floating]) +assert_type(np.poly(AR_c16), npt.NDArray[np.floating]) + +assert_type(np.polyint(poly_obj), np.poly1d) +assert_type(np.polyint(AR_f8), npt.NDArray[np.floating]) +assert_type(np.polyint(AR_f8, k=AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.polyint(AR_O, m=2), npt.NDArray[np.object_]) + +assert_type(np.polyder(poly_obj), np.poly1d) +assert_type(np.polyder(AR_f8), npt.NDArray[np.floating]) +assert_type(np.polyder(AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.polyder(AR_O, m=2), npt.NDArray[np.object_]) + +assert_type(np.polyfit(AR_f8, AR_f8, 2), npt.NDArray[np.float64]) +assert_type( + np.polyfit(AR_f8, AR_i8, 1, full=True), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.int32], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ], +) +assert_type( + np.polyfit(AR_u4, AR_f8, 1.0, cov="unscaled"), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ], +) +assert_type(np.polyfit(AR_c16, AR_f8, 2), npt.NDArray[np.complex128]) +assert_type( + np.polyfit(AR_f8, AR_c16, 1, full=True), + tuple[ + npt.NDArray[np.complex128], + npt.NDArray[np.float64], + npt.NDArray[np.int32], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ], +) +assert_type( + np.polyfit(AR_u4, AR_c16, 1.0, cov=True), + tuple[ + npt.NDArray[np.complex128], + npt.NDArray[np.complex128], + ], +) + +assert_type(np.polyval(AR_b, AR_b), npt.NDArray[np.int64]) +assert_type(np.polyval(AR_u4, AR_b), npt.NDArray[np.unsignedinteger]) +assert_type(np.polyval(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.polyval(AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(np.polyval(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.polyval(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.polyadd(poly_obj, AR_i8), np.poly1d) +assert_type(np.polyadd(AR_f8, poly_obj), np.poly1d) +assert_type(np.polyadd(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.polyadd(AR_u4, AR_b), npt.NDArray[np.unsignedinteger]) +assert_type(np.polyadd(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.polyadd(AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(np.polyadd(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.polyadd(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.polysub(poly_obj, AR_i8), np.poly1d) +assert_type(np.polysub(AR_f8, poly_obj), np.poly1d) +assert_type(np.polysub(AR_b, AR_b), NoReturn) +assert_type(np.polysub(AR_u4, AR_b), npt.NDArray[np.unsignedinteger]) +assert_type(np.polysub(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.polysub(AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(np.polysub(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.polysub(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.polymul(poly_obj, AR_i8), np.poly1d) +assert_type(np.polymul(AR_f8, poly_obj), np.poly1d) +assert_type(np.polymul(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.polymul(AR_u4, AR_b), npt.NDArray[np.unsignedinteger]) +assert_type(np.polymul(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.polymul(AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(np.polymul(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.polymul(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.polydiv(poly_obj, AR_i8), tuple[np.poly1d, np.poly1d]) +assert_type(np.polydiv(AR_f8, poly_obj), tuple[np.poly1d, np.poly1d]) +assert_type(np.polydiv(AR_b, AR_b), tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]) +assert_type(np.polydiv(AR_u4, AR_b), tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]) +assert_type(np.polydiv(AR_i8, AR_i8), tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]) +assert_type(np.polydiv(AR_f8, AR_i8), tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]) +assert_type(np.polydiv(AR_i8, AR_c16), tuple[npt.NDArray[np.complexfloating], npt.NDArray[np.complexfloating]]) +assert_type(np.polydiv(AR_O, AR_O), tuple[npt.NDArray[Any], npt.NDArray[Any]]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_utils.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_utils.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c9470e00a3590c00594d87856a232575bc9ff63e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_utils.pyi @@ -0,0 +1,17 @@ +from io import StringIO +from typing import assert_type + +import numpy as np +import numpy.lib.array_utils as array_utils +import numpy.typing as npt + +AR: npt.NDArray[np.float64] +AR_DICT: dict[str, npt.NDArray[np.float64]] +FILE: StringIO + +def func(a: int) -> bool: ... + +assert_type(array_utils.byte_bounds(AR), tuple[int, int]) +assert_type(array_utils.byte_bounds(np.float64()), tuple[int, int]) + +assert_type(np.info(1, output=FILE), None) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_version.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_version.pyi new file mode 100644 index 0000000000000000000000000000000000000000..03735375ae3e8945c5f9d8f1cab7c90be60f8cf6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/lib_version.pyi @@ -0,0 +1,20 @@ +from typing import assert_type + +from numpy.lib import NumpyVersion + +version = NumpyVersion("1.8.0") + +assert_type(version.vstring, str) +assert_type(version.version, str) +assert_type(version.major, int) +assert_type(version.minor, int) +assert_type(version.bugfix, int) +assert_type(version.pre_release, str) +assert_type(version.is_devversion, bool) + +assert_type(version == version, bool) +assert_type(version != version, bool) +assert_type(version < "1.8.0", bool) +assert_type(version <= version, bool) +assert_type(version > version, bool) +assert_type(version >= "1.8.0", bool) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/linalg.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/linalg.pyi new file mode 100644 index 0000000000000000000000000000000000000000..fbaac3cfab0be0ef5800221c429e43aa182ec8d9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/linalg.pyi @@ -0,0 +1,132 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy.linalg._linalg import ( + EighResult, + EigResult, + QRResult, + SlogdetResult, + SVDResult, +) + +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +AR_O: npt.NDArray[np.object_] +AR_m: npt.NDArray[np.timedelta64] +AR_S: npt.NDArray[np.str_] +AR_b: npt.NDArray[np.bool] + +assert_type(np.linalg.tensorsolve(AR_i8, AR_i8), npt.NDArray[np.float64]) +assert_type(np.linalg.tensorsolve(AR_i8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.tensorsolve(AR_c16, AR_f8), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.solve(AR_i8, AR_i8), npt.NDArray[np.float64]) +assert_type(np.linalg.solve(AR_i8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.solve(AR_c16, AR_f8), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.tensorinv(AR_i8), npt.NDArray[np.float64]) +assert_type(np.linalg.tensorinv(AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.tensorinv(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.inv(AR_i8), npt.NDArray[np.float64]) +assert_type(np.linalg.inv(AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.inv(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.matrix_power(AR_i8, -1), npt.NDArray[Any]) +assert_type(np.linalg.matrix_power(AR_f8, 0), npt.NDArray[Any]) +assert_type(np.linalg.matrix_power(AR_c16, 1), npt.NDArray[Any]) +assert_type(np.linalg.matrix_power(AR_O, 2), npt.NDArray[Any]) + +assert_type(np.linalg.cholesky(AR_i8), npt.NDArray[np.float64]) +assert_type(np.linalg.cholesky(AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.cholesky(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.outer(AR_i8, AR_i8), npt.NDArray[np.int64]) +assert_type(np.linalg.outer(AR_f8, AR_f8), npt.NDArray[np.float64]) +assert_type(np.linalg.outer(AR_c16, AR_c16), npt.NDArray[np.complex128]) +assert_type(np.linalg.outer(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.linalg.outer(AR_O, AR_O), npt.NDArray[np.object_]) +assert_type(np.linalg.outer(AR_i8, AR_m), npt.NDArray[np.timedelta64]) + +assert_type(np.linalg.qr(AR_i8), QRResult) +assert_type(np.linalg.qr(AR_f8), QRResult) +assert_type(np.linalg.qr(AR_c16), QRResult) + +assert_type(np.linalg.eigvals(AR_i8), npt.NDArray[np.float64] | npt.NDArray[np.complex128]) +assert_type(np.linalg.eigvals(AR_f8), npt.NDArray[np.floating] | npt.NDArray[np.complexfloating]) +assert_type(np.linalg.eigvals(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.eigvalsh(AR_i8), npt.NDArray[np.float64]) +assert_type(np.linalg.eigvalsh(AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.eigvalsh(AR_c16), npt.NDArray[np.floating]) + +assert_type(np.linalg.eig(AR_i8), EigResult) +assert_type(np.linalg.eig(AR_f8), EigResult) +assert_type(np.linalg.eig(AR_c16), EigResult) + +assert_type(np.linalg.eigh(AR_i8), EighResult) +assert_type(np.linalg.eigh(AR_f8), EighResult) +assert_type(np.linalg.eigh(AR_c16), EighResult) + +assert_type(np.linalg.svd(AR_i8), SVDResult) +assert_type(np.linalg.svd(AR_f8), SVDResult) +assert_type(np.linalg.svd(AR_c16), SVDResult) +assert_type(np.linalg.svd(AR_i8, compute_uv=False), npt.NDArray[np.float64]) +assert_type(np.linalg.svd(AR_f8, compute_uv=False), npt.NDArray[np.floating]) +assert_type(np.linalg.svd(AR_c16, compute_uv=False), npt.NDArray[np.floating]) + +assert_type(np.linalg.cond(AR_i8), Any) +assert_type(np.linalg.cond(AR_f8), Any) +assert_type(np.linalg.cond(AR_c16), Any) + +assert_type(np.linalg.matrix_rank(AR_i8), Any) +assert_type(np.linalg.matrix_rank(AR_f8), Any) +assert_type(np.linalg.matrix_rank(AR_c16), Any) + +assert_type(np.linalg.pinv(AR_i8), npt.NDArray[np.float64]) +assert_type(np.linalg.pinv(AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.pinv(AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.slogdet(AR_i8), SlogdetResult) +assert_type(np.linalg.slogdet(AR_f8), SlogdetResult) +assert_type(np.linalg.slogdet(AR_c16), SlogdetResult) + +assert_type(np.linalg.det(AR_i8), Any) +assert_type(np.linalg.det(AR_f8), Any) +assert_type(np.linalg.det(AR_c16), Any) + +assert_type(np.linalg.lstsq(AR_i8, AR_i8), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], np.int32, npt.NDArray[np.float64]]) +assert_type(np.linalg.lstsq(AR_i8, AR_f8), tuple[npt.NDArray[np.floating], npt.NDArray[np.floating], np.int32, npt.NDArray[np.floating]]) +assert_type(np.linalg.lstsq(AR_f8, AR_c16), tuple[npt.NDArray[np.complexfloating], npt.NDArray[np.floating], np.int32, npt.NDArray[np.floating]]) + +assert_type(np.linalg.norm(AR_i8), np.floating) +assert_type(np.linalg.norm(AR_f8), np.floating) +assert_type(np.linalg.norm(AR_c16), np.floating) +assert_type(np.linalg.norm(AR_S), np.floating) +assert_type(np.linalg.norm(AR_f8, axis=0), Any) + +assert_type(np.linalg.matrix_norm(AR_i8), np.floating) +assert_type(np.linalg.matrix_norm(AR_f8), np.floating) +assert_type(np.linalg.matrix_norm(AR_c16), np.floating) +assert_type(np.linalg.matrix_norm(AR_S), np.floating) + +assert_type(np.linalg.vector_norm(AR_i8), np.floating) +assert_type(np.linalg.vector_norm(AR_f8), np.floating) +assert_type(np.linalg.vector_norm(AR_c16), np.floating) +assert_type(np.linalg.vector_norm(AR_S), np.floating) + +assert_type(np.linalg.multi_dot([AR_i8, AR_i8]), Any) +assert_type(np.linalg.multi_dot([AR_i8, AR_f8]), Any) +assert_type(np.linalg.multi_dot([AR_f8, AR_c16]), Any) +assert_type(np.linalg.multi_dot([AR_O, AR_O]), Any) +assert_type(np.linalg.multi_dot([AR_m, AR_m]), Any) + +assert_type(np.linalg.cross(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.linalg.cross(AR_f8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.linalg.cross(AR_c16, AR_c16), npt.NDArray[np.complexfloating]) + +assert_type(np.linalg.matmul(AR_i8, AR_i8), npt.NDArray[np.int64]) +assert_type(np.linalg.matmul(AR_f8, AR_f8), npt.NDArray[np.float64]) +assert_type(np.linalg.matmul(AR_c16, AR_c16), npt.NDArray[np.complex128]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ma.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ma.pyi new file mode 100644 index 0000000000000000000000000000000000000000..2c65534ecb04fcb1b6e43dcceb3cc4767922bf4e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ma.pyi @@ -0,0 +1,369 @@ +from typing import Any, Literal, TypeAlias, TypeVar, assert_type + +import numpy as np +from numpy import dtype, generic +from numpy._typing import NDArray, _AnyShape + +_ScalarT = TypeVar("_ScalarT", bound=generic) +MaskedArray: TypeAlias = np.ma.MaskedArray[_AnyShape, dtype[_ScalarT]] +_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]] + +class MaskedArraySubclass(MaskedArray[np.complex128]): ... + +AR_b: NDArray[np.bool] +AR_f4: NDArray[np.float32] +AR_dt64: NDArray[np.datetime64] +AR_td64: NDArray[np.timedelta64] +AR_o: NDArray[np.timedelta64] + +MAR_c16: MaskedArray[np.complex128] +MAR_b: MaskedArray[np.bool] +MAR_f4: MaskedArray[np.float32] +MAR_f8: MaskedArray[np.float64] +MAR_i8: MaskedArray[np.int64] +MAR_dt64: MaskedArray[np.datetime64] +MAR_td64: MaskedArray[np.timedelta64] +MAR_o: MaskedArray[np.object_] +MAR_s: MaskedArray[np.str_] +MAR_byte: MaskedArray[np.bytes_] +MAR_V: MaskedArray[np.void] + +MAR_subclass: MaskedArraySubclass + +MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype] +MAR_2d_f4: np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]] + +b: np.bool +f4: np.float32 +f: float + +assert_type(MAR_1d.shape, tuple[int]) + +assert_type(MAR_f4.dtype, np.dtype[np.float32]) + +assert_type(int(MAR_i8), int) +assert_type(float(MAR_f4), float) + +assert_type(np.ma.min(MAR_b), np.bool) +assert_type(np.ma.min(MAR_f4), np.float32) +assert_type(np.ma.min(MAR_b, axis=0), Any) +assert_type(np.ma.min(MAR_f4, axis=0), Any) +assert_type(np.ma.min(MAR_b, keepdims=True), Any) +assert_type(np.ma.min(MAR_f4, keepdims=True), Any) +assert_type(np.ma.min(MAR_f4, out=MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.min(MAR_f4, 0, MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.min(MAR_f4, None, MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_b.min(), np.bool) +assert_type(MAR_f4.min(), np.float32) +assert_type(MAR_b.min(axis=0), Any) +assert_type(MAR_f4.min(axis=0), Any) +assert_type(MAR_b.min(keepdims=True), Any) +assert_type(MAR_f4.min(keepdims=True), Any) +assert_type(MAR_f4.min(out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.min(0, MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.min(None, MAR_subclass), MaskedArraySubclass) + +assert_type(np.ma.max(MAR_b), np.bool) +assert_type(np.ma.max(MAR_f4), np.float32) +assert_type(np.ma.max(MAR_b, axis=0), Any) +assert_type(np.ma.max(MAR_f4, axis=0), Any) +assert_type(np.ma.max(MAR_b, keepdims=True), Any) +assert_type(np.ma.max(MAR_f4, keepdims=True), Any) +assert_type(np.ma.max(MAR_f4, out=MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.max(MAR_f4, 0, MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.max(MAR_f4, None, MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_b.max(), np.bool) +assert_type(MAR_f4.max(), np.float32) +assert_type(MAR_b.max(axis=0), Any) +assert_type(MAR_f4.max(axis=0), Any) +assert_type(MAR_b.max(keepdims=True), Any) +assert_type(MAR_f4.max(keepdims=True), Any) +assert_type(MAR_f4.max(out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.max(0, MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.max(None, MAR_subclass), MaskedArraySubclass) + +assert_type(np.ma.ptp(MAR_b), np.bool) +assert_type(np.ma.ptp(MAR_f4), np.float32) +assert_type(np.ma.ptp(MAR_b, axis=0), Any) +assert_type(np.ma.ptp(MAR_f4, axis=0), Any) +assert_type(np.ma.ptp(MAR_b, keepdims=True), Any) +assert_type(np.ma.ptp(MAR_f4, keepdims=True), Any) +assert_type(np.ma.ptp(MAR_f4, out=MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.ptp(MAR_f4, 0, MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.ptp(MAR_f4, None, MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_b.ptp(), np.bool) +assert_type(MAR_f4.ptp(), np.float32) +assert_type(MAR_b.ptp(axis=0), Any) +assert_type(MAR_f4.ptp(axis=0), Any) +assert_type(MAR_b.ptp(keepdims=True), Any) +assert_type(MAR_f4.ptp(keepdims=True), Any) +assert_type(MAR_f4.ptp(out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.ptp(0, MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.ptp(None, MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_b.argmin(), np.intp) +assert_type(MAR_f4.argmin(), np.intp) +assert_type(MAR_f4.argmax(fill_value=6.28318, keepdims=False), np.intp) +assert_type(MAR_b.argmin(axis=0), Any) +assert_type(MAR_f4.argmin(axis=0), Any) +assert_type(MAR_b.argmin(keepdims=True), Any) +assert_type(MAR_f4.argmin(out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.argmin(None, None, out=MAR_subclass), MaskedArraySubclass) + +assert_type(np.ma.argmin(MAR_b), np.intp) +assert_type(np.ma.argmin(MAR_f4), np.intp) +assert_type(np.ma.argmin(MAR_f4, fill_value=6.28318, keepdims=False), np.intp) +assert_type(np.ma.argmin(MAR_b, axis=0), Any) +assert_type(np.ma.argmin(MAR_f4, axis=0), Any) +assert_type(np.ma.argmin(MAR_b, keepdims=True), Any) +assert_type(np.ma.argmin(MAR_f4, out=MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.argmin(MAR_f4, None, None, out=MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_b.argmax(), np.intp) +assert_type(MAR_f4.argmax(), np.intp) +assert_type(MAR_f4.argmax(fill_value=6.28318, keepdims=False), np.intp) +assert_type(MAR_b.argmax(axis=0), Any) +assert_type(MAR_f4.argmax(axis=0), Any) +assert_type(MAR_b.argmax(keepdims=True), Any) +assert_type(MAR_f4.argmax(out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.argmax(None, None, out=MAR_subclass), MaskedArraySubclass) + +assert_type(np.ma.argmax(MAR_b), np.intp) +assert_type(np.ma.argmax(MAR_f4), np.intp) +assert_type(np.ma.argmax(MAR_f4, fill_value=6.28318, keepdims=False), np.intp) +assert_type(np.ma.argmax(MAR_b, axis=0), Any) +assert_type(np.ma.argmax(MAR_f4, axis=0), Any) +assert_type(np.ma.argmax(MAR_b, keepdims=True), Any) +assert_type(np.ma.argmax(MAR_f4, out=MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.argmax(MAR_f4, None, None, out=MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_b.all(), np.bool) +assert_type(MAR_f4.all(), np.bool) +assert_type(MAR_f4.all(keepdims=False), np.bool) +assert_type(MAR_b.all(axis=0), np.bool | MaskedArray[np.bool]) +assert_type(MAR_b.all(axis=0, keepdims=True), MaskedArray[np.bool]) +assert_type(MAR_b.all(0, None, True), MaskedArray[np.bool]) +assert_type(MAR_f4.all(axis=0), np.bool | MaskedArray[np.bool]) +assert_type(MAR_b.all(keepdims=True), MaskedArray[np.bool]) +assert_type(MAR_f4.all(out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.all(None, out=MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_b.any(), np.bool) +assert_type(MAR_f4.any(), np.bool) +assert_type(MAR_f4.any(keepdims=False), np.bool) +assert_type(MAR_b.any(axis=0), np.bool | MaskedArray[np.bool]) +assert_type(MAR_b.any(axis=0, keepdims=True), MaskedArray[np.bool]) +assert_type(MAR_b.any(0, None, True), MaskedArray[np.bool]) +assert_type(MAR_f4.any(axis=0), np.bool | MaskedArray[np.bool]) +assert_type(MAR_b.any(keepdims=True), MaskedArray[np.bool]) +assert_type(MAR_f4.any(out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f4.any(None, out=MAR_subclass), MaskedArraySubclass) + +assert_type(MAR_f4.sort(), None) +assert_type(MAR_f4.sort(axis=0, kind='quicksort', order='K', endwith=False, fill_value=42., stable=False), None) + +assert_type(np.ma.sort(MAR_f4), MaskedArray[np.float32]) +assert_type(np.ma.sort(MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.sort([[0, 1], [2, 3]]), NDArray[Any]) +assert_type(np.ma.sort(AR_f4), NDArray[np.float32]) + +assert_type(MAR_f8.take(0), np.float64) +assert_type(MAR_1d.take(0), Any) +assert_type(MAR_f8.take([0]), MaskedArray[np.float64]) +assert_type(MAR_f8.take(0, out=MAR_subclass), MaskedArraySubclass) +assert_type(MAR_f8.take([0], out=MAR_subclass), MaskedArraySubclass) + +assert_type(np.ma.take(f, 0), Any) +assert_type(np.ma.take(f4, 0), np.float32) +assert_type(np.ma.take(MAR_f8, 0), np.float64) +assert_type(np.ma.take(AR_f4, 0), np.float32) +assert_type(np.ma.take(MAR_1d, 0), Any) +assert_type(np.ma.take(MAR_f8, [0]), MaskedArray[np.float64]) +assert_type(np.ma.take(AR_f4, [0]), MaskedArray[np.float32]) +assert_type(np.ma.take(MAR_f8, 0, out=MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.take(MAR_f8, [0], out=MAR_subclass), MaskedArraySubclass) +assert_type(np.ma.take([1], [0]), MaskedArray[Any]) +assert_type(np.ma.take(np.eye(2), 1, axis=0), MaskedArray[np.float64]) + +assert_type(MAR_f4.partition(1), None) +assert_type(MAR_V.partition(1, axis=0, kind='introselect', order='K'), None) + +assert_type(MAR_f4.argpartition(1), MaskedArray[np.intp]) +assert_type(MAR_1d.argpartition(1, axis=0, kind='introselect', order='K'), MaskedArray[np.intp]) + +assert_type(np.ma.ndim(f4), int) +assert_type(np.ma.ndim(MAR_b), int) +assert_type(np.ma.ndim(AR_f4), int) + +assert_type(np.ma.size(b), int) +assert_type(np.ma.size(MAR_f4, axis=0), int) +assert_type(np.ma.size(AR_f4), int) + +assert_type(np.ma.is_masked(MAR_f4), bool) + +assert_type(MAR_f4.ids(), tuple[int, int]) + +assert_type(MAR_f4.iscontiguous(), bool) + +assert_type(MAR_f4 >= 3, MaskedArray[np.bool]) +assert_type(MAR_i8 >= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_b >= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_td64 >= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_dt64 >= AR_dt64, MaskedArray[np.bool]) +assert_type(MAR_o >= AR_o, MaskedArray[np.bool]) +assert_type(MAR_1d >= 0, MaskedArray[np.bool]) +assert_type(MAR_s >= MAR_s, MaskedArray[np.bool]) +assert_type(MAR_byte >= MAR_byte, MaskedArray[np.bool]) + +assert_type(MAR_f4 > 3, MaskedArray[np.bool]) +assert_type(MAR_i8 > AR_td64, MaskedArray[np.bool]) +assert_type(MAR_b > AR_td64, MaskedArray[np.bool]) +assert_type(MAR_td64 > AR_td64, MaskedArray[np.bool]) +assert_type(MAR_dt64 > AR_dt64, MaskedArray[np.bool]) +assert_type(MAR_o > AR_o, MaskedArray[np.bool]) +assert_type(MAR_1d > 0, MaskedArray[np.bool]) +assert_type(MAR_s > MAR_s, MaskedArray[np.bool]) +assert_type(MAR_byte > MAR_byte, MaskedArray[np.bool]) + +assert_type(MAR_f4 <= 3, MaskedArray[np.bool]) +assert_type(MAR_i8 <= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_b <= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_td64 <= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_dt64 <= AR_dt64, MaskedArray[np.bool]) +assert_type(MAR_o <= AR_o, MaskedArray[np.bool]) +assert_type(MAR_1d <= 0, MaskedArray[np.bool]) +assert_type(MAR_s <= MAR_s, MaskedArray[np.bool]) +assert_type(MAR_byte <= MAR_byte, MaskedArray[np.bool]) + +assert_type(MAR_f4 < 3, MaskedArray[np.bool]) +assert_type(MAR_i8 < AR_td64, MaskedArray[np.bool]) +assert_type(MAR_b < AR_td64, MaskedArray[np.bool]) +assert_type(MAR_td64 < AR_td64, MaskedArray[np.bool]) +assert_type(MAR_dt64 < AR_dt64, MaskedArray[np.bool]) +assert_type(MAR_o < AR_o, MaskedArray[np.bool]) +assert_type(MAR_1d < 0, MaskedArray[np.bool]) +assert_type(MAR_s < MAR_s, MaskedArray[np.bool]) +assert_type(MAR_byte < MAR_byte, MaskedArray[np.bool]) + +assert_type(MAR_f4 <= 3, MaskedArray[np.bool]) +assert_type(MAR_i8 <= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_b <= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_td64 <= AR_td64, MaskedArray[np.bool]) +assert_type(MAR_dt64 <= AR_dt64, MaskedArray[np.bool]) +assert_type(MAR_o <= AR_o, MaskedArray[np.bool]) +assert_type(MAR_1d <= 0, MaskedArray[np.bool]) +assert_type(MAR_s <= MAR_s, MaskedArray[np.bool]) +assert_type(MAR_byte <= MAR_byte, MaskedArray[np.bool]) + +assert_type(MAR_byte.count(), int) +assert_type(MAR_f4.count(axis=None), int) +assert_type(MAR_f4.count(axis=0), NDArray[np.int_]) +assert_type(MAR_b.count(axis=(0,1)), NDArray[np.int_]) +assert_type(MAR_o.count(keepdims=True), NDArray[np.int_]) +assert_type(MAR_o.count(axis=None, keepdims=True), NDArray[np.int_]) +assert_type(MAR_o.count(None, True), NDArray[np.int_]) + +assert_type(np.ma.count(MAR_byte), int) +assert_type(np.ma.count(MAR_byte, axis=None), int) +assert_type(np.ma.count(MAR_f4, axis=0), NDArray[np.int_]) +assert_type(np.ma.count(MAR_b, axis=(0,1)), NDArray[np.int_]) +assert_type(np.ma.count(MAR_o, keepdims=True), NDArray[np.int_]) +assert_type(np.ma.count(MAR_o, axis=None, keepdims=True), NDArray[np.int_]) +assert_type(np.ma.count(MAR_o, None, True), NDArray[np.int_]) + +assert_type(MAR_f4.compressed(), np.ndarray[tuple[int], np.dtype[np.float32]]) + +assert_type(np.ma.compressed(MAR_i8), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(np.ma.compressed([[1,2,3]]), np.ndarray[tuple[int], np.dtype]) + +assert_type(MAR_f4.put([0,4,8], [10,20,30]), None) +assert_type(MAR_f4.put(4, 999), None) +assert_type(MAR_f4.put(4, 999, mode='clip'), None) + +assert_type(np.ma.put(MAR_f4, [0,4,8], [10,20,30]), None) +assert_type(np.ma.put(MAR_f4, 4, 999), None) +assert_type(np.ma.put(MAR_f4, 4, 999, mode='clip'), None) + +assert_type(np.ma.putmask(MAR_f4, [True, False], [0, 1]), None) +assert_type(np.ma.putmask(MAR_f4, np.False_, [0, 1]), None) + +assert_type(MAR_f4.filled(float('nan')), NDArray[np.float32]) +assert_type(MAR_i8.filled(), NDArray[np.int64]) +assert_type(MAR_1d.filled(), np.ndarray[tuple[int], np.dtype]) + +assert_type(np.ma.filled(MAR_f4, float('nan')), NDArray[np.float32]) +assert_type(np.ma.filled([[1,2,3]]), NDArray[Any]) +# PyRight detects this one correctly, but mypy doesn't. +# https://github.com/numpy/numpy/pull/28742#discussion_r2048968375 +assert_type(np.ma.filled(MAR_1d), np.ndarray[tuple[int], np.dtype]) # type: ignore[assert-type] + +assert_type(MAR_b.repeat(3), np.ma.MaskedArray[tuple[int], np.dtype[np.bool]]) +assert_type(MAR_2d_f4.repeat(MAR_i8), np.ma.MaskedArray[tuple[int], np.dtype[np.float32]]) +assert_type(MAR_2d_f4.repeat(MAR_i8, axis=None), np.ma.MaskedArray[tuple[int], np.dtype[np.float32]]) +assert_type(MAR_2d_f4.repeat(MAR_i8, axis=0), MaskedArray[np.float32]) + +assert_type(np.ma.allequal(AR_f4, MAR_f4), bool) +assert_type(np.ma.allequal(AR_f4, MAR_f4, fill_value=False), bool) + +assert_type(np.ma.allclose(AR_f4, MAR_f4), bool) +assert_type(np.ma.allclose(AR_f4, MAR_f4, masked_equal=False), bool) +assert_type(np.ma.allclose(AR_f4, MAR_f4, rtol=.4, atol=.3), bool) + +assert_type(MAR_2d_f4.ravel(), np.ma.MaskedArray[tuple[int], np.dtype[np.float32]]) +assert_type(MAR_1d.ravel(order='A'), np.ma.MaskedArray[tuple[int], np.dtype[Any]]) + +assert_type(np.ma.getmask(MAR_f4), NDArray[np.bool] | np.bool) +# PyRight detects this one correctly, but mypy doesn't: +# `Revealed type is "Union[numpy.ndarray[Any, Any], numpy.bool[Any]]"` +assert_type(np.ma.getmask(MAR_1d), np.ndarray[tuple[int], np.dtype[np.bool]] | np.bool) # type: ignore[assert-type] +assert_type(np.ma.getmask(MAR_2d_f4), np.ndarray[tuple[int, int], np.dtype[np.bool]] | np.bool) +assert_type(np.ma.getmask([1,2]), NDArray[np.bool] | np.bool) +assert_type(np.ma.getmask(np.int64(1)), np.bool) + +assert_type(np.ma.is_mask(MAR_1d), bool) +assert_type(np.ma.is_mask(AR_b), bool) + +def func(x: object) -> None: + if np.ma.is_mask(x): + assert_type(x, NDArray[np.bool]) + else: + assert_type(x, object) + +assert_type(MAR_2d_f4.mT, np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) + +assert_type(MAR_c16.real, MaskedArray[np.float64]) +assert_type(MAR_c16.imag, MaskedArray[np.float64]) + +assert_type(MAR_2d_f4.baseclass, type[NDArray[Any]]) + +assert_type(MAR_b.swapaxes(0, 1), MaskedArray[np.bool]) +assert_type(MAR_2d_f4.swapaxes(1, 0), MaskedArray[np.float32]) + +assert_type(np.ma.nomask, np.bool[Literal[False]]) +assert_type(np.ma.MaskType, type[np.bool]) + +assert_type(MAR_1d.__setmask__([True, False]), None) +assert_type(MAR_1d.__setmask__(np.False_), None) + +assert_type(MAR_2d_f4.harden_mask(), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(MAR_i8.harden_mask(), MaskedArray[np.int64]) +assert_type(MAR_2d_f4.soften_mask(), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(MAR_i8.soften_mask(), MaskedArray[np.int64]) +assert_type(MAR_f4.unshare_mask(), MaskedArray[np.float32]) +assert_type(MAR_b.shrink_mask(), MaskedArray[np.bool_]) + +assert_type(MAR_i8.hardmask, bool) +assert_type(MAR_i8.sharedmask, bool) + +assert_type(MAR_b.transpose(), MaskedArray[np.bool]) +assert_type(MAR_2d_f4.transpose(), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(MAR_2d_f4.transpose(1, 0), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(MAR_2d_f4.transpose((1, 0)), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(MAR_b.T, MaskedArray[np.bool]) +assert_type(MAR_2d_f4.T, np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) + +assert_type(MAR_2d_f4.nonzero(), tuple[_Array1D[np.intp], *tuple[_Array1D[np.intp], ...]]) +assert_type(MAR_2d_f4.nonzero()[0], _Array1D[np.intp]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/matrix.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/matrix.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1a7285d428ccbccacfb72acea284f876cf084d6f --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/matrix.pyi @@ -0,0 +1,73 @@ +from typing import Any, TypeAlias, assert_type + +import numpy as np +import numpy.typing as npt + +_Shape2D: TypeAlias = tuple[int, int] + +mat: np.matrix[_Shape2D, np.dtype[np.int64]] +ar_f8: npt.NDArray[np.float64] +ar_ip: npt.NDArray[np.intp] + +assert_type(mat * 5, np.matrix[_Shape2D, Any]) +assert_type(5 * mat, np.matrix[_Shape2D, Any]) +mat *= 5 + +assert_type(mat**5, np.matrix[_Shape2D, Any]) +mat **= 5 + +assert_type(mat.sum(), Any) +assert_type(mat.mean(), Any) +assert_type(mat.std(), Any) +assert_type(mat.var(), Any) +assert_type(mat.prod(), Any) +assert_type(mat.any(), np.bool) +assert_type(mat.all(), np.bool) +assert_type(mat.max(), np.int64) +assert_type(mat.min(), np.int64) +assert_type(mat.argmax(), np.intp) +assert_type(mat.argmin(), np.intp) +assert_type(mat.ptp(), np.int64) + +assert_type(mat.sum(axis=0), np.matrix[_Shape2D, Any]) +assert_type(mat.mean(axis=0), np.matrix[_Shape2D, Any]) +assert_type(mat.std(axis=0), np.matrix[_Shape2D, Any]) +assert_type(mat.var(axis=0), np.matrix[_Shape2D, Any]) +assert_type(mat.prod(axis=0), np.matrix[_Shape2D, Any]) +assert_type(mat.any(axis=0), np.matrix[_Shape2D, np.dtype[np.bool]]) +assert_type(mat.all(axis=0), np.matrix[_Shape2D, np.dtype[np.bool]]) +assert_type(mat.max(axis=0), np.matrix[_Shape2D, np.dtype[np.int64]]) +assert_type(mat.min(axis=0), np.matrix[_Shape2D, np.dtype[np.int64]]) +assert_type(mat.argmax(axis=0), np.matrix[_Shape2D, np.dtype[np.intp]]) +assert_type(mat.argmin(axis=0), np.matrix[_Shape2D, np.dtype[np.intp]]) +assert_type(mat.ptp(axis=0), np.matrix[_Shape2D, np.dtype[np.int64]]) + +assert_type(mat.sum(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.mean(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.std(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.var(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.prod(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.any(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.all(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.max(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.min(out=ar_f8), npt.NDArray[np.float64]) +assert_type(mat.argmax(out=ar_ip), npt.NDArray[np.intp]) +assert_type(mat.argmin(out=ar_ip), npt.NDArray[np.intp]) +assert_type(mat.ptp(out=ar_f8), npt.NDArray[np.float64]) + +assert_type(mat.T, np.matrix[_Shape2D, np.dtype[np.int64]]) +assert_type(mat.I, np.matrix[_Shape2D, Any]) +assert_type(mat.A, np.ndarray[_Shape2D, np.dtype[np.int64]]) +assert_type(mat.A1, npt.NDArray[np.int64]) +assert_type(mat.H, np.matrix[_Shape2D, np.dtype[np.int64]]) +assert_type(mat.getT(), np.matrix[_Shape2D, np.dtype[np.int64]]) +assert_type(mat.getI(), np.matrix[_Shape2D, Any]) +assert_type(mat.getA(), np.ndarray[_Shape2D, np.dtype[np.int64]]) +assert_type(mat.getA1(), npt.NDArray[np.int64]) +assert_type(mat.getH(), np.matrix[_Shape2D, np.dtype[np.int64]]) + +assert_type(np.bmat(ar_f8), np.matrix[_Shape2D, Any]) +assert_type(np.bmat([[0, 1, 2]]), np.matrix[_Shape2D, Any]) +assert_type(np.bmat("mat"), np.matrix[_Shape2D, Any]) + +assert_type(np.asmatrix(ar_f8, dtype=np.int64), np.matrix[_Shape2D, Any]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/memmap.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/memmap.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f3e20ed2d5e7df74e3e2fa66d743430fece8cb62 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/memmap.pyi @@ -0,0 +1,19 @@ +from typing import Any, assert_type + +import numpy as np + +memmap_obj: np.memmap[Any, np.dtype[np.str_]] + +assert_type(np.memmap.__array_priority__, float) +assert_type(memmap_obj.__array_priority__, float) +assert_type(memmap_obj.filename, str | None) +assert_type(memmap_obj.offset, int) +assert_type(memmap_obj.mode, str) +assert_type(memmap_obj.flush(), None) + +assert_type(np.memmap("file.txt", offset=5), np.memmap[Any, np.dtype[np.uint8]]) +assert_type(np.memmap(b"file.txt", dtype=np.float64, shape=(10, 3)), np.memmap[Any, np.dtype[np.float64]]) +with open("file.txt", "rb") as f: + assert_type(np.memmap(f, dtype=float, order="K"), np.memmap[Any, np.dtype]) + +assert_type(memmap_obj.__array_finalize__(object()), None) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/mod.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/mod.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ce74557d39d6f630e049e29cbc8d588cfb57e11b --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/mod.pyi @@ -0,0 +1,179 @@ +import datetime as dt +from typing import Literal as L +from typing import assert_type + +import numpy as np +import numpy.typing as npt + +f8: np.float64 +i8: np.int64 +u8: np.uint64 + +f4: np.float32 +i4: np.int32 +u4: np.uint32 + +m: np.timedelta64 +m_nat: np.timedelta64[None] +m_int0: np.timedelta64[L[0]] +m_int: np.timedelta64[int] +m_td: np.timedelta64[dt.timedelta] + +b_: np.bool + +b: bool +i: int +f: float + +AR_b: npt.NDArray[np.bool] +AR_m: npt.NDArray[np.timedelta64] + +# Time structures + +assert_type(m % m, np.timedelta64) +assert_type(m % m_nat, np.timedelta64[None]) +assert_type(m % m_int0, np.timedelta64[None]) +assert_type(m % m_int, np.timedelta64[int | None]) +assert_type(m_nat % m, np.timedelta64[None]) +assert_type(m_int % m_nat, np.timedelta64[None]) +assert_type(m_int % m_int0, np.timedelta64[None]) +assert_type(m_int % m_int, np.timedelta64[int | None]) +assert_type(m_int % m_td, np.timedelta64[int | None]) +assert_type(m_td % m_nat, np.timedelta64[None]) +assert_type(m_td % m_int0, np.timedelta64[None]) +assert_type(m_td % m_int, np.timedelta64[int | None]) +assert_type(m_td % m_td, np.timedelta64[dt.timedelta | None]) + +assert_type(AR_m % m, npt.NDArray[np.timedelta64]) +assert_type(m % AR_m, npt.NDArray[np.timedelta64]) + +assert_type(divmod(m, m), tuple[np.int64, np.timedelta64]) +assert_type(divmod(m, m_nat), tuple[np.int64, np.timedelta64[None]]) +assert_type(divmod(m, m_int0), tuple[np.int64, np.timedelta64[None]]) +# workarounds for https://github.com/microsoft/pyright/issues/9663 +assert_type(m.__divmod__(m_int), tuple[np.int64, np.timedelta64[int | None]]) +assert_type(divmod(m_nat, m), tuple[np.int64, np.timedelta64[None]]) +assert_type(divmod(m_int, m_nat), tuple[np.int64, np.timedelta64[None]]) +assert_type(divmod(m_int, m_int0), tuple[np.int64, np.timedelta64[None]]) +assert_type(divmod(m_int, m_int), tuple[np.int64, np.timedelta64[int | None]]) +assert_type(divmod(m_int, m_td), tuple[np.int64, np.timedelta64[int | None]]) +assert_type(divmod(m_td, m_nat), tuple[np.int64, np.timedelta64[None]]) +assert_type(divmod(m_td, m_int0), tuple[np.int64, np.timedelta64[None]]) +assert_type(divmod(m_td, m_int), tuple[np.int64, np.timedelta64[int | None]]) +assert_type(divmod(m_td, m_td), tuple[np.int64, np.timedelta64[dt.timedelta | None]]) + +assert_type(divmod(AR_m, m), tuple[npt.NDArray[np.int64], npt.NDArray[np.timedelta64]]) +assert_type(divmod(m, AR_m), tuple[npt.NDArray[np.int64], npt.NDArray[np.timedelta64]]) + +# Bool + +assert_type(b_ % b, np.int8) +assert_type(b_ % i, np.int_) +assert_type(b_ % f, np.float64) +assert_type(b_ % b_, np.int8) +assert_type(b_ % i8, np.int64) +assert_type(b_ % u8, np.uint64) +assert_type(b_ % f8, np.float64) +assert_type(b_ % AR_b, npt.NDArray[np.int8]) + +assert_type(divmod(b_, b), tuple[np.int8, np.int8]) +assert_type(divmod(b_, b_), tuple[np.int8, np.int8]) +# workarounds for https://github.com/microsoft/pyright/issues/9663 +assert_type(b_.__divmod__(i), tuple[np.int_, np.int_]) +assert_type(b_.__divmod__(f), tuple[np.float64, np.float64]) +assert_type(b_.__divmod__(i8), tuple[np.int64, np.int64]) +assert_type(b_.__divmod__(u8), tuple[np.uint64, np.uint64]) +assert_type(divmod(b_, f8), tuple[np.float64, np.float64]) +assert_type(divmod(b_, AR_b), tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) + +assert_type(b % b_, np.int8) +assert_type(i % b_, np.int_) +assert_type(f % b_, np.float64) +assert_type(b_ % b_, np.int8) +assert_type(i8 % b_, np.int64) +assert_type(u8 % b_, np.uint64) +assert_type(f8 % b_, np.float64) +assert_type(AR_b % b_, npt.NDArray[np.int8]) + +assert_type(divmod(b, b_), tuple[np.int8, np.int8]) +assert_type(divmod(i, b_), tuple[np.int_, np.int_]) +assert_type(divmod(f, b_), tuple[np.float64, np.float64]) +assert_type(divmod(b_, b_), tuple[np.int8, np.int8]) +assert_type(divmod(i8, b_), tuple[np.int64, np.int64]) +assert_type(divmod(u8, b_), tuple[np.uint64, np.uint64]) +assert_type(divmod(f8, b_), tuple[np.float64, np.float64]) +assert_type(divmod(AR_b, b_), tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) + +# int + +assert_type(i8 % b, np.int64) +assert_type(i8 % i8, np.int64) +assert_type(i8 % f, np.float64) +assert_type(i8 % f8, np.float64) +assert_type(i4 % i8, np.signedinteger) +assert_type(i4 % f8, np.float64) +assert_type(i4 % i4, np.int32) +assert_type(i4 % f4, np.floating) +assert_type(i8 % AR_b, npt.NDArray[np.int64]) + +assert_type(divmod(i8, b), tuple[np.int64, np.int64]) +assert_type(divmod(i8, i4), tuple[np.signedinteger, np.signedinteger]) +assert_type(divmod(i8, i8), tuple[np.int64, np.int64]) +# workarounds for https://github.com/microsoft/pyright/issues/9663 +assert_type(i8.__divmod__(f), tuple[np.float64, np.float64]) +assert_type(i8.__divmod__(f8), tuple[np.float64, np.float64]) +assert_type(divmod(i8, f4), tuple[np.floating, np.floating]) +assert_type(divmod(i4, i4), tuple[np.int32, np.int32]) +assert_type(divmod(i4, f4), tuple[np.floating, np.floating]) +assert_type(divmod(i8, AR_b), tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]) + +assert_type(b % i8, np.int64) +assert_type(f % i8, np.float64) +assert_type(i8 % i8, np.int64) +assert_type(f8 % i8, np.float64) +assert_type(i8 % i4, np.signedinteger) +assert_type(f8 % i4, np.float64) +assert_type(i4 % i4, np.int32) +assert_type(f4 % i4, np.floating) +assert_type(AR_b % i8, npt.NDArray[np.int64]) + +assert_type(divmod(b, i8), tuple[np.int64, np.int64]) +assert_type(divmod(f, i8), tuple[np.float64, np.float64]) +assert_type(divmod(i8, i8), tuple[np.int64, np.int64]) +assert_type(divmod(f8, i8), tuple[np.float64, np.float64]) +assert_type(divmod(i4, i8), tuple[np.signedinteger, np.signedinteger]) +assert_type(divmod(i4, i4), tuple[np.int32, np.int32]) +# workarounds for https://github.com/microsoft/pyright/issues/9663 +assert_type(f4.__divmod__(i8), tuple[np.floating, np.floating]) +assert_type(f4.__divmod__(i4), tuple[np.floating, np.floating]) +assert_type(AR_b.__divmod__(i8), tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]) + +# float + +assert_type(f8 % b, np.float64) +assert_type(f8 % f, np.float64) +assert_type(i8 % f4, np.floating) +assert_type(f4 % f4, np.float32) +assert_type(f8 % AR_b, npt.NDArray[np.float64]) + +assert_type(divmod(f8, b), tuple[np.float64, np.float64]) +assert_type(divmod(f8, f), tuple[np.float64, np.float64]) +assert_type(divmod(f8, f8), tuple[np.float64, np.float64]) +assert_type(divmod(f8, f4), tuple[np.float64, np.float64]) +assert_type(divmod(f4, f4), tuple[np.float32, np.float32]) +assert_type(divmod(f8, AR_b), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]) + +assert_type(b % f8, np.float64) +assert_type(f % f8, np.float64) # pyright: ignore[reportAssertTypeFailure] # pyright incorrectly infers `builtins.float` +assert_type(f8 % f8, np.float64) +assert_type(f8 % f8, np.float64) +assert_type(f4 % f4, np.float32) +assert_type(AR_b % f8, npt.NDArray[np.float64]) + +assert_type(divmod(b, f8), tuple[np.float64, np.float64]) +assert_type(divmod(f8, f8), tuple[np.float64, np.float64]) +assert_type(divmod(f4, f4), tuple[np.float32, np.float32]) +# workarounds for https://github.com/microsoft/pyright/issues/9663 +assert_type(f8.__rdivmod__(f), tuple[np.float64, np.float64]) +assert_type(f8.__rdivmod__(f4), tuple[np.float64, np.float64]) +assert_type(AR_b.__divmod__(f8), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/modules.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/modules.pyi new file mode 100644 index 0000000000000000000000000000000000000000..628fb500bfda11752336fa52fea5b81c259501ae --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/modules.pyi @@ -0,0 +1,51 @@ +import types +from typing import assert_type + +import numpy as np +from numpy import f2py + +assert_type(np, types.ModuleType) + +assert_type(np.char, types.ModuleType) +assert_type(np.ctypeslib, types.ModuleType) +assert_type(np.emath, types.ModuleType) +assert_type(np.fft, types.ModuleType) +assert_type(np.lib, types.ModuleType) +assert_type(np.linalg, types.ModuleType) +assert_type(np.ma, types.ModuleType) +assert_type(np.matrixlib, types.ModuleType) +assert_type(np.polynomial, types.ModuleType) +assert_type(np.random, types.ModuleType) +assert_type(np.rec, types.ModuleType) +assert_type(np.testing, types.ModuleType) +assert_type(np.version, types.ModuleType) +assert_type(np.exceptions, types.ModuleType) +assert_type(np.dtypes, types.ModuleType) + +assert_type(np.lib.format, types.ModuleType) +assert_type(np.lib.mixins, types.ModuleType) +assert_type(np.lib.scimath, types.ModuleType) +assert_type(np.lib.stride_tricks, types.ModuleType) +assert_type(np.ma.extras, types.ModuleType) +assert_type(np.polynomial.chebyshev, types.ModuleType) +assert_type(np.polynomial.hermite, types.ModuleType) +assert_type(np.polynomial.hermite_e, types.ModuleType) +assert_type(np.polynomial.laguerre, types.ModuleType) +assert_type(np.polynomial.legendre, types.ModuleType) +assert_type(np.polynomial.polynomial, types.ModuleType) + +assert_type(np.__path__, list[str]) +assert_type(np.__version__, str) +assert_type(np.test, np._pytesttester.PytestTester) +assert_type(np.test.module_name, str) + +assert_type(np.__all__, list[str]) +assert_type(np.char.__all__, list[str]) +assert_type(np.ctypeslib.__all__, list[str]) +assert_type(np.emath.__all__, list[str]) +assert_type(np.lib.__all__, list[str]) +assert_type(np.ma.__all__, list[str]) +assert_type(np.random.__all__, list[str]) +assert_type(np.rec.__all__, list[str]) +assert_type(np.testing.__all__, list[str]) +assert_type(f2py.__all__, list[str]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/multiarray.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/multiarray.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6ba3fcde632f03328a8575b5bd60c504d9bac448 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/multiarray.pyi @@ -0,0 +1,194 @@ +import datetime as dt +from typing import Any, Literal, TypeVar, assert_type + +import numpy as np +import numpy.typing as npt + +_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, covariant=True) + +class SubClass(npt.NDArray[_ScalarT_co]): ... + +subclass: SubClass[np.float64] + +AR_f8: npt.NDArray[np.float64] +AR_i8: npt.NDArray[np.int64] +AR_u1: npt.NDArray[np.uint8] +AR_m: npt.NDArray[np.timedelta64] +AR_M: npt.NDArray[np.datetime64] + +AR_LIKE_f: list[float] +AR_LIKE_i: list[int] + +m: np.timedelta64 +M: np.datetime64 + +b_f8 = np.broadcast(AR_f8) +b_i8_f8_f8 = np.broadcast(AR_i8, AR_f8, AR_f8) + +nditer_obj: np.nditer + +date_scalar: dt.date +date_seq: list[dt.date] +timedelta_seq: list[dt.timedelta] + +n1: Literal[1] +n2: Literal[2] +n3: Literal[3] + +f8: np.float64 + +def func11(a: int) -> bool: ... +def func21(a: int, b: int) -> int: ... +def func12(a: int) -> tuple[complex, bool]: ... + +assert_type(next(b_f8), tuple[Any, ...]) +assert_type(b_f8.reset(), None) +assert_type(b_f8.index, int) +assert_type(b_f8.iters, tuple[np.flatiter[Any], ...]) +assert_type(b_f8.nd, int) +assert_type(b_f8.ndim, int) +assert_type(b_f8.numiter, int) +assert_type(b_f8.shape, tuple[Any, ...]) +assert_type(b_f8.size, int) + +assert_type(next(b_i8_f8_f8), tuple[Any, ...]) +assert_type(b_i8_f8_f8.reset(), None) +assert_type(b_i8_f8_f8.index, int) +assert_type(b_i8_f8_f8.iters, tuple[np.flatiter[Any], ...]) +assert_type(b_i8_f8_f8.nd, int) +assert_type(b_i8_f8_f8.ndim, int) +assert_type(b_i8_f8_f8.numiter, int) +assert_type(b_i8_f8_f8.shape, tuple[Any, ...]) +assert_type(b_i8_f8_f8.size, int) + +assert_type(np.inner(AR_f8, AR_i8), Any) + +assert_type(np.where([True, True, False]), tuple[npt.NDArray[np.intp], ...]) +assert_type(np.where([True, True, False], 1, 0), npt.NDArray[Any]) + +assert_type(np.lexsort([0, 1, 2]), Any) + +assert_type(np.can_cast(np.dtype("i8"), int), bool) +assert_type(np.can_cast(AR_f8, "f8"), bool) +assert_type(np.can_cast(AR_f8, np.complex128, casting="unsafe"), bool) + +assert_type(np.min_scalar_type([1]), np.dtype) +assert_type(np.min_scalar_type(AR_f8), np.dtype) + +assert_type(np.result_type(int, [1]), np.dtype) +assert_type(np.result_type(AR_f8, AR_u1), np.dtype) +assert_type(np.result_type(AR_f8, np.complex128), np.dtype) + +assert_type(np.dot(AR_LIKE_f, AR_i8), Any) +assert_type(np.dot(AR_u1, 1), Any) +assert_type(np.dot(1.5j, 1), Any) +assert_type(np.dot(AR_u1, 1, out=AR_f8), npt.NDArray[np.float64]) + +assert_type(np.vdot(AR_LIKE_f, AR_i8), np.floating) +assert_type(np.vdot(AR_u1, 1), np.signedinteger) +assert_type(np.vdot(1.5j, 1), np.complexfloating) + +assert_type(np.bincount(AR_i8), npt.NDArray[np.intp]) + +assert_type(np.copyto(AR_f8, [1., 1.5, 1.6]), None) + +assert_type(np.putmask(AR_f8, [True, True, False], 1.5), None) + +assert_type(np.packbits(AR_i8), npt.NDArray[np.uint8]) +assert_type(np.packbits(AR_u1), npt.NDArray[np.uint8]) + +assert_type(np.unpackbits(AR_u1), npt.NDArray[np.uint8]) + +assert_type(np.shares_memory(1, 2), bool) +assert_type(np.shares_memory(AR_f8, AR_f8, max_work=1), bool) + +assert_type(np.may_share_memory(1, 2), bool) +assert_type(np.may_share_memory(AR_f8, AR_f8, max_work=1), bool) + +assert_type(np.promote_types(np.int32, np.int64), np.dtype) +assert_type(np.promote_types("f4", float), np.dtype) + +assert_type(np.frompyfunc(func11, n1, n1).nin, Literal[1]) +assert_type(np.frompyfunc(func11, n1, n1).nout, Literal[1]) +assert_type(np.frompyfunc(func11, n1, n1).nargs, Literal[2]) +assert_type(np.frompyfunc(func11, n1, n1).ntypes, Literal[1]) +assert_type(np.frompyfunc(func11, n1, n1).identity, None) +assert_type(np.frompyfunc(func11, n1, n1).signature, None) +assert_type(np.frompyfunc(func11, n1, n1)(f8), bool) +assert_type(np.frompyfunc(func11, n1, n1)(AR_f8), bool | npt.NDArray[np.object_]) +assert_type(np.frompyfunc(func11, n1, n1).at(AR_f8, AR_i8), None) + +assert_type(np.frompyfunc(func21, n2, n1).nin, Literal[2]) +assert_type(np.frompyfunc(func21, n2, n1).nout, Literal[1]) +assert_type(np.frompyfunc(func21, n2, n1).nargs, Literal[3]) +assert_type(np.frompyfunc(func21, n2, n1).ntypes, Literal[1]) +assert_type(np.frompyfunc(func21, n2, n1).identity, None) +assert_type(np.frompyfunc(func21, n2, n1).signature, None) +assert_type(np.frompyfunc(func21, n2, n1)(f8, f8), int) +assert_type(np.frompyfunc(func21, n2, n1)(AR_f8, f8), int | npt.NDArray[np.object_]) +assert_type(np.frompyfunc(func21, n2, n1)(f8, AR_f8), int | npt.NDArray[np.object_]) +assert_type(np.frompyfunc(func21, n2, n1).reduce(AR_f8, axis=0), int | npt.NDArray[np.object_]) +assert_type(np.frompyfunc(func21, n2, n1).accumulate(AR_f8), npt.NDArray[np.object_]) +assert_type(np.frompyfunc(func21, n2, n1).reduceat(AR_f8, AR_i8), npt.NDArray[np.object_]) +assert_type(np.frompyfunc(func21, n2, n1).outer(f8, f8), int) +assert_type(np.frompyfunc(func21, n2, n1).outer(AR_f8, f8), int | npt.NDArray[np.object_]) + +assert_type(np.frompyfunc(func21, n2, n1, identity=0).nin, Literal[2]) +assert_type(np.frompyfunc(func21, n2, n1, identity=0).nout, Literal[1]) +assert_type(np.frompyfunc(func21, n2, n1, identity=0).nargs, Literal[3]) +assert_type(np.frompyfunc(func21, n2, n1, identity=0).ntypes, Literal[1]) +assert_type(np.frompyfunc(func21, n2, n1, identity=0).identity, int) +assert_type(np.frompyfunc(func21, n2, n1, identity=0).signature, None) + +assert_type(np.frompyfunc(func12, n1, n2).nin, Literal[1]) +assert_type(np.frompyfunc(func12, n1, n2).nout, Literal[2]) +assert_type(np.frompyfunc(func12, n1, n2).nargs, int) +assert_type(np.frompyfunc(func12, n1, n2).ntypes, Literal[1]) +assert_type(np.frompyfunc(func12, n1, n2).identity, None) +assert_type(np.frompyfunc(func12, n1, n2).signature, None) +assert_type( + np.frompyfunc(func12, n2, n2)(f8, f8), + tuple[complex, complex, *tuple[complex, ...]], +) +assert_type( + np.frompyfunc(func12, n2, n2)(AR_f8, f8), + tuple[ + complex | npt.NDArray[np.object_], + complex | npt.NDArray[np.object_], + *tuple[complex | npt.NDArray[np.object_], ...], + ], +) + +assert_type(np.datetime_data("m8[D]"), tuple[str, int]) +assert_type(np.datetime_data(np.datetime64), tuple[str, int]) +assert_type(np.datetime_data(np.dtype(np.timedelta64)), tuple[str, int]) + +assert_type(np.busday_count("2011-01", "2011-02"), np.int_) +assert_type(np.busday_count(["2011-01"], "2011-02"), npt.NDArray[np.int_]) +assert_type(np.busday_count(["2011-01"], date_scalar), npt.NDArray[np.int_]) + +assert_type(np.busday_offset(M, m), np.datetime64) +assert_type(np.busday_offset(date_scalar, m), np.datetime64) +assert_type(np.busday_offset(M, 5), np.datetime64) +assert_type(np.busday_offset(AR_M, m), npt.NDArray[np.datetime64]) +assert_type(np.busday_offset(M, timedelta_seq), npt.NDArray[np.datetime64]) +assert_type(np.busday_offset("2011-01", "2011-02", roll="forward"), np.datetime64) +assert_type(np.busday_offset(["2011-01"], "2011-02", roll="forward"), npt.NDArray[np.datetime64]) + +assert_type(np.is_busday("2012"), np.bool) +assert_type(np.is_busday(date_scalar), np.bool) +assert_type(np.is_busday(["2012"]), npt.NDArray[np.bool]) + +assert_type(np.datetime_as_string(M), np.str_) +assert_type(np.datetime_as_string(AR_M), npt.NDArray[np.str_]) + +assert_type(np.busdaycalendar(holidays=date_seq), np.busdaycalendar) +assert_type(np.busdaycalendar(holidays=[M]), np.busdaycalendar) + +assert_type(np.char.compare_chararrays("a", "b", "!=", rstrip=False), npt.NDArray[np.bool]) +assert_type(np.char.compare_chararrays(b"a", b"a", "==", True), npt.NDArray[np.bool]) + +assert_type(np.nested_iters([AR_i8, AR_i8], [[0], [1]], flags=["c_index"]), tuple[np.nditer, ...]) +assert_type(np.nested_iters([AR_i8, AR_i8], [[0], [1]], op_flags=[["readonly", "readonly"]]), tuple[np.nditer, ...]) +assert_type(np.nested_iters([AR_i8, AR_i8], [[0], [1]], op_dtypes=np.int_), tuple[np.nditer, ...]) +assert_type(np.nested_iters([AR_i8, AR_i8], [[0], [1]], order="C", casting="no"), tuple[np.nditer, ...]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nbit_base_example.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nbit_base_example.pyi new file mode 100644 index 0000000000000000000000000000000000000000..33229660b6f89e7a72a03d4e1e043afd0a378a4d --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nbit_base_example.pyi @@ -0,0 +1,21 @@ +from typing import TypeVar, assert_type + +import numpy as np +import numpy.typing as npt +from numpy._typing import _32Bit, _64Bit + +T1 = TypeVar("T1", bound=npt.NBitBase) # type: ignore[deprecated] # pyright: ignore[reportDeprecated] +T2 = TypeVar("T2", bound=npt.NBitBase) # type: ignore[deprecated] # pyright: ignore[reportDeprecated] + +def add(a: np.floating[T1], b: np.integer[T2]) -> np.floating[T1 | T2]: + return a + b + +i8: np.int64 +i4: np.int32 +f8: np.float64 +f4: np.float32 + +assert_type(add(f8, i8), np.floating[_64Bit]) +assert_type(add(f4, i8), np.floating[_32Bit | _64Bit]) +assert_type(add(f8, i4), np.floating[_32Bit | _64Bit]) +assert_type(add(f4, i4), np.floating[_32Bit]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_assignability.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_assignability.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d754a94003d3369765e4f25ba532e8311012d3b6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_assignability.pyi @@ -0,0 +1,77 @@ +from typing import Protocol, TypeAlias, TypeVar, assert_type + +import numpy as np +from numpy._typing import _64Bit + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +class CanAbs(Protocol[_T_co]): + def __abs__(self, /) -> _T_co: ... + +class CanInvert(Protocol[_T_co]): + def __invert__(self, /) -> _T_co: ... + +class CanNeg(Protocol[_T_co]): + def __neg__(self, /) -> _T_co: ... + +class CanPos(Protocol[_T_co]): + def __pos__(self, /) -> _T_co: ... + +def do_abs(x: CanAbs[_T]) -> _T: ... +def do_invert(x: CanInvert[_T]) -> _T: ... +def do_neg(x: CanNeg[_T]) -> _T: ... +def do_pos(x: CanPos[_T]) -> _T: ... + +_Bool_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.bool]] +_UInt8_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.uint8]] +_Int16_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.int16]] +_LongLong_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longlong]] +_Float32_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float32]] +_Float64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]] +_LongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.longdouble]] +_Complex64_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex64]] +_Complex128_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]] +_CLongDouble_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.clongdouble]] + +b1_1d: _Bool_1d +u1_1d: _UInt8_1d +i2_1d: _Int16_1d +q_1d: _LongLong_1d +f4_1d: _Float32_1d +f8_1d: _Float64_1d +g_1d: _LongDouble_1d +c8_1d: _Complex64_1d +c16_1d: _Complex128_1d +G_1d: _CLongDouble_1d + +assert_type(do_abs(b1_1d), _Bool_1d) +assert_type(do_abs(u1_1d), _UInt8_1d) +assert_type(do_abs(i2_1d), _Int16_1d) +assert_type(do_abs(q_1d), _LongLong_1d) +assert_type(do_abs(f4_1d), _Float32_1d) +assert_type(do_abs(f8_1d), _Float64_1d) +assert_type(do_abs(g_1d), _LongDouble_1d) + +assert_type(do_abs(c8_1d), _Float32_1d) +# NOTE: Unfortunately it's not possible to have this return a `float64` sctype, see +# https://github.com/python/mypy/issues/14070 +assert_type(do_abs(c16_1d), np.ndarray[tuple[int], np.dtype[np.floating[_64Bit]]]) +assert_type(do_abs(G_1d), _LongDouble_1d) + +assert_type(do_invert(b1_1d), _Bool_1d) +assert_type(do_invert(u1_1d), _UInt8_1d) +assert_type(do_invert(i2_1d), _Int16_1d) +assert_type(do_invert(q_1d), _LongLong_1d) + +assert_type(do_neg(u1_1d), _UInt8_1d) +assert_type(do_neg(i2_1d), _Int16_1d) +assert_type(do_neg(q_1d), _LongLong_1d) +assert_type(do_neg(f4_1d), _Float32_1d) +assert_type(do_neg(c16_1d), _Complex128_1d) + +assert_type(do_pos(u1_1d), _UInt8_1d) +assert_type(do_pos(i2_1d), _Int16_1d) +assert_type(do_pos(q_1d), _LongLong_1d) +assert_type(do_pos(f4_1d), _Float32_1d) +assert_type(do_pos(c16_1d), _Complex128_1d) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_conversion.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_conversion.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bbd42573a774fa12df006265bc4fea8010e09dcc --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_conversion.pyi @@ -0,0 +1,85 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +b1_0d: np.ndarray[tuple[()], np.dtype[np.bool]] +u2_1d: np.ndarray[tuple[int], np.dtype[np.uint16]] +i4_2d: np.ndarray[tuple[int, int], np.dtype[np.int32]] +f8_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float64]] +cG_4d: np.ndarray[tuple[int, int, int, int], np.dtype[np.clongdouble]] +i0_nd: npt.NDArray[np.int_] +uncertain_dtype: np.int32 | np.float64 | np.str_ + +# item +assert_type(i0_nd.item(), int) +assert_type(i0_nd.item(1), int) +assert_type(i0_nd.item(0, 1), int) +assert_type(i0_nd.item((0, 1)), int) + +assert_type(b1_0d.item(()), bool) +assert_type(u2_1d.item((0,)), int) +assert_type(i4_2d.item(-1, 2), int) +assert_type(f8_3d.item(2, 1, -1), float) +assert_type(cG_4d.item(-0xEd_fed_Deb_a_dead_bee), complex) # c'mon Ed, we talked about this... + +# tolist +assert_type(b1_0d.tolist(), bool) +assert_type(u2_1d.tolist(), list[int]) +assert_type(i4_2d.tolist(), list[list[int]]) +assert_type(f8_3d.tolist(), list[list[list[float]]]) +assert_type(cG_4d.tolist(), Any) +assert_type(i0_nd.tolist(), Any) + +# regression tests for numpy/numpy#27944 +any_dtype: np.ndarray[Any, Any] +any_sctype: np.ndarray[Any, Any] +assert_type(any_dtype.tolist(), Any) +assert_type(any_sctype.tolist(), Any) + + +# itemset does not return a value +# tobytes is pretty simple +# tofile does not return a value +# dump does not return a value +# dumps is pretty simple + +# astype +assert_type(i0_nd.astype("float"), npt.NDArray[Any]) +assert_type(i0_nd.astype(float), npt.NDArray[Any]) +assert_type(i0_nd.astype(np.float64), npt.NDArray[np.float64]) +assert_type(i0_nd.astype(np.float64, "K"), npt.NDArray[np.float64]) +assert_type(i0_nd.astype(np.float64, "K", "unsafe"), npt.NDArray[np.float64]) +assert_type(i0_nd.astype(np.float64, "K", "unsafe", True), npt.NDArray[np.float64]) +assert_type(i0_nd.astype(np.float64, "K", "unsafe", True, True), npt.NDArray[np.float64]) + +assert_type(np.astype(i0_nd, np.float64), npt.NDArray[np.float64]) + +assert_type(i4_2d.astype(np.uint16), np.ndarray[tuple[int, int], np.dtype[np.uint16]]) +assert_type(np.astype(i4_2d, np.uint16), np.ndarray[tuple[int, int], np.dtype[np.uint16]]) +assert_type(f8_3d.astype(np.int16), np.ndarray[tuple[int, int, int], np.dtype[np.int16]]) +assert_type(np.astype(f8_3d, np.int16), np.ndarray[tuple[int, int, int], np.dtype[np.int16]]) +assert_type(i4_2d.astype(uncertain_dtype), np.ndarray[tuple[int, int], np.dtype[np.generic]]) +assert_type(np.astype(i4_2d, uncertain_dtype), np.ndarray[tuple[int, int], np.dtype]) + +# byteswap +assert_type(i0_nd.byteswap(), npt.NDArray[np.int_]) +assert_type(i0_nd.byteswap(True), npt.NDArray[np.int_]) + +# copy +assert_type(i0_nd.copy(), npt.NDArray[np.int_]) +assert_type(i0_nd.copy("C"), npt.NDArray[np.int_]) + +assert_type(i0_nd.view(), npt.NDArray[np.int_]) +assert_type(i0_nd.view(np.float64), npt.NDArray[np.float64]) +assert_type(i0_nd.view(float), npt.NDArray[Any]) +assert_type(i0_nd.view(np.float64, np.matrix), np.matrix[tuple[int, int], Any]) + +# getfield +assert_type(i0_nd.getfield("float"), npt.NDArray[Any]) +assert_type(i0_nd.getfield(float), npt.NDArray[Any]) +assert_type(i0_nd.getfield(np.float64), npt.NDArray[np.float64]) +assert_type(i0_nd.getfield(np.float64, 8), npt.NDArray[np.float64]) + +# setflags does not return a value +# fill does not return a value diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_misc.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_misc.pyi new file mode 100644 index 0000000000000000000000000000000000000000..4cbb90621ca9728bdb6fe4183e0705310e09c331 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_misc.pyi @@ -0,0 +1,247 @@ +""" +Tests for miscellaneous (non-magic) ``np.ndarray``/``np.generic`` methods. + +More extensive tests are performed for the methods' +function-based counterpart in `../from_numeric.py`. + +""" + +from collections.abc import Iterator +import ctypes as ct +import operator +from types import ModuleType +from typing import Any, Literal, assert_type + +from typing_extensions import CapsuleType + +import numpy as np +import numpy.typing as npt + +class SubClass(npt.NDArray[np.object_]): ... + +f8: np.float64 +i8: np.int64 +B: SubClass +AR_f8: npt.NDArray[np.float64] +AR_i8: npt.NDArray[np.int64] +AR_u1: npt.NDArray[np.uint8] +AR_c8: npt.NDArray[np.complex64] +AR_m: npt.NDArray[np.timedelta64] +AR_U: npt.NDArray[np.str_] +AR_V: npt.NDArray[np.void] + +AR_f8_1d: np.ndarray[tuple[int], np.dtype[np.float64]] +AR_f8_2d: np.ndarray[tuple[int, int], np.dtype[np.float64]] +AR_f8_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float64]] + +ctypes_obj = AR_f8.ctypes + +assert_type(AR_f8.__dlpack__(), CapsuleType) +assert_type(AR_f8.__dlpack_device__(), tuple[Literal[1], Literal[0]]) + +assert_type(ctypes_obj.data, int) +assert_type(ctypes_obj.shape, ct.Array[np.ctypeslib.c_intp]) +assert_type(ctypes_obj.strides, ct.Array[np.ctypeslib.c_intp]) +assert_type(ctypes_obj._as_parameter_, ct.c_void_p) + +assert_type(ctypes_obj.data_as(ct.c_void_p), ct.c_void_p) +assert_type(ctypes_obj.shape_as(ct.c_longlong), ct.Array[ct.c_longlong]) +assert_type(ctypes_obj.strides_as(ct.c_ubyte), ct.Array[ct.c_ubyte]) + +assert_type(f8.all(), np.bool) +assert_type(AR_f8.all(), np.bool) +assert_type(AR_f8.all(axis=0), np.bool | npt.NDArray[np.bool]) +assert_type(AR_f8.all(keepdims=True), np.bool | npt.NDArray[np.bool]) +assert_type(AR_f8.all(out=B), SubClass) + +assert_type(f8.any(), np.bool) +assert_type(AR_f8.any(), np.bool) +assert_type(AR_f8.any(axis=0), np.bool | npt.NDArray[np.bool]) +assert_type(AR_f8.any(keepdims=True), np.bool | npt.NDArray[np.bool]) +assert_type(AR_f8.any(out=B), SubClass) + +assert_type(f8.argmax(), np.intp) +assert_type(AR_f8.argmax(), np.intp) +assert_type(AR_f8.argmax(axis=0), Any) +assert_type(AR_f8.argmax(out=AR_i8), npt.NDArray[np.intp]) + +assert_type(f8.argmin(), np.intp) +assert_type(AR_f8.argmin(), np.intp) +assert_type(AR_f8.argmin(axis=0), Any) +assert_type(AR_f8.argmin(out=AR_i8), npt.NDArray[np.intp]) + +assert_type(f8.argsort(), npt.NDArray[Any]) +assert_type(AR_f8.argsort(), npt.NDArray[Any]) + +assert_type(f8.astype(np.int64).choose([()]), npt.NDArray[Any]) +assert_type(AR_f8.choose([0]), npt.NDArray[Any]) +assert_type(AR_f8.choose([0], out=B), SubClass) + +assert_type(f8.clip(1), npt.NDArray[Any]) +assert_type(AR_f8.clip(1), npt.NDArray[Any]) +assert_type(AR_f8.clip(None, 1), npt.NDArray[Any]) +assert_type(AR_f8.clip(1, out=B), SubClass) +assert_type(AR_f8.clip(None, 1, out=B), SubClass) + +assert_type(f8.compress([0]), npt.NDArray[Any]) +assert_type(AR_f8.compress([0]), npt.NDArray[Any]) +assert_type(AR_f8.compress([0], out=B), SubClass) + +assert_type(f8.conj(), np.float64) +assert_type(AR_f8.conj(), npt.NDArray[np.float64]) +assert_type(B.conj(), SubClass) + +assert_type(f8.conjugate(), np.float64) +assert_type(AR_f8.conjugate(), npt.NDArray[np.float64]) +assert_type(B.conjugate(), SubClass) + +assert_type(f8.cumprod(), npt.NDArray[Any]) +assert_type(AR_f8.cumprod(), npt.NDArray[Any]) +assert_type(AR_f8.cumprod(out=B), SubClass) + +assert_type(f8.cumsum(), npt.NDArray[Any]) +assert_type(AR_f8.cumsum(), npt.NDArray[Any]) +assert_type(AR_f8.cumsum(out=B), SubClass) + +assert_type(f8.max(), Any) +assert_type(AR_f8.max(), Any) +assert_type(AR_f8.max(axis=0), Any) +assert_type(AR_f8.max(keepdims=True), Any) +assert_type(AR_f8.max(out=B), SubClass) + +assert_type(f8.mean(), Any) +assert_type(AR_f8.mean(), Any) +assert_type(AR_f8.mean(axis=0), Any) +assert_type(AR_f8.mean(keepdims=True), Any) +assert_type(AR_f8.mean(out=B), SubClass) + +assert_type(f8.min(), Any) +assert_type(AR_f8.min(), Any) +assert_type(AR_f8.min(axis=0), Any) +assert_type(AR_f8.min(keepdims=True), Any) +assert_type(AR_f8.min(out=B), SubClass) + +assert_type(f8.prod(), Any) +assert_type(AR_f8.prod(), Any) +assert_type(AR_f8.prod(axis=0), Any) +assert_type(AR_f8.prod(keepdims=True), Any) +assert_type(AR_f8.prod(out=B), SubClass) + +assert_type(f8.round(), np.float64) +assert_type(AR_f8.round(), npt.NDArray[np.float64]) +assert_type(AR_f8.round(out=B), SubClass) + +assert_type(f8.repeat(1), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(f8.repeat(1, axis=0), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(AR_f8.repeat(1), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(AR_f8.repeat(1, axis=0), npt.NDArray[np.float64]) +assert_type(B.repeat(1), np.ndarray[tuple[int], np.dtype[np.object_]]) +assert_type(B.repeat(1, axis=0), npt.NDArray[np.object_]) + +assert_type(f8.std(), Any) +assert_type(AR_f8.std(), Any) +assert_type(AR_f8.std(axis=0), Any) +assert_type(AR_f8.std(keepdims=True), Any) +assert_type(AR_f8.std(out=B), SubClass) + +assert_type(f8.sum(), Any) +assert_type(AR_f8.sum(), Any) +assert_type(AR_f8.sum(axis=0), Any) +assert_type(AR_f8.sum(keepdims=True), Any) +assert_type(AR_f8.sum(out=B), SubClass) + +assert_type(f8.take(0), np.float64) +assert_type(AR_f8.take(0), np.float64) +assert_type(AR_f8.take([0]), npt.NDArray[np.float64]) +assert_type(AR_f8.take(0, out=B), SubClass) +assert_type(AR_f8.take([0], out=B), SubClass) + +assert_type(f8.var(), Any) +assert_type(AR_f8.var(), Any) +assert_type(AR_f8.var(axis=0), Any) +assert_type(AR_f8.var(keepdims=True), Any) +assert_type(AR_f8.var(out=B), SubClass) + +assert_type(AR_f8.argpartition([0]), npt.NDArray[np.intp]) + +assert_type(AR_f8.diagonal(), npt.NDArray[np.float64]) + +assert_type(AR_f8.dot(1), npt.NDArray[Any]) +assert_type(AR_f8.dot([1]), Any) +assert_type(AR_f8.dot(1, out=B), SubClass) + +assert_type(AR_f8.nonzero(), tuple[npt.NDArray[np.intp], ...]) + +assert_type(AR_f8.searchsorted(1), np.intp) +assert_type(AR_f8.searchsorted([1]), npt.NDArray[np.intp]) + +assert_type(AR_f8.trace(), Any) +assert_type(AR_f8.trace(out=B), SubClass) + +assert_type(AR_f8.item(), float) +assert_type(AR_U.item(), str) + +assert_type(AR_f8.ravel(), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(AR_U.ravel(), np.ndarray[tuple[int], np.dtype[np.str_]]) + +assert_type(AR_f8.flatten(), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(AR_U.flatten(), np.ndarray[tuple[int], np.dtype[np.str_]]) + +assert_type(AR_i8.reshape(None), npt.NDArray[np.int64]) +assert_type(AR_f8.reshape(-1), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(AR_c8.reshape(2, 3, 4, 5), np.ndarray[tuple[int, int, int, int], np.dtype[np.complex64]]) +assert_type(AR_m.reshape(()), np.ndarray[tuple[()], np.dtype[np.timedelta64]]) +assert_type(AR_U.reshape([]), np.ndarray[tuple[()], np.dtype[np.str_]]) +assert_type(AR_V.reshape((480, 720, 4)), np.ndarray[tuple[int, int, int], np.dtype[np.void]]) + +assert_type(int(AR_f8), int) +assert_type(int(AR_U), int) + +assert_type(float(AR_f8), float) +assert_type(float(AR_U), float) + +assert_type(complex(AR_f8), complex) + +assert_type(operator.index(AR_i8), int) + +assert_type(AR_f8.__array_wrap__(B), npt.NDArray[np.object_]) + +assert_type(AR_V[0], Any) +assert_type(AR_V[0, 0], Any) +assert_type(AR_V[AR_i8], npt.NDArray[np.void]) +assert_type(AR_V[AR_i8, AR_i8], npt.NDArray[np.void]) +assert_type(AR_V[AR_i8, None], npt.NDArray[np.void]) +assert_type(AR_V[0, ...], npt.NDArray[np.void]) +assert_type(AR_V[[0]], npt.NDArray[np.void]) +assert_type(AR_V[[0], [0]], npt.NDArray[np.void]) +assert_type(AR_V[:], npt.NDArray[np.void]) +assert_type(AR_V["a"], npt.NDArray[Any]) +assert_type(AR_V[["a", "b"]], npt.NDArray[np.void]) + +assert_type(AR_f8.dump("test_file"), None) +assert_type(AR_f8.dump(b"test_file"), None) +with open("test_file", "wb") as f: + assert_type(AR_f8.dump(f), None) + +assert_type(AR_f8.__array_finalize__(None), None) +assert_type(AR_f8.__array_finalize__(B), None) +assert_type(AR_f8.__array_finalize__(AR_f8), None) + +assert_type(f8.device, Literal["cpu"]) +assert_type(AR_f8.device, Literal["cpu"]) + +assert_type(f8.to_device("cpu"), np.float64) +assert_type(i8.to_device("cpu"), np.int64) +assert_type(AR_f8.to_device("cpu"), npt.NDArray[np.float64]) +assert_type(AR_i8.to_device("cpu"), npt.NDArray[np.int64]) +assert_type(AR_u1.to_device("cpu"), npt.NDArray[np.uint8]) +assert_type(AR_c8.to_device("cpu"), npt.NDArray[np.complex64]) +assert_type(AR_m.to_device("cpu"), npt.NDArray[np.timedelta64]) + +assert_type(f8.__array_namespace__(), ModuleType) +assert_type(AR_f8.__array_namespace__(), ModuleType) + +assert_type(iter(AR_f8), Iterator[Any]) # any-D +assert_type(iter(AR_f8_1d), Iterator[np.float64]) # 1-D +assert_type(iter(AR_f8_2d), Iterator[npt.NDArray[np.float64]]) # 2-D +assert_type(iter(AR_f8_3d), Iterator[npt.NDArray[np.float64]]) # 3-D diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi new file mode 100644 index 0000000000000000000000000000000000000000..4447bb13d2ad9abe78dddda6f08ee933f4cd050c --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi @@ -0,0 +1,39 @@ +from typing import assert_type + +import numpy as np +import numpy.typing as npt + +nd: npt.NDArray[np.int64] + +# reshape +assert_type(nd.reshape(None), npt.NDArray[np.int64]) +assert_type(nd.reshape(4), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(nd.reshape((4,)), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(nd.reshape(2, 2), np.ndarray[tuple[int, int], np.dtype[np.int64]]) +assert_type(nd.reshape((2, 2)), np.ndarray[tuple[int, int], np.dtype[np.int64]]) + +assert_type(nd.reshape((2, 2), order="C"), np.ndarray[tuple[int, int], np.dtype[np.int64]]) +assert_type(nd.reshape(4, order="C"), np.ndarray[tuple[int], np.dtype[np.int64]]) + +# resize does not return a value + +# transpose +assert_type(nd.transpose(), npt.NDArray[np.int64]) +assert_type(nd.transpose(1, 0), npt.NDArray[np.int64]) +assert_type(nd.transpose((1, 0)), npt.NDArray[np.int64]) + +# swapaxes +assert_type(nd.swapaxes(0, 1), npt.NDArray[np.int64]) + +# flatten +assert_type(nd.flatten(), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(nd.flatten("C"), np.ndarray[tuple[int], np.dtype[np.int64]]) + +# ravel +assert_type(nd.ravel(), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(nd.ravel("C"), np.ndarray[tuple[int], np.dtype[np.int64]]) + +# squeeze +assert_type(nd.squeeze(), npt.NDArray[np.int64]) +assert_type(nd.squeeze(0), npt.NDArray[np.int64]) +assert_type(nd.squeeze((0, 2)), npt.NDArray[np.int64]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nditer.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nditer.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8965f3c03e6de46740c7b1aa8978f1730a0ac799 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nditer.pyi @@ -0,0 +1,49 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +nditer_obj: np.nditer + +assert_type(np.nditer([0, 1], flags=["c_index"]), np.nditer) +assert_type(np.nditer([0, 1], op_flags=[["readonly", "readonly"]]), np.nditer) +assert_type(np.nditer([0, 1], op_dtypes=np.int_), np.nditer) +assert_type(np.nditer([0, 1], order="C", casting="no"), np.nditer) + +assert_type(nditer_obj.dtypes, tuple[np.dtype, ...]) +assert_type(nditer_obj.finished, bool) +assert_type(nditer_obj.has_delayed_bufalloc, bool) +assert_type(nditer_obj.has_index, bool) +assert_type(nditer_obj.has_multi_index, bool) +assert_type(nditer_obj.index, int) +assert_type(nditer_obj.iterationneedsapi, bool) +assert_type(nditer_obj.iterindex, int) +assert_type(nditer_obj.iterrange, tuple[int, ...]) +assert_type(nditer_obj.itersize, int) +assert_type(nditer_obj.itviews, tuple[npt.NDArray[Any], ...]) +assert_type(nditer_obj.multi_index, tuple[int, ...]) +assert_type(nditer_obj.ndim, int) +assert_type(nditer_obj.nop, int) +assert_type(nditer_obj.operands, tuple[npt.NDArray[Any], ...]) +assert_type(nditer_obj.shape, tuple[int, ...]) +assert_type(nditer_obj.value, tuple[npt.NDArray[Any], ...]) + +assert_type(nditer_obj.close(), None) +assert_type(nditer_obj.copy(), np.nditer) +assert_type(nditer_obj.debug_print(), None) +assert_type(nditer_obj.enable_external_loop(), None) +assert_type(nditer_obj.iternext(), bool) +assert_type(nditer_obj.remove_axis(0), None) +assert_type(nditer_obj.remove_multi_index(), None) +assert_type(nditer_obj.reset(), None) + +assert_type(len(nditer_obj), int) +assert_type(iter(nditer_obj), np.nditer) +assert_type(next(nditer_obj), tuple[npt.NDArray[Any], ...]) +assert_type(nditer_obj.__copy__(), np.nditer) +with nditer_obj as f: + assert_type(f, np.nditer) +assert_type(nditer_obj[0], npt.NDArray[Any]) +assert_type(nditer_obj[:], tuple[npt.NDArray[Any], ...]) +nditer_obj[0] = 0 +nditer_obj[:] = [0, 1] diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nested_sequence.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nested_sequence.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b4f98b79c3339fa856a09cd639961b5bb8156fd3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/nested_sequence.pyi @@ -0,0 +1,25 @@ +from collections.abc import Sequence +from typing import Any, assert_type + +from numpy._typing import _NestedSequence + +a: Sequence[int] +b: Sequence[Sequence[int]] +c: Sequence[Sequence[Sequence[int]]] +d: Sequence[Sequence[Sequence[Sequence[int]]]] +e: Sequence[bool] +f: tuple[int, ...] +g: list[int] +h: Sequence[Any] + +def func(a: _NestedSequence[int]) -> None: ... + +assert_type(func(a), None) +assert_type(func(b), None) +assert_type(func(c), None) +assert_type(func(d), None) +assert_type(func(e), None) +assert_type(func(f), None) +assert_type(func(g), None) +assert_type(func(h), None) +assert_type(func(range(15)), None) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/npyio.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/npyio.pyi new file mode 100644 index 0000000000000000000000000000000000000000..40da72c8544eac029383c3314a3423e04b3b0a6c --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/npyio.pyi @@ -0,0 +1,83 @@ +import pathlib +import re +import zipfile +from collections.abc import Mapping +from typing import IO, Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy.lib._npyio_impl import BagObj + +str_path: str +pathlib_path: pathlib.Path +str_file: IO[str] +bytes_file: IO[bytes] + +npz_file: np.lib.npyio.NpzFile + +AR_i8: npt.NDArray[np.int64] +AR_LIKE_f8: list[float] + +class BytesWriter: + def write(self, data: bytes) -> None: ... + +class BytesReader: + def read(self, n: int = ...) -> bytes: ... + def seek(self, offset: int, whence: int = ...) -> int: ... + +bytes_writer: BytesWriter +bytes_reader: BytesReader + +assert_type(npz_file.zip, zipfile.ZipFile) +assert_type(npz_file.fid, IO[str] | None) +assert_type(npz_file.files, list[str]) +assert_type(npz_file.allow_pickle, bool) +assert_type(npz_file.pickle_kwargs, Mapping[str, Any] | None) +assert_type(npz_file.f, BagObj[np.lib.npyio.NpzFile]) +assert_type(npz_file["test"], npt.NDArray[Any]) +assert_type(len(npz_file), int) +with npz_file as f: + assert_type(f, np.lib.npyio.NpzFile) + +assert_type(np.load(bytes_file), Any) +assert_type(np.load(pathlib_path, allow_pickle=True), Any) +assert_type(np.load(str_path, encoding="bytes"), Any) +assert_type(np.load(bytes_reader), Any) + +assert_type(np.save(bytes_file, AR_LIKE_f8), None) +assert_type(np.save(pathlib_path, AR_i8, allow_pickle=True), None) +assert_type(np.save(str_path, AR_LIKE_f8), None) +assert_type(np.save(bytes_writer, AR_LIKE_f8), None) + +assert_type(np.savez(bytes_file, AR_LIKE_f8), None) +assert_type(np.savez(pathlib_path, ar1=AR_i8, ar2=AR_i8), None) +assert_type(np.savez(str_path, AR_LIKE_f8, ar1=AR_i8), None) +assert_type(np.savez(bytes_writer, AR_LIKE_f8, ar1=AR_i8), None) + +assert_type(np.savez_compressed(bytes_file, AR_LIKE_f8), None) +assert_type(np.savez_compressed(pathlib_path, ar1=AR_i8, ar2=AR_i8), None) +assert_type(np.savez_compressed(str_path, AR_LIKE_f8, ar1=AR_i8), None) +assert_type(np.savez_compressed(bytes_writer, AR_LIKE_f8, ar1=AR_i8), None) + +assert_type(np.loadtxt(bytes_file), npt.NDArray[np.float64]) +assert_type(np.loadtxt(pathlib_path, dtype=np.str_), npt.NDArray[np.str_]) +assert_type(np.loadtxt(str_path, dtype=str, skiprows=2), npt.NDArray[Any]) +assert_type(np.loadtxt(str_file, comments="test"), npt.NDArray[np.float64]) +assert_type(np.loadtxt(str_file, comments=None), npt.NDArray[np.float64]) +assert_type(np.loadtxt(str_path, delimiter="\n"), npt.NDArray[np.float64]) +assert_type(np.loadtxt(str_path, ndmin=2), npt.NDArray[np.float64]) +assert_type(np.loadtxt(["1", "2", "3"]), npt.NDArray[np.float64]) + +assert_type(np.fromregex(bytes_file, "test", np.float64), npt.NDArray[np.float64]) +assert_type(np.fromregex(str_file, b"test", dtype=float), npt.NDArray[Any]) +assert_type(np.fromregex(str_path, re.compile("test"), dtype=np.str_, encoding="utf8"), npt.NDArray[np.str_]) +assert_type(np.fromregex(pathlib_path, "test", np.float64), npt.NDArray[np.float64]) +assert_type(np.fromregex(bytes_reader, "test", np.float64), npt.NDArray[np.float64]) + +assert_type(np.genfromtxt(bytes_file), npt.NDArray[Any]) +assert_type(np.genfromtxt(pathlib_path, dtype=np.str_), npt.NDArray[np.str_]) +assert_type(np.genfromtxt(str_path, dtype=str, skip_header=2), npt.NDArray[Any]) +assert_type(np.genfromtxt(str_file, comments="test"), npt.NDArray[Any]) +assert_type(np.genfromtxt(str_path, delimiter="\n"), npt.NDArray[Any]) +assert_type(np.genfromtxt(str_path, ndmin=2), npt.NDArray[Any]) +assert_type(np.genfromtxt(["1", "2", "3"], ndmin=2), npt.NDArray[Any]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/numeric.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/numeric.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7c1ea8958e3bb005467a18179c234860a2a54647 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/numeric.pyi @@ -0,0 +1,134 @@ +""" +Tests for :mod:`_core.numeric`. + +Does not include tests which fall under ``array_constructors``. + +""" + +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +class SubClass(npt.NDArray[np.int64]): ... + +i8: np.int64 + +AR_b: npt.NDArray[np.bool] +AR_u8: npt.NDArray[np.uint64] +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +AR_m: npt.NDArray[np.timedelta64] +AR_O: npt.NDArray[np.object_] + +B: list[int] +C: SubClass + +assert_type(np.count_nonzero(i8), np.intp) +assert_type(np.count_nonzero(AR_i8), np.intp) +assert_type(np.count_nonzero(B), np.intp) +assert_type(np.count_nonzero(AR_i8, keepdims=True), npt.NDArray[np.intp]) +assert_type(np.count_nonzero(AR_i8, axis=0), Any) + +assert_type(np.isfortran(i8), bool) +assert_type(np.isfortran(AR_i8), bool) + +assert_type(np.argwhere(i8), npt.NDArray[np.intp]) +assert_type(np.argwhere(AR_i8), npt.NDArray[np.intp]) + +assert_type(np.flatnonzero(i8), npt.NDArray[np.intp]) +assert_type(np.flatnonzero(AR_i8), npt.NDArray[np.intp]) + +assert_type(np.correlate(B, AR_i8, mode="valid"), npt.NDArray[np.signedinteger]) +assert_type(np.correlate(AR_i8, AR_i8, mode="same"), npt.NDArray[np.signedinteger]) +assert_type(np.correlate(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.correlate(AR_b, AR_u8), npt.NDArray[np.unsignedinteger]) +assert_type(np.correlate(AR_i8, AR_b), npt.NDArray[np.signedinteger]) +assert_type(np.correlate(AR_i8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.correlate(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.correlate(AR_i8, AR_m), npt.NDArray[np.timedelta64]) +assert_type(np.correlate(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.convolve(B, AR_i8, mode="valid"), npt.NDArray[np.signedinteger]) +assert_type(np.convolve(AR_i8, AR_i8, mode="same"), npt.NDArray[np.signedinteger]) +assert_type(np.convolve(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.convolve(AR_b, AR_u8), npt.NDArray[np.unsignedinteger]) +assert_type(np.convolve(AR_i8, AR_b), npt.NDArray[np.signedinteger]) +assert_type(np.convolve(AR_i8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.convolve(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.convolve(AR_i8, AR_m), npt.NDArray[np.timedelta64]) +assert_type(np.convolve(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.outer(i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.outer(B, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.outer(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.outer(AR_i8, AR_i8, out=C), SubClass) +assert_type(np.outer(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.outer(AR_b, AR_u8), npt.NDArray[np.unsignedinteger]) +assert_type(np.outer(AR_i8, AR_b), npt.NDArray[np.signedinteger]) +assert_type(np.convolve(AR_i8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.outer(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.outer(AR_i8, AR_m), npt.NDArray[np.timedelta64]) +assert_type(np.outer(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.tensordot(B, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.tensordot(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.tensordot(AR_i8, AR_i8, axes=0), npt.NDArray[np.signedinteger]) +assert_type(np.tensordot(AR_i8, AR_i8, axes=(0, 1)), npt.NDArray[np.signedinteger]) +assert_type(np.tensordot(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.tensordot(AR_b, AR_u8), npt.NDArray[np.unsignedinteger]) +assert_type(np.tensordot(AR_i8, AR_b), npt.NDArray[np.signedinteger]) +assert_type(np.tensordot(AR_i8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.tensordot(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.tensordot(AR_i8, AR_m), npt.NDArray[np.timedelta64]) +assert_type(np.tensordot(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.isscalar(i8), bool) +assert_type(np.isscalar(AR_i8), bool) +assert_type(np.isscalar(B), bool) + +assert_type(np.roll(AR_i8, 1), npt.NDArray[np.int64]) +assert_type(np.roll(AR_i8, (1, 2)), npt.NDArray[np.int64]) +assert_type(np.roll(B, 1), npt.NDArray[Any]) + +assert_type(np.rollaxis(AR_i8, 0, 1), npt.NDArray[np.int64]) + +assert_type(np.moveaxis(AR_i8, 0, 1), npt.NDArray[np.int64]) +assert_type(np.moveaxis(AR_i8, (0, 1), (1, 2)), npt.NDArray[np.int64]) + +assert_type(np.cross(B, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.cross(AR_i8, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.cross(AR_b, AR_u8), npt.NDArray[np.unsignedinteger]) +assert_type(np.cross(AR_i8, AR_b), npt.NDArray[np.signedinteger]) +assert_type(np.cross(AR_i8, AR_f8), npt.NDArray[np.floating]) +assert_type(np.cross(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(np.cross(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(np.indices([0, 1, 2]), npt.NDArray[np.int_]) +assert_type(np.indices([0, 1, 2], sparse=True), tuple[npt.NDArray[np.int_], ...]) +assert_type(np.indices([0, 1, 2], dtype=np.float64), npt.NDArray[np.float64]) +assert_type(np.indices([0, 1, 2], sparse=True, dtype=np.float64), tuple[npt.NDArray[np.float64], ...]) +assert_type(np.indices([0, 1, 2], dtype=float), npt.NDArray[Any]) +assert_type(np.indices([0, 1, 2], sparse=True, dtype=float), tuple[npt.NDArray[Any], ...]) + +assert_type(np.binary_repr(1), str) + +assert_type(np.base_repr(1), str) + +assert_type(np.allclose(i8, AR_i8), bool) +assert_type(np.allclose(B, AR_i8), bool) +assert_type(np.allclose(AR_i8, AR_i8), bool) + +assert_type(np.isclose(i8, i8), np.bool) +assert_type(np.isclose(i8, AR_i8), npt.NDArray[np.bool]) +assert_type(np.isclose(B, AR_i8), npt.NDArray[np.bool]) +assert_type(np.isclose(AR_i8, AR_i8), npt.NDArray[np.bool]) + +assert_type(np.array_equal(i8, AR_i8), bool) +assert_type(np.array_equal(B, AR_i8), bool) +assert_type(np.array_equal(AR_i8, AR_i8), bool) + +assert_type(np.array_equiv(i8, AR_i8), bool) +assert_type(np.array_equiv(B, AR_i8), bool) +assert_type(np.array_equiv(AR_i8, AR_i8), bool) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/numerictypes.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/numerictypes.pyi new file mode 100644 index 0000000000000000000000000000000000000000..75d108ce5a0ffb3fdbd6cf4c4712909298ad5780 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/numerictypes.pyi @@ -0,0 +1,16 @@ +from typing import Literal, assert_type + +import numpy as np + +assert_type(np.ScalarType[0], type[int]) +assert_type(np.ScalarType[3], type[bool]) +assert_type(np.ScalarType[8], type[np.complex64]) +assert_type(np.ScalarType[9], type[np.complex128]) +assert_type(np.ScalarType[-1], type[np.void]) +assert_type(np.bool_(object()), np.bool) + +assert_type(np.typecodes["Character"], Literal["c"]) +assert_type(np.typecodes["Complex"], Literal["FDG"]) +assert_type(np.typecodes["All"], Literal["?bhilqnpBHILQNPefdgFDGSUVOMm"]) + +assert_type(np.sctypeDict["uint8"], type[np.generic]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_polybase.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_polybase.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bb927035e40cc4f4aa4b2229fef8cb7f94d15884 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_polybase.pyi @@ -0,0 +1,220 @@ +from collections.abc import Sequence +from decimal import Decimal +from fractions import Fraction +from typing import Any, LiteralString, TypeAlias, TypeVar, assert_type +from typing import Literal as L + +import numpy as np +import numpy.polynomial as npp +import numpy.typing as npt + +_Ar_x: TypeAlias = npt.NDArray[np.inexact | np.object_] +_Ar_f: TypeAlias = npt.NDArray[np.floating] +_Ar_c: TypeAlias = npt.NDArray[np.complexfloating] +_Ar_O: TypeAlias = npt.NDArray[np.object_] + +_Ar_x_n: TypeAlias = np.ndarray[tuple[int], np.dtype[np.inexact | np.object_]] +_Ar_f_n: TypeAlias = np.ndarray[tuple[int], np.dtype[np.floating]] +_Ar_c_n: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complexfloating]] +_Ar_O_n: TypeAlias = np.ndarray[tuple[int], np.dtype[np.object_]] + +_Ar_x_2: TypeAlias = np.ndarray[tuple[L[2]], np.dtype[np.inexact | np.object_]] +_Ar_f_2: TypeAlias = np.ndarray[tuple[L[2]], np.dtype[np.floating]] +_Ar_c_2: TypeAlias = np.ndarray[tuple[L[2]], np.dtype[np.complexfloating]] +_Ar_O_2: TypeAlias = np.ndarray[tuple[L[2]], np.dtype[np.object_]] + +_ScalarT = TypeVar("_ScalarT", bound=np.generic) +_Ar_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]] + +_BasisName: TypeAlias = L["X"] + +SC_i: np.int_ +SC_i_co: int | np.int_ +SC_f: np.float64 +SC_f_co: float | np.float64 | np.int_ +SC_c: np.complex128 +SC_c_co: complex | np.complex128 +SC_O: Decimal + +AR_i: npt.NDArray[np.int_] +AR_f: npt.NDArray[np.float64] +AR_f_co: npt.NDArray[np.float64] | npt.NDArray[np.int_] +AR_c: npt.NDArray[np.complex128] +AR_c_co: npt.NDArray[np.complex128] | npt.NDArray[np.float64] | npt.NDArray[np.int_] +AR_O: npt.NDArray[np.object_] +AR_O_co: npt.NDArray[np.object_ | np.number] + +SQ_i: Sequence[int] +SQ_f: Sequence[float] +SQ_c: Sequence[complex] +SQ_O: Sequence[Decimal] + +PS_poly: npp.Polynomial +PS_cheb: npp.Chebyshev +PS_herm: npp.Hermite +PS_herme: npp.HermiteE +PS_lag: npp.Laguerre +PS_leg: npp.Legendre +PS_all: ( + npp.Polynomial + | npp.Chebyshev + | npp.Hermite + | npp.HermiteE + | npp.Laguerre + | npp.Legendre +) + +# static- and classmethods + +assert_type(type(PS_poly).basis_name, None) +assert_type(type(PS_cheb).basis_name, L['T']) +assert_type(type(PS_herm).basis_name, L['H']) +assert_type(type(PS_herme).basis_name, L['He']) +assert_type(type(PS_lag).basis_name, L['L']) +assert_type(type(PS_leg).basis_name, L['P']) + +assert_type(type(PS_all).__hash__, None) +assert_type(type(PS_all).__array_ufunc__, None) +assert_type(type(PS_all).maxpower, L[100]) + +assert_type(type(PS_poly).fromroots(SC_i), npp.Polynomial) +assert_type(type(PS_poly).fromroots(SQ_i), npp.Polynomial) +assert_type(type(PS_poly).fromroots(AR_i), npp.Polynomial) +assert_type(type(PS_cheb).fromroots(SC_f), npp.Chebyshev) +assert_type(type(PS_cheb).fromroots(SQ_f), npp.Chebyshev) +assert_type(type(PS_cheb).fromroots(AR_f_co), npp.Chebyshev) +assert_type(type(PS_herm).fromroots(SC_c), npp.Hermite) +assert_type(type(PS_herm).fromroots(SQ_c), npp.Hermite) +assert_type(type(PS_herm).fromroots(AR_c_co), npp.Hermite) +assert_type(type(PS_leg).fromroots(SC_O), npp.Legendre) +assert_type(type(PS_leg).fromroots(SQ_O), npp.Legendre) +assert_type(type(PS_leg).fromroots(AR_O_co), npp.Legendre) + +assert_type(type(PS_poly).identity(), npp.Polynomial) +assert_type(type(PS_cheb).identity(symbol='z'), npp.Chebyshev) + +assert_type(type(PS_lag).basis(SC_i), npp.Laguerre) +assert_type(type(PS_leg).basis(32, symbol='u'), npp.Legendre) + +assert_type(type(PS_herm).cast(PS_poly), npp.Hermite) +assert_type(type(PS_herme).cast(PS_leg), npp.HermiteE) + +# attributes / properties + +assert_type(PS_all.coef, _Ar_x_n) +assert_type(PS_all.domain, _Ar_x_2) +assert_type(PS_all.window, _Ar_x_2) +assert_type(PS_all.symbol, LiteralString) + +# instance methods + +assert_type(PS_all.has_samecoef(PS_all), bool) +assert_type(PS_all.has_samedomain(PS_all), bool) +assert_type(PS_all.has_samewindow(PS_all), bool) +assert_type(PS_all.has_sametype(PS_all), bool) +assert_type(PS_poly.has_sametype(PS_poly), bool) +assert_type(PS_poly.has_sametype(PS_leg), bool) +assert_type(PS_poly.has_sametype(NotADirectoryError), L[False]) + +assert_type(PS_poly.copy(), npp.Polynomial) +assert_type(PS_cheb.copy(), npp.Chebyshev) +assert_type(PS_herm.copy(), npp.Hermite) +assert_type(PS_herme.copy(), npp.HermiteE) +assert_type(PS_lag.copy(), npp.Laguerre) +assert_type(PS_leg.copy(), npp.Legendre) + +assert_type(PS_leg.cutdeg(), npp.Legendre) +assert_type(PS_leg.trim(), npp.Legendre) +assert_type(PS_leg.trim(tol=SC_f_co), npp.Legendre) +assert_type(PS_leg.truncate(SC_i_co), npp.Legendre) + +assert_type(PS_all.convert(None, npp.Chebyshev), npp.Chebyshev) +assert_type(PS_all.convert((0, 1), npp.Laguerre), npp.Laguerre) +assert_type(PS_all.convert([0, 1], npp.Hermite, [-1, 1]), npp.Hermite) + +assert_type(PS_all.degree(), int) +assert_type(PS_all.mapparms(), tuple[Any, Any]) + +assert_type(PS_poly.integ(), npp.Polynomial) +assert_type(PS_herme.integ(SC_i_co), npp.HermiteE) +assert_type(PS_lag.integ(SC_i_co, SC_f_co), npp.Laguerre) +assert_type(PS_poly.deriv(), npp.Polynomial) +assert_type(PS_herm.deriv(SC_i_co), npp.Hermite) + +assert_type(PS_poly.roots(), _Ar_x_n) + +assert_type( + PS_poly.linspace(), + tuple[_Ar_1d[np.float64 | np.complex128], _Ar_1d[np.float64 | np.complex128]], +) + +assert_type( + PS_poly.linspace(9), + tuple[_Ar_1d[np.float64 | np.complex128], _Ar_1d[np.float64 | np.complex128]], +) + +assert_type(PS_cheb.fit(AR_c_co, AR_c_co, SC_i_co), npp.Chebyshev) +assert_type(PS_leg.fit(AR_c_co, AR_c_co, AR_i), npp.Legendre) +assert_type(PS_herm.fit(AR_c_co, AR_c_co, SQ_i), npp.Hermite) +assert_type(PS_poly.fit(AR_c_co, SQ_c, SQ_i), npp.Polynomial) +assert_type(PS_lag.fit(SQ_c, SQ_c, SQ_i, full=False), npp.Laguerre) +assert_type( + PS_herme.fit(SQ_c, AR_c_co, SC_i_co, full=True), + tuple[npp.HermiteE, Sequence[np.inexact | np.int32]], +) + +# custom operations + +assert_type(PS_all.__hash__, None) +assert_type(PS_all.__array_ufunc__, None) + +assert_type(str(PS_all), str) +assert_type(repr(PS_all), str) +assert_type(format(PS_all), str) + +assert_type(len(PS_all), int) +assert_type(next(iter(PS_all)), np.inexact | object) + +assert_type(PS_all(SC_f_co), np.float64 | np.complex128) +assert_type(PS_all(SC_c_co), np.complex128) +assert_type(PS_all(Decimal()), np.float64 | np.complex128) +assert_type(PS_all(Fraction()), np.float64 | np.complex128) +assert_type(PS_poly(SQ_f), npt.NDArray[np.float64] | npt.NDArray[np.complex128] | npt.NDArray[np.object_]) +assert_type(PS_poly(SQ_c), npt.NDArray[np.complex128] | npt.NDArray[np.object_]) +assert_type(PS_poly(SQ_O), npt.NDArray[np.object_]) +assert_type(PS_poly(AR_f), npt.NDArray[np.float64] | npt.NDArray[np.complex128] | npt.NDArray[np.object_]) +assert_type(PS_poly(AR_c), npt.NDArray[np.complex128] | npt.NDArray[np.object_]) +assert_type(PS_poly(AR_O), npt.NDArray[np.object_]) +assert_type(PS_all(PS_poly), npp.Polynomial) + +assert_type(PS_poly == PS_poly, bool) +assert_type(PS_poly != PS_poly, bool) + +assert_type(-PS_poly, npp.Polynomial) +assert_type(+PS_poly, npp.Polynomial) + +assert_type(PS_poly + 5, npp.Polynomial) +assert_type(PS_poly - 5, npp.Polynomial) +assert_type(PS_poly * 5, npp.Polynomial) +assert_type(PS_poly / 5, npp.Polynomial) +assert_type(PS_poly // 5, npp.Polynomial) +assert_type(PS_poly % 5, npp.Polynomial) + +assert_type(PS_poly + PS_leg, npp.Polynomial) +assert_type(PS_poly - PS_leg, npp.Polynomial) +assert_type(PS_poly * PS_leg, npp.Polynomial) +assert_type(PS_poly / PS_leg, npp.Polynomial) +assert_type(PS_poly // PS_leg, npp.Polynomial) +assert_type(PS_poly % PS_leg, npp.Polynomial) + +assert_type(5 + PS_poly, npp.Polynomial) +assert_type(5 - PS_poly, npp.Polynomial) +assert_type(5 * PS_poly, npp.Polynomial) +assert_type(5 / PS_poly, npp.Polynomial) +assert_type(5 // PS_poly, npp.Polynomial) +assert_type(5 % PS_poly, npp.Polynomial) +assert_type(divmod(PS_poly, 5), tuple[npp.Polynomial, npp.Polynomial]) +assert_type(divmod(5, PS_poly), tuple[npp.Polynomial, npp.Polynomial]) + +assert_type(PS_poly**1, npp.Polynomial) +assert_type(PS_poly**1.0, npp.Polynomial) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_polyutils.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_polyutils.pyi new file mode 100644 index 0000000000000000000000000000000000000000..45522e72102f0c0b6a48cf35e2dc4ae6e9719c84 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_polyutils.pyi @@ -0,0 +1,219 @@ +from collections.abc import Sequence +from decimal import Decimal +from fractions import Fraction +from typing import Any, TypeAlias, assert_type +from typing import Literal as L + +import numpy as np +import numpy.polynomial.polyutils as pu +import numpy.typing as npt +from numpy.polynomial._polytypes import _Tuple2 + +_ArrFloat1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.floating]] +_ArrComplex1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complexfloating]] +_ArrObject1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.object_]] + +_ArrFloat1D_2: TypeAlias = np.ndarray[tuple[L[2]], np.dtype[np.float64]] +_ArrComplex1D_2: TypeAlias = np.ndarray[tuple[L[2]], np.dtype[np.complex128]] +_ArrObject1D_2: TypeAlias = np.ndarray[tuple[L[2]], np.dtype[np.object_]] + +num_int: int +num_float: float +num_complex: complex +# will result in an `object_` dtype +num_object: Decimal | Fraction + +sct_int: np.int_ +sct_float: np.float64 +sct_complex: np.complex128 +sct_object: np.object_ # doesn't exist at runtime + +arr_int: npt.NDArray[np.int_] +arr_float: npt.NDArray[np.float64] +arr_complex: npt.NDArray[np.complex128] +arr_object: npt.NDArray[np.object_] + +seq_num_int: Sequence[int] +seq_num_float: Sequence[float] +seq_num_complex: Sequence[complex] +seq_num_object: Sequence[Decimal | Fraction] + +seq_sct_int: Sequence[np.int_] +seq_sct_float: Sequence[np.float64] +seq_sct_complex: Sequence[np.complex128] +seq_sct_object: Sequence[np.object_] + +seq_arr_int: Sequence[npt.NDArray[np.int_]] +seq_arr_float: Sequence[npt.NDArray[np.float64]] +seq_arr_complex: Sequence[npt.NDArray[np.complex128]] +seq_arr_object: Sequence[npt.NDArray[np.object_]] + +seq_seq_num_int: Sequence[Sequence[int]] +seq_seq_num_float: Sequence[Sequence[float]] +seq_seq_num_complex: Sequence[Sequence[complex]] +seq_seq_num_object: Sequence[Sequence[Decimal | Fraction]] + +seq_seq_sct_int: Sequence[Sequence[np.int_]] +seq_seq_sct_float: Sequence[Sequence[np.float64]] +seq_seq_sct_complex: Sequence[Sequence[np.complex128]] +seq_seq_sct_object: Sequence[Sequence[np.object_]] # doesn't exist at runtime + +# as_series + +assert_type(pu.as_series(arr_int), list[_ArrFloat1D]) +assert_type(pu.as_series(arr_float), list[_ArrFloat1D]) +assert_type(pu.as_series(arr_complex), list[_ArrComplex1D]) +assert_type(pu.as_series(arr_object), list[_ArrObject1D]) + +assert_type(pu.as_series(seq_num_int), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_num_float), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_num_complex), list[_ArrComplex1D]) +assert_type(pu.as_series(seq_num_object), list[_ArrObject1D]) + +assert_type(pu.as_series(seq_sct_int), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_sct_float), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_sct_complex), list[_ArrComplex1D]) +assert_type(pu.as_series(seq_sct_object), list[_ArrObject1D]) + +assert_type(pu.as_series(seq_arr_int), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_arr_float), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_arr_complex), list[_ArrComplex1D]) +assert_type(pu.as_series(seq_arr_object), list[_ArrObject1D]) + +assert_type(pu.as_series(seq_seq_num_int), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_seq_num_float), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_seq_num_complex), list[_ArrComplex1D]) +assert_type(pu.as_series(seq_seq_num_object), list[_ArrObject1D]) + +assert_type(pu.as_series(seq_seq_sct_int), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_seq_sct_float), list[_ArrFloat1D]) +assert_type(pu.as_series(seq_seq_sct_complex), list[_ArrComplex1D]) +assert_type(pu.as_series(seq_seq_sct_object), list[_ArrObject1D]) + +# trimcoef + +assert_type(pu.trimcoef(num_int), _ArrFloat1D) +assert_type(pu.trimcoef(num_float), _ArrFloat1D) +assert_type(pu.trimcoef(num_complex), _ArrComplex1D) +assert_type(pu.trimcoef(num_object), _ArrObject1D) +assert_type(pu.trimcoef(num_object), _ArrObject1D) + +assert_type(pu.trimcoef(sct_int), _ArrFloat1D) +assert_type(pu.trimcoef(sct_float), _ArrFloat1D) +assert_type(pu.trimcoef(sct_complex), _ArrComplex1D) +assert_type(pu.trimcoef(sct_object), _ArrObject1D) + +assert_type(pu.trimcoef(arr_int), _ArrFloat1D) +assert_type(pu.trimcoef(arr_float), _ArrFloat1D) +assert_type(pu.trimcoef(arr_complex), _ArrComplex1D) +assert_type(pu.trimcoef(arr_object), _ArrObject1D) + +assert_type(pu.trimcoef(seq_num_int), _ArrFloat1D) +assert_type(pu.trimcoef(seq_num_float), _ArrFloat1D) +assert_type(pu.trimcoef(seq_num_complex), _ArrComplex1D) +assert_type(pu.trimcoef(seq_num_object), _ArrObject1D) + +assert_type(pu.trimcoef(seq_sct_int), _ArrFloat1D) +assert_type(pu.trimcoef(seq_sct_float), _ArrFloat1D) +assert_type(pu.trimcoef(seq_sct_complex), _ArrComplex1D) +assert_type(pu.trimcoef(seq_sct_object), _ArrObject1D) + +# getdomain + +assert_type(pu.getdomain(num_int), _ArrFloat1D_2) +assert_type(pu.getdomain(num_float), _ArrFloat1D_2) +assert_type(pu.getdomain(num_complex), _ArrComplex1D_2) +assert_type(pu.getdomain(num_object), _ArrObject1D_2) +assert_type(pu.getdomain(num_object), _ArrObject1D_2) + +assert_type(pu.getdomain(sct_int), _ArrFloat1D_2) +assert_type(pu.getdomain(sct_float), _ArrFloat1D_2) +assert_type(pu.getdomain(sct_complex), _ArrComplex1D_2) +assert_type(pu.getdomain(sct_object), _ArrObject1D_2) + +assert_type(pu.getdomain(arr_int), _ArrFloat1D_2) +assert_type(pu.getdomain(arr_float), _ArrFloat1D_2) +assert_type(pu.getdomain(arr_complex), _ArrComplex1D_2) +assert_type(pu.getdomain(arr_object), _ArrObject1D_2) + +assert_type(pu.getdomain(seq_num_int), _ArrFloat1D_2) +assert_type(pu.getdomain(seq_num_float), _ArrFloat1D_2) +assert_type(pu.getdomain(seq_num_complex), _ArrComplex1D_2) +assert_type(pu.getdomain(seq_num_object), _ArrObject1D_2) + +assert_type(pu.getdomain(seq_sct_int), _ArrFloat1D_2) +assert_type(pu.getdomain(seq_sct_float), _ArrFloat1D_2) +assert_type(pu.getdomain(seq_sct_complex), _ArrComplex1D_2) +assert_type(pu.getdomain(seq_sct_object), _ArrObject1D_2) + +# mapparms + +assert_type(pu.mapparms(seq_num_int, seq_num_int), _Tuple2[float]) +assert_type(pu.mapparms(seq_num_int, seq_num_float), _Tuple2[float]) +assert_type(pu.mapparms(seq_num_float, seq_num_float), _Tuple2[float]) +assert_type(pu.mapparms(seq_num_float, seq_num_complex), _Tuple2[complex]) +assert_type(pu.mapparms(seq_num_complex, seq_num_complex), _Tuple2[complex]) +assert_type(pu.mapparms(seq_num_complex, seq_num_object), _Tuple2[object]) +assert_type(pu.mapparms(seq_num_object, seq_num_object), _Tuple2[object]) + +assert_type(pu.mapparms(seq_sct_int, seq_sct_int), _Tuple2[np.floating]) +assert_type(pu.mapparms(seq_sct_int, seq_sct_float), _Tuple2[np.floating]) +assert_type(pu.mapparms(seq_sct_float, seq_sct_float), _Tuple2[float]) +assert_type(pu.mapparms(seq_sct_float, seq_sct_complex), _Tuple2[complex]) +assert_type(pu.mapparms(seq_sct_complex, seq_sct_complex), _Tuple2[complex]) +assert_type(pu.mapparms(seq_sct_complex, seq_sct_object), _Tuple2[object]) +assert_type(pu.mapparms(seq_sct_object, seq_sct_object), _Tuple2[object]) + +assert_type(pu.mapparms(arr_int, arr_int), _Tuple2[np.floating]) +assert_type(pu.mapparms(arr_int, arr_float), _Tuple2[np.floating]) +assert_type(pu.mapparms(arr_float, arr_float), _Tuple2[np.floating]) +assert_type(pu.mapparms(arr_float, arr_complex), _Tuple2[np.complexfloating]) +assert_type(pu.mapparms(arr_complex, arr_complex), _Tuple2[np.complexfloating]) +assert_type(pu.mapparms(arr_complex, arr_object), _Tuple2[object]) +assert_type(pu.mapparms(arr_object, arr_object), _Tuple2[object]) + +# mapdomain + +assert_type(pu.mapdomain(num_int, seq_num_int, seq_num_int), np.floating) +assert_type(pu.mapdomain(num_int, seq_num_int, seq_num_float), np.floating) +assert_type(pu.mapdomain(num_int, seq_num_float, seq_num_float), np.floating) +assert_type(pu.mapdomain(num_float, seq_num_float, seq_num_float), np.floating) +assert_type(pu.mapdomain(num_float, seq_num_float, seq_num_complex), np.complexfloating) +assert_type(pu.mapdomain(num_float, seq_num_complex, seq_num_complex), np.complexfloating) +assert_type(pu.mapdomain(num_complex, seq_num_complex, seq_num_complex), np.complexfloating) +assert_type(pu.mapdomain(num_complex, seq_num_complex, seq_num_object), object) +assert_type(pu.mapdomain(num_complex, seq_num_object, seq_num_object), object) +assert_type(pu.mapdomain(num_object, seq_num_object, seq_num_object), object) + +assert_type(pu.mapdomain(seq_num_int, seq_num_int, seq_num_int), _ArrFloat1D) +assert_type(pu.mapdomain(seq_num_int, seq_num_int, seq_num_float), _ArrFloat1D) +assert_type(pu.mapdomain(seq_num_int, seq_num_float, seq_num_float), _ArrFloat1D) +assert_type(pu.mapdomain(seq_num_float, seq_num_float, seq_num_float), _ArrFloat1D) +assert_type(pu.mapdomain(seq_num_float, seq_num_float, seq_num_complex), _ArrComplex1D) +assert_type(pu.mapdomain(seq_num_float, seq_num_complex, seq_num_complex), _ArrComplex1D) +assert_type(pu.mapdomain(seq_num_complex, seq_num_complex, seq_num_complex), _ArrComplex1D) +assert_type(pu.mapdomain(seq_num_complex, seq_num_complex, seq_num_object), _ArrObject1D) +assert_type(pu.mapdomain(seq_num_complex, seq_num_object, seq_num_object), _ArrObject1D) +assert_type(pu.mapdomain(seq_num_object, seq_num_object, seq_num_object), _ArrObject1D) + +assert_type(pu.mapdomain(seq_sct_int, seq_sct_int, seq_sct_int), _ArrFloat1D) +assert_type(pu.mapdomain(seq_sct_int, seq_sct_int, seq_sct_float), _ArrFloat1D) +assert_type(pu.mapdomain(seq_sct_int, seq_sct_float, seq_sct_float), _ArrFloat1D) +assert_type(pu.mapdomain(seq_sct_float, seq_sct_float, seq_sct_float), _ArrFloat1D) +assert_type(pu.mapdomain(seq_sct_float, seq_sct_float, seq_sct_complex), _ArrComplex1D) +assert_type(pu.mapdomain(seq_sct_float, seq_sct_complex, seq_sct_complex), _ArrComplex1D) +assert_type(pu.mapdomain(seq_sct_complex, seq_sct_complex, seq_sct_complex), _ArrComplex1D) +assert_type(pu.mapdomain(seq_sct_complex, seq_sct_complex, seq_sct_object), _ArrObject1D) +assert_type(pu.mapdomain(seq_sct_complex, seq_sct_object, seq_sct_object), _ArrObject1D) +assert_type(pu.mapdomain(seq_sct_object, seq_sct_object, seq_sct_object), _ArrObject1D) + +assert_type(pu.mapdomain(arr_int, arr_int, arr_int), _ArrFloat1D) +assert_type(pu.mapdomain(arr_int, arr_int, arr_float), _ArrFloat1D) +assert_type(pu.mapdomain(arr_int, arr_float, arr_float), _ArrFloat1D) +assert_type(pu.mapdomain(arr_float, arr_float, arr_float), _ArrFloat1D) +assert_type(pu.mapdomain(arr_float, arr_float, arr_complex), _ArrComplex1D) +assert_type(pu.mapdomain(arr_float, arr_complex, arr_complex), _ArrComplex1D) +assert_type(pu.mapdomain(arr_complex, arr_complex, arr_complex), _ArrComplex1D) +assert_type(pu.mapdomain(arr_complex, arr_complex, arr_object), _ArrObject1D) +assert_type(pu.mapdomain(arr_complex, arr_object, arr_object), _ArrObject1D) +assert_type(pu.mapdomain(arr_object, arr_object, arr_object), _ArrObject1D) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_series.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_series.pyi new file mode 100644 index 0000000000000000000000000000000000000000..93f0799c818d983b1130cc152c931e587579fdf6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/polynomial_series.pyi @@ -0,0 +1,138 @@ +from collections.abc import Sequence +from typing import Any, TypeAlias, assert_type + +import numpy as np +import numpy.polynomial as npp +import numpy.typing as npt + +_ArrFloat1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.floating]] +_ArrFloat1D64: TypeAlias = np.ndarray[tuple[int], np.dtype[np.float64]] +_ArrComplex1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complexfloating]] +_ArrComplex1D128: TypeAlias = np.ndarray[tuple[int], np.dtype[np.complex128]] +_ArrObject1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.object_]] + +AR_b: npt.NDArray[np.bool] +AR_u4: npt.NDArray[np.uint32] +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] +AR_c16: npt.NDArray[np.complex128] +AR_O: npt.NDArray[np.object_] + +PS_poly: npp.Polynomial +PS_cheb: npp.Chebyshev + +assert_type(npp.polynomial.polyroots(AR_f8), _ArrFloat1D64) +assert_type(npp.polynomial.polyroots(AR_c16), _ArrComplex1D128) +assert_type(npp.polynomial.polyroots(AR_O), _ArrObject1D) + +assert_type(npp.polynomial.polyfromroots(AR_f8), _ArrFloat1D) +assert_type(npp.polynomial.polyfromroots(AR_c16), _ArrComplex1D) +assert_type(npp.polynomial.polyfromroots(AR_O), _ArrObject1D) + +# assert_type(npp.polynomial.polyadd(AR_b, AR_b), NoReturn) +assert_type(npp.polynomial.polyadd(AR_u4, AR_b), _ArrFloat1D) +assert_type(npp.polynomial.polyadd(AR_i8, AR_i8), _ArrFloat1D) +assert_type(npp.polynomial.polyadd(AR_f8, AR_i8), _ArrFloat1D) +assert_type(npp.polynomial.polyadd(AR_i8, AR_c16), _ArrComplex1D) +assert_type(npp.polynomial.polyadd(AR_O, AR_O), _ArrObject1D) + +assert_type(npp.polynomial.polymulx(AR_u4), _ArrFloat1D) +assert_type(npp.polynomial.polymulx(AR_i8), _ArrFloat1D) +assert_type(npp.polynomial.polymulx(AR_f8), _ArrFloat1D) +assert_type(npp.polynomial.polymulx(AR_c16), _ArrComplex1D) +assert_type(npp.polynomial.polymulx(AR_O), _ArrObject1D) + +assert_type(npp.polynomial.polypow(AR_u4, 2), _ArrFloat1D) +assert_type(npp.polynomial.polypow(AR_i8, 2), _ArrFloat1D) +assert_type(npp.polynomial.polypow(AR_f8, 2), _ArrFloat1D) +assert_type(npp.polynomial.polypow(AR_c16, 2), _ArrComplex1D) +assert_type(npp.polynomial.polypow(AR_O, 2), _ArrObject1D) + +# assert_type(npp.polynomial.polyder(PS_poly), npt.NDArray[np.object_]) +assert_type(npp.polynomial.polyder(AR_f8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyder(AR_c16), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyder(AR_O, m=2), npt.NDArray[np.object_]) + +# assert_type(npp.polynomial.polyint(PS_poly), npt.NDArray[np.object_]) +assert_type(npp.polynomial.polyint(AR_f8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyint(AR_f8, k=AR_c16), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyint(AR_O, m=2), npt.NDArray[np.object_]) + +assert_type(npp.polynomial.polyval(AR_b, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval(AR_u4, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval(AR_i8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval(AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyval(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(npp.polynomial.polyval2d(AR_b, AR_b, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval2d(AR_u4, AR_u4, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval2d(AR_i8, AR_i8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval2d(AR_f8, AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval2d(AR_i8, AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyval2d(AR_O, AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(npp.polynomial.polyval3d(AR_b, AR_b, AR_b, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval3d(AR_u4, AR_u4, AR_u4, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval3d(AR_i8, AR_i8, AR_i8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval3d(AR_f8, AR_f8, AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyval3d(AR_i8, AR_i8, AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyval3d(AR_O, AR_O, AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(npp.polynomial.polyvalfromroots(AR_b, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyvalfromroots(AR_u4, AR_b), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyvalfromroots(AR_i8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyvalfromroots(AR_f8, AR_i8), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyvalfromroots(AR_i8, AR_c16), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyvalfromroots(AR_O, AR_O), npt.NDArray[np.object_]) + +assert_type(npp.polynomial.polyvander(AR_f8, 3), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyvander(AR_c16, 3), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyvander(AR_O, 3), npt.NDArray[np.object_]) + +assert_type(npp.polynomial.polyvander2d(AR_f8, AR_f8, [4, 2]), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyvander2d(AR_c16, AR_c16, [4, 2]), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyvander2d(AR_O, AR_O, [4, 2]), npt.NDArray[np.object_]) + +assert_type(npp.polynomial.polyvander3d(AR_f8, AR_f8, AR_f8, [4, 3, 2]), npt.NDArray[np.floating]) +assert_type(npp.polynomial.polyvander3d(AR_c16, AR_c16, AR_c16, [4, 3, 2]), npt.NDArray[np.complexfloating]) +assert_type(npp.polynomial.polyvander3d(AR_O, AR_O, AR_O, [4, 3, 2]), npt.NDArray[np.object_]) + +assert_type( + npp.polynomial.polyfit(AR_f8, AR_f8, 2), + npt.NDArray[np.floating], +) +assert_type( + npp.polynomial.polyfit(AR_f8, AR_i8, 1, full=True), + tuple[npt.NDArray[np.floating], Sequence[np.inexact | np.int32]], +) +assert_type( + npp.polynomial.polyfit(AR_c16, AR_f8, 2), + npt.NDArray[np.complexfloating], +) +assert_type( + npp.polynomial.polyfit(AR_f8, AR_c16, 1, full=True)[0], + npt.NDArray[np.complexfloating], +) + +assert_type(npp.chebyshev.chebgauss(2), tuple[_ArrFloat1D64, _ArrFloat1D64]) + +assert_type(npp.chebyshev.chebweight(AR_f8), npt.NDArray[np.float64]) +assert_type(npp.chebyshev.chebweight(AR_c16), npt.NDArray[np.complex128]) +assert_type(npp.chebyshev.chebweight(AR_O), npt.NDArray[np.object_]) + +assert_type(npp.chebyshev.poly2cheb(AR_f8), _ArrFloat1D) +assert_type(npp.chebyshev.poly2cheb(AR_c16), _ArrComplex1D) +assert_type(npp.chebyshev.poly2cheb(AR_O), _ArrObject1D) + +assert_type(npp.chebyshev.cheb2poly(AR_f8), _ArrFloat1D) +assert_type(npp.chebyshev.cheb2poly(AR_c16), _ArrComplex1D) +assert_type(npp.chebyshev.cheb2poly(AR_O), _ArrObject1D) + +assert_type(npp.chebyshev.chebpts1(6), _ArrFloat1D64) +assert_type(npp.chebyshev.chebpts2(6), _ArrFloat1D64) + +assert_type( + npp.chebyshev.chebinterpolate(np.tanh, 3), + npt.NDArray[np.float64 | np.complex128 | np.object_], +) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/random.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/random.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e188eb02893f06faf35b1b5265c83dc24f1d84c9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/random.pyi @@ -0,0 +1,1546 @@ +import threading +from collections.abc import Sequence +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt +from numpy.random._generator import Generator +from numpy.random._mt19937 import MT19937 +from numpy.random._pcg64 import PCG64 +from numpy.random._philox import Philox +from numpy.random._sfc64 import SFC64 +from numpy.random.bit_generator import SeedlessSeedSequence, SeedSequence + +def_rng = np.random.default_rng() +seed_seq = np.random.SeedSequence() +mt19937 = np.random.MT19937() +pcg64 = np.random.PCG64() +sfc64 = np.random.SFC64() +philox = np.random.Philox() +seedless_seq = SeedlessSeedSequence() + +assert_type(def_rng, Generator) +assert_type(mt19937, MT19937) +assert_type(pcg64, PCG64) +assert_type(sfc64, SFC64) +assert_type(philox, Philox) +assert_type(seed_seq, SeedSequence) +assert_type(seedless_seq, SeedlessSeedSequence) + +mt19937_jumped = mt19937.jumped() +mt19937_jumped3 = mt19937.jumped(3) +mt19937_raw = mt19937.random_raw() +mt19937_raw_arr = mt19937.random_raw(5) + +assert_type(mt19937_jumped, MT19937) +assert_type(mt19937_jumped3, MT19937) +assert_type(mt19937_raw, int) +assert_type(mt19937_raw_arr, npt.NDArray[np.uint64]) +assert_type(mt19937.lock, threading.Lock) + +pcg64_jumped = pcg64.jumped() +pcg64_jumped3 = pcg64.jumped(3) +pcg64_adv = pcg64.advance(3) +pcg64_raw = pcg64.random_raw() +pcg64_raw_arr = pcg64.random_raw(5) + +assert_type(pcg64_jumped, PCG64) +assert_type(pcg64_jumped3, PCG64) +assert_type(pcg64_adv, PCG64) +assert_type(pcg64_raw, int) +assert_type(pcg64_raw_arr, npt.NDArray[np.uint64]) +assert_type(pcg64.lock, threading.Lock) + +philox_jumped = philox.jumped() +philox_jumped3 = philox.jumped(3) +philox_adv = philox.advance(3) +philox_raw = philox.random_raw() +philox_raw_arr = philox.random_raw(5) + +assert_type(philox_jumped, Philox) +assert_type(philox_jumped3, Philox) +assert_type(philox_adv, Philox) +assert_type(philox_raw, int) +assert_type(philox_raw_arr, npt.NDArray[np.uint64]) +assert_type(philox.lock, threading.Lock) + +sfc64_raw = sfc64.random_raw() +sfc64_raw_arr = sfc64.random_raw(5) + +assert_type(sfc64_raw, int) +assert_type(sfc64_raw_arr, npt.NDArray[np.uint64]) +assert_type(sfc64.lock, threading.Lock) + +assert_type(seed_seq.pool, npt.NDArray[np.uint32]) +assert_type(seed_seq.entropy, int | Sequence[int] | None) +assert_type(seed_seq.spawn(1), list[np.random.SeedSequence]) +assert_type(seed_seq.generate_state(8, "uint32"), npt.NDArray[np.uint32 | np.uint64]) +assert_type(seed_seq.generate_state(8, "uint64"), npt.NDArray[np.uint32 | np.uint64]) + +def_gen: np.random.Generator = np.random.default_rng() + +D_arr_0p1: npt.NDArray[np.float64] = np.array([0.1]) +D_arr_0p5: npt.NDArray[np.float64] = np.array([0.5]) +D_arr_0p9: npt.NDArray[np.float64] = np.array([0.9]) +D_arr_1p5: npt.NDArray[np.float64] = np.array([1.5]) +I_arr_10: npt.NDArray[np.int_] = np.array([10], dtype=np.int_) +I_arr_20: npt.NDArray[np.int_] = np.array([20], dtype=np.int_) +D_arr_like_0p1: list[float] = [0.1] +D_arr_like_0p5: list[float] = [0.5] +D_arr_like_0p9: list[float] = [0.9] +D_arr_like_1p5: list[float] = [1.5] +I_arr_like_10: list[int] = [10] +I_arr_like_20: list[int] = [20] +D_2D_like: list[list[float]] = [[1, 2], [2, 3], [3, 4], [4, 5.1]] +D_2D: npt.NDArray[np.float64] = np.array(D_2D_like) +S_out: npt.NDArray[np.float32] = np.empty(1, dtype=np.float32) +D_out: npt.NDArray[np.float64] = np.empty(1) + +assert_type(def_gen.standard_normal(), float) +assert_type(def_gen.standard_normal(dtype=np.float32), float) +assert_type(def_gen.standard_normal(dtype="float32"), float) +assert_type(def_gen.standard_normal(dtype="double"), float) +assert_type(def_gen.standard_normal(dtype=np.float64), float) +assert_type(def_gen.standard_normal(size=None), float) +assert_type(def_gen.standard_normal(size=1), npt.NDArray[np.float64]) +assert_type(def_gen.standard_normal(size=1, dtype=np.float32), npt.NDArray[np.float32]) +assert_type(def_gen.standard_normal(size=1, dtype="f4"), npt.NDArray[np.float32]) +assert_type(def_gen.standard_normal(size=1, dtype="float32", out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.standard_normal(dtype=np.float32, out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.standard_normal(size=1, dtype=np.float64), npt.NDArray[np.float64]) +assert_type(def_gen.standard_normal(size=1, dtype="float64"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_normal(size=1, dtype="f8"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_normal(out=D_out), npt.NDArray[np.float64]) +assert_type(def_gen.standard_normal(size=1, dtype="float64"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_normal(size=1, dtype="float64", out=D_out), npt.NDArray[np.float64]) + +assert_type(def_gen.random(), float) +assert_type(def_gen.random(dtype=np.float32), float) +assert_type(def_gen.random(dtype="float32"), float) +assert_type(def_gen.random(dtype="double"), float) +assert_type(def_gen.random(dtype=np.float64), float) +assert_type(def_gen.random(size=None), float) +assert_type(def_gen.random(size=1), npt.NDArray[np.float64]) +assert_type(def_gen.random(size=1, dtype=np.float32), npt.NDArray[np.float32]) +assert_type(def_gen.random(size=1, dtype="f4"), npt.NDArray[np.float32]) +assert_type(def_gen.random(size=1, dtype="float32", out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.random(dtype=np.float32, out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.random(size=1, dtype=np.float64), npt.NDArray[np.float64]) +assert_type(def_gen.random(size=1, dtype="float64"), npt.NDArray[np.float64]) +assert_type(def_gen.random(size=1, dtype="f8"), npt.NDArray[np.float64]) +assert_type(def_gen.random(out=D_out), npt.NDArray[np.float64]) +assert_type(def_gen.random(size=1, dtype="float64"), npt.NDArray[np.float64]) +assert_type(def_gen.random(size=1, dtype="float64", out=D_out), npt.NDArray[np.float64]) + +assert_type(def_gen.standard_cauchy(), float) +assert_type(def_gen.standard_cauchy(size=None), float) +assert_type(def_gen.standard_cauchy(size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.standard_exponential(), float) +assert_type(def_gen.standard_exponential(method="inv"), float) +assert_type(def_gen.standard_exponential(dtype=np.float32), float) +assert_type(def_gen.standard_exponential(dtype="float32"), float) +assert_type(def_gen.standard_exponential(dtype="double"), float) +assert_type(def_gen.standard_exponential(dtype=np.float64), float) +assert_type(def_gen.standard_exponential(size=None), float) +assert_type(def_gen.standard_exponential(size=None, method="inv"), float) +assert_type(def_gen.standard_exponential(size=1, method="inv"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_exponential(size=1, dtype=np.float32), npt.NDArray[np.float32]) +assert_type(def_gen.standard_exponential(size=1, dtype="f4", method="inv"), npt.NDArray[np.float32]) +assert_type(def_gen.standard_exponential(size=1, dtype="float32", out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.standard_exponential(dtype=np.float32, out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.standard_exponential(size=1, dtype=np.float64, method="inv"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_exponential(size=1, dtype="float64"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_exponential(size=1, dtype="f8"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_exponential(out=D_out), npt.NDArray[np.float64]) +assert_type(def_gen.standard_exponential(size=1, dtype="float64"), npt.NDArray[np.float64]) +assert_type(def_gen.standard_exponential(size=1, dtype="float64", out=D_out), npt.NDArray[np.float64]) + +assert_type(def_gen.zipf(1.5), int) +assert_type(def_gen.zipf(1.5, size=None), int) +assert_type(def_gen.zipf(1.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.zipf(D_arr_1p5), npt.NDArray[np.int64]) +assert_type(def_gen.zipf(D_arr_1p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.zipf(D_arr_like_1p5), npt.NDArray[np.int64]) +assert_type(def_gen.zipf(D_arr_like_1p5, size=1), npt.NDArray[np.int64]) + +assert_type(def_gen.weibull(0.5), float) +assert_type(def_gen.weibull(0.5, size=None), float) +assert_type(def_gen.weibull(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.weibull(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.weibull(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.weibull(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.weibull(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.standard_t(0.5), float) +assert_type(def_gen.standard_t(0.5, size=None), float) +assert_type(def_gen.standard_t(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.standard_t(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.standard_t(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.standard_t(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.standard_t(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.poisson(0.5), int) +assert_type(def_gen.poisson(0.5, size=None), int) +assert_type(def_gen.poisson(0.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.poisson(D_arr_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.poisson(D_arr_0p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.poisson(D_arr_like_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.poisson(D_arr_like_0p5, size=1), npt.NDArray[np.int64]) + +assert_type(def_gen.power(0.5), float) +assert_type(def_gen.power(0.5, size=None), float) +assert_type(def_gen.power(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.power(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.power(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.power(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.power(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.pareto(0.5), float) +assert_type(def_gen.pareto(0.5, size=None), float) +assert_type(def_gen.pareto(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.pareto(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.pareto(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.pareto(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.pareto(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.chisquare(0.5), float) +assert_type(def_gen.chisquare(0.5, size=None), float) +assert_type(def_gen.chisquare(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.chisquare(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.chisquare(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.chisquare(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.chisquare(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.exponential(0.5), float) +assert_type(def_gen.exponential(0.5, size=None), float) +assert_type(def_gen.exponential(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.exponential(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.exponential(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.exponential(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.exponential(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.geometric(0.5), int) +assert_type(def_gen.geometric(0.5, size=None), int) +assert_type(def_gen.geometric(0.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.geometric(D_arr_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.geometric(D_arr_0p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.geometric(D_arr_like_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.geometric(D_arr_like_0p5, size=1), npt.NDArray[np.int64]) + +assert_type(def_gen.logseries(0.5), int) +assert_type(def_gen.logseries(0.5, size=None), int) +assert_type(def_gen.logseries(0.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.logseries(D_arr_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.logseries(D_arr_0p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.logseries(D_arr_like_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.logseries(D_arr_like_0p5, size=1), npt.NDArray[np.int64]) + +assert_type(def_gen.rayleigh(0.5), float) +assert_type(def_gen.rayleigh(0.5, size=None), float) +assert_type(def_gen.rayleigh(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.rayleigh(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.rayleigh(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.rayleigh(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.rayleigh(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.standard_gamma(0.5), float) +assert_type(def_gen.standard_gamma(0.5, size=None), float) +assert_type(def_gen.standard_gamma(0.5, dtype="float32"), float) +assert_type(def_gen.standard_gamma(0.5, size=None, dtype="float32"), float) +assert_type(def_gen.standard_gamma(0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(D_arr_0p5, dtype="f4"), npt.NDArray[np.float32]) +assert_type(def_gen.standard_gamma(0.5, size=1, dtype="float32", out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.standard_gamma(D_arr_0p5, dtype=np.float32, out=S_out), npt.NDArray[np.float32]) +assert_type(def_gen.standard_gamma(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(0.5, out=D_out), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(D_arr_like_0p5, out=D_out), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.standard_gamma(D_arr_like_0p5, size=1, out=D_out, dtype=np.float64), npt.NDArray[np.float64]) + +assert_type(def_gen.vonmises(0.5, 0.5), float) +assert_type(def_gen.vonmises(0.5, 0.5, size=None), float) +assert_type(def_gen.vonmises(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.vonmises(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.wald(0.5, 0.5), float) +assert_type(def_gen.wald(0.5, 0.5, size=None), float) +assert_type(def_gen.wald(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.wald(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.wald(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.wald(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.wald(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.wald(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.wald(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.wald(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.wald(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.wald(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.wald(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.uniform(0.5, 0.5), float) +assert_type(def_gen.uniform(0.5, 0.5, size=None), float) +assert_type(def_gen.uniform(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.uniform(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.beta(0.5, 0.5), float) +assert_type(def_gen.beta(0.5, 0.5, size=None), float) +assert_type(def_gen.beta(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.beta(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.beta(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.beta(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.beta(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.beta(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.beta(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.beta(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.beta(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.beta(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.beta(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.f(0.5, 0.5), float) +assert_type(def_gen.f(0.5, 0.5, size=None), float) +assert_type(def_gen.f(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.f(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.f(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.f(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.f(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.f(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.f(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.f(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.f(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.f(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.f(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.gamma(0.5, 0.5), float) +assert_type(def_gen.gamma(0.5, 0.5, size=None), float) +assert_type(def_gen.gamma(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gamma(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.gumbel(0.5, 0.5), float) +assert_type(def_gen.gumbel(0.5, 0.5, size=None), float) +assert_type(def_gen.gumbel(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.gumbel(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.laplace(0.5, 0.5), float) +assert_type(def_gen.laplace(0.5, 0.5, size=None), float) +assert_type(def_gen.laplace(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.laplace(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.logistic(0.5, 0.5), float) +assert_type(def_gen.logistic(0.5, 0.5, size=None), float) +assert_type(def_gen.logistic(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.logistic(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.lognormal(0.5, 0.5), float) +assert_type(def_gen.lognormal(0.5, 0.5, size=None), float) +assert_type(def_gen.lognormal(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.lognormal(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.noncentral_chisquare(0.5, 0.5), float) +assert_type(def_gen.noncentral_chisquare(0.5, 0.5, size=None), float) +assert_type(def_gen.noncentral_chisquare(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.normal(0.5, 0.5), float) +assert_type(def_gen.normal(0.5, 0.5, size=None), float) +assert_type(def_gen.normal(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.normal(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.normal(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.normal(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.normal(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.normal(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(def_gen.normal(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.normal(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.normal(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(def_gen.normal(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.normal(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.triangular(0.1, 0.5, 0.9), float) +assert_type(def_gen.triangular(0.1, 0.5, 0.9, size=None), float) +assert_type(def_gen.triangular(0.1, 0.5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(D_arr_0p1, 0.5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(0.1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(D_arr_0p1, 0.5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(0.1, D_arr_0p5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(D_arr_like_0p1, 0.5, D_arr_0p9), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(0.5, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(D_arr_0p1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(D_arr_like_0p1, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.triangular(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.noncentral_f(0.1, 0.5, 0.9), float) +assert_type(def_gen.noncentral_f(0.1, 0.5, 0.9, size=None), float) +assert_type(def_gen.noncentral_f(0.1, 0.5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(D_arr_0p1, 0.5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(0.1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(D_arr_0p1, 0.5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(0.1, D_arr_0p5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(D_arr_like_0p1, 0.5, D_arr_0p9), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(0.5, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(D_arr_0p1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1), npt.NDArray[np.float64]) +assert_type(def_gen.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) + +assert_type(def_gen.binomial(10, 0.5), int) +assert_type(def_gen.binomial(10, 0.5, size=None), int) +assert_type(def_gen.binomial(10, 0.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(I_arr_10, 0.5), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(10, D_arr_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(I_arr_10, 0.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(10, D_arr_0p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(I_arr_like_10, 0.5), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(10, D_arr_like_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(I_arr_10, D_arr_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(I_arr_like_10, D_arr_like_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(I_arr_10, D_arr_0p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.binomial(I_arr_like_10, D_arr_like_0p5, size=1), npt.NDArray[np.int64]) + +assert_type(def_gen.negative_binomial(10, 0.5), int) +assert_type(def_gen.negative_binomial(10, 0.5, size=None), int) +assert_type(def_gen.negative_binomial(10, 0.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(I_arr_10, 0.5), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(10, D_arr_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(I_arr_10, 0.5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(10, D_arr_0p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(I_arr_like_10, 0.5), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(10, D_arr_like_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(I_arr_10, D_arr_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(I_arr_like_10, D_arr_like_0p5), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(I_arr_10, D_arr_0p5, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.negative_binomial(I_arr_like_10, D_arr_like_0p5, size=1), npt.NDArray[np.int64]) + +assert_type(def_gen.hypergeometric(20, 20, 10), int) +assert_type(def_gen.hypergeometric(20, 20, 10, size=None), int) +assert_type(def_gen.hypergeometric(20, 20, 10, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(I_arr_20, 20, 10), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(20, I_arr_20, 10), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(I_arr_20, 20, I_arr_like_10, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(20, I_arr_20, 10, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(I_arr_like_20, 20, I_arr_10), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(20, I_arr_like_20, 10), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(I_arr_20, I_arr_20, 10), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(I_arr_like_20, I_arr_like_20, 10), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(I_arr_20, I_arr_20, I_arr_10, size=1), npt.NDArray[np.int64]) +assert_type(def_gen.hypergeometric(I_arr_like_20, I_arr_like_20, I_arr_like_10, size=1), npt.NDArray[np.int64]) + +I_int64_100: npt.NDArray[np.int64] = np.array([100], dtype=np.int64) + +assert_type(def_gen.integers(0, 100), np.int64) +assert_type(def_gen.integers(100), np.int64) +assert_type(def_gen.integers([100]), npt.NDArray[np.int64]) +assert_type(def_gen.integers(0, [100]), npt.NDArray[np.int64]) + +I_bool_low: npt.NDArray[np.bool] = np.array([0], dtype=np.bool) +I_bool_low_like: list[int] = [0] +I_bool_high_open: npt.NDArray[np.bool] = np.array([1], dtype=np.bool) +I_bool_high_closed: npt.NDArray[np.bool] = np.array([1], dtype=np.bool) + +assert_type(def_gen.integers(2, dtype=bool), bool) +assert_type(def_gen.integers(0, 2, dtype=bool), bool) +assert_type(def_gen.integers(1, dtype=bool, endpoint=True), bool) +assert_type(def_gen.integers(0, 1, dtype=bool, endpoint=True), bool) +assert_type(def_gen.integers(I_bool_low_like, 1, dtype=bool, endpoint=True), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_high_open, dtype=bool), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_low, I_bool_high_open, dtype=bool), npt.NDArray[np.bool]) +assert_type(def_gen.integers(0, I_bool_high_open, dtype=bool), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_high_closed, dtype=bool, endpoint=True), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_low, I_bool_high_closed, dtype=bool, endpoint=True), npt.NDArray[np.bool]) +assert_type(def_gen.integers(0, I_bool_high_closed, dtype=bool, endpoint=True), npt.NDArray[np.bool]) + +assert_type(def_gen.integers(2, dtype=np.bool), np.bool) +assert_type(def_gen.integers(0, 2, dtype=np.bool), np.bool) +assert_type(def_gen.integers(1, dtype=np.bool, endpoint=True), np.bool) +assert_type(def_gen.integers(0, 1, dtype=np.bool, endpoint=True), np.bool) +assert_type(def_gen.integers(I_bool_low_like, 1, dtype=np.bool, endpoint=True), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_high_open, dtype=np.bool), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_low, I_bool_high_open, dtype=np.bool), npt.NDArray[np.bool]) +assert_type(def_gen.integers(0, I_bool_high_open, dtype=np.bool), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_high_closed, dtype=np.bool, endpoint=True), npt.NDArray[np.bool]) +assert_type(def_gen.integers(I_bool_low, I_bool_high_closed, dtype=np.bool, endpoint=True), npt.NDArray[np.bool]) +assert_type(def_gen.integers(0, I_bool_high_closed, dtype=np.bool, endpoint=True), npt.NDArray[np.bool]) + +I_u1_low: npt.NDArray[np.uint8] = np.array([0], dtype=np.uint8) +I_u1_low_like: list[int] = [0] +I_u1_high_open: npt.NDArray[np.uint8] = np.array([255], dtype=np.uint8) +I_u1_high_closed: npt.NDArray[np.uint8] = np.array([255], dtype=np.uint8) + +assert_type(def_gen.integers(256, dtype="u1"), np.uint8) +assert_type(def_gen.integers(0, 256, dtype="u1"), np.uint8) +assert_type(def_gen.integers(255, dtype="u1", endpoint=True), np.uint8) +assert_type(def_gen.integers(0, 255, dtype="u1", endpoint=True), np.uint8) +assert_type(def_gen.integers(I_u1_low_like, 255, dtype="u1", endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_high_open, dtype="u1"), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_low, I_u1_high_open, dtype="u1"), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(0, I_u1_high_open, dtype="u1"), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_high_closed, dtype="u1", endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_low, I_u1_high_closed, dtype="u1", endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(0, I_u1_high_closed, dtype="u1", endpoint=True), npt.NDArray[np.uint8]) + +assert_type(def_gen.integers(256, dtype="uint8"), np.uint8) +assert_type(def_gen.integers(0, 256, dtype="uint8"), np.uint8) +assert_type(def_gen.integers(255, dtype="uint8", endpoint=True), np.uint8) +assert_type(def_gen.integers(0, 255, dtype="uint8", endpoint=True), np.uint8) +assert_type(def_gen.integers(I_u1_low_like, 255, dtype="uint8", endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_high_open, dtype="uint8"), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_low, I_u1_high_open, dtype="uint8"), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(0, I_u1_high_open, dtype="uint8"), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_high_closed, dtype="uint8", endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_low, I_u1_high_closed, dtype="uint8", endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(0, I_u1_high_closed, dtype="uint8", endpoint=True), npt.NDArray[np.uint8]) + +assert_type(def_gen.integers(256, dtype=np.uint8), np.uint8) +assert_type(def_gen.integers(0, 256, dtype=np.uint8), np.uint8) +assert_type(def_gen.integers(255, dtype=np.uint8, endpoint=True), np.uint8) +assert_type(def_gen.integers(0, 255, dtype=np.uint8, endpoint=True), np.uint8) +assert_type(def_gen.integers(I_u1_low_like, 255, dtype=np.uint8, endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_high_open, dtype=np.uint8), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_low, I_u1_high_open, dtype=np.uint8), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(0, I_u1_high_open, dtype=np.uint8), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_high_closed, dtype=np.uint8, endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(I_u1_low, I_u1_high_closed, dtype=np.uint8, endpoint=True), npt.NDArray[np.uint8]) +assert_type(def_gen.integers(0, I_u1_high_closed, dtype=np.uint8, endpoint=True), npt.NDArray[np.uint8]) + +I_u2_low: npt.NDArray[np.uint16] = np.array([0], dtype=np.uint16) +I_u2_low_like: list[int] = [0] +I_u2_high_open: npt.NDArray[np.uint16] = np.array([65535], dtype=np.uint16) +I_u2_high_closed: npt.NDArray[np.uint16] = np.array([65535], dtype=np.uint16) + +assert_type(def_gen.integers(65536, dtype="u2"), np.uint16) +assert_type(def_gen.integers(0, 65536, dtype="u2"), np.uint16) +assert_type(def_gen.integers(65535, dtype="u2", endpoint=True), np.uint16) +assert_type(def_gen.integers(0, 65535, dtype="u2", endpoint=True), np.uint16) +assert_type(def_gen.integers(I_u2_low_like, 65535, dtype="u2", endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_high_open, dtype="u2"), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_low, I_u2_high_open, dtype="u2"), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(0, I_u2_high_open, dtype="u2"), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_high_closed, dtype="u2", endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_low, I_u2_high_closed, dtype="u2", endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(0, I_u2_high_closed, dtype="u2", endpoint=True), npt.NDArray[np.uint16]) + +assert_type(def_gen.integers(65536, dtype="uint16"), np.uint16) +assert_type(def_gen.integers(0, 65536, dtype="uint16"), np.uint16) +assert_type(def_gen.integers(65535, dtype="uint16", endpoint=True), np.uint16) +assert_type(def_gen.integers(0, 65535, dtype="uint16", endpoint=True), np.uint16) +assert_type(def_gen.integers(I_u2_low_like, 65535, dtype="uint16", endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_high_open, dtype="uint16"), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_low, I_u2_high_open, dtype="uint16"), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(0, I_u2_high_open, dtype="uint16"), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_high_closed, dtype="uint16", endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_low, I_u2_high_closed, dtype="uint16", endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(0, I_u2_high_closed, dtype="uint16", endpoint=True), npt.NDArray[np.uint16]) + +assert_type(def_gen.integers(65536, dtype=np.uint16), np.uint16) +assert_type(def_gen.integers(0, 65536, dtype=np.uint16), np.uint16) +assert_type(def_gen.integers(65535, dtype=np.uint16, endpoint=True), np.uint16) +assert_type(def_gen.integers(0, 65535, dtype=np.uint16, endpoint=True), np.uint16) +assert_type(def_gen.integers(I_u2_low_like, 65535, dtype=np.uint16, endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_high_open, dtype=np.uint16), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_low, I_u2_high_open, dtype=np.uint16), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(0, I_u2_high_open, dtype=np.uint16), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_high_closed, dtype=np.uint16, endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(I_u2_low, I_u2_high_closed, dtype=np.uint16, endpoint=True), npt.NDArray[np.uint16]) +assert_type(def_gen.integers(0, I_u2_high_closed, dtype=np.uint16, endpoint=True), npt.NDArray[np.uint16]) + +I_u4_low: npt.NDArray[np.uint32] = np.array([0], dtype=np.uint32) +I_u4_low_like: list[int] = [0] +I_u4_high_open: npt.NDArray[np.uint32] = np.array([4294967295], dtype=np.uint32) +I_u4_high_closed: npt.NDArray[np.uint32] = np.array([4294967295], dtype=np.uint32) + +assert_type(def_gen.integers(4294967296, dtype=np.int_), np.int_) +assert_type(def_gen.integers(0, 4294967296, dtype=np.int_), np.int_) +assert_type(def_gen.integers(4294967295, dtype=np.int_, endpoint=True), np.int_) +assert_type(def_gen.integers(0, 4294967295, dtype=np.int_, endpoint=True), np.int_) +assert_type(def_gen.integers(I_u4_low_like, 4294967295, dtype=np.int_, endpoint=True), npt.NDArray[np.int_]) +assert_type(def_gen.integers(I_u4_high_open, dtype=np.int_), npt.NDArray[np.int_]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_open, dtype=np.int_), npt.NDArray[np.int_]) +assert_type(def_gen.integers(0, I_u4_high_open, dtype=np.int_), npt.NDArray[np.int_]) +assert_type(def_gen.integers(I_u4_high_closed, dtype=np.int_, endpoint=True), npt.NDArray[np.int_]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_closed, dtype=np.int_, endpoint=True), npt.NDArray[np.int_]) +assert_type(def_gen.integers(0, I_u4_high_closed, dtype=np.int_, endpoint=True), npt.NDArray[np.int_]) + +assert_type(def_gen.integers(4294967296, dtype="u4"), np.uint32) +assert_type(def_gen.integers(0, 4294967296, dtype="u4"), np.uint32) +assert_type(def_gen.integers(4294967295, dtype="u4", endpoint=True), np.uint32) +assert_type(def_gen.integers(0, 4294967295, dtype="u4", endpoint=True), np.uint32) +assert_type(def_gen.integers(I_u4_low_like, 4294967295, dtype="u4", endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_high_open, dtype="u4"), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_open, dtype="u4"), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(0, I_u4_high_open, dtype="u4"), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_high_closed, dtype="u4", endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_closed, dtype="u4", endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(0, I_u4_high_closed, dtype="u4", endpoint=True), npt.NDArray[np.uint32]) + +assert_type(def_gen.integers(4294967296, dtype="uint32"), np.uint32) +assert_type(def_gen.integers(0, 4294967296, dtype="uint32"), np.uint32) +assert_type(def_gen.integers(4294967295, dtype="uint32", endpoint=True), np.uint32) +assert_type(def_gen.integers(0, 4294967295, dtype="uint32", endpoint=True), np.uint32) +assert_type(def_gen.integers(I_u4_low_like, 4294967295, dtype="uint32", endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_high_open, dtype="uint32"), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_open, dtype="uint32"), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(0, I_u4_high_open, dtype="uint32"), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_high_closed, dtype="uint32", endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_closed, dtype="uint32", endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(0, I_u4_high_closed, dtype="uint32", endpoint=True), npt.NDArray[np.uint32]) + +assert_type(def_gen.integers(4294967296, dtype=np.uint32), np.uint32) +assert_type(def_gen.integers(0, 4294967296, dtype=np.uint32), np.uint32) +assert_type(def_gen.integers(4294967295, dtype=np.uint32, endpoint=True), np.uint32) +assert_type(def_gen.integers(0, 4294967295, dtype=np.uint32, endpoint=True), np.uint32) +assert_type(def_gen.integers(I_u4_low_like, 4294967295, dtype=np.uint32, endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_high_open, dtype=np.uint32), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_open, dtype=np.uint32), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(0, I_u4_high_open, dtype=np.uint32), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_high_closed, dtype=np.uint32, endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_closed, dtype=np.uint32, endpoint=True), npt.NDArray[np.uint32]) +assert_type(def_gen.integers(0, I_u4_high_closed, dtype=np.uint32, endpoint=True), npt.NDArray[np.uint32]) + +assert_type(def_gen.integers(4294967296, dtype=np.uint), np.uint) +assert_type(def_gen.integers(0, 4294967296, dtype=np.uint), np.uint) +assert_type(def_gen.integers(4294967295, dtype=np.uint, endpoint=True), np.uint) +assert_type(def_gen.integers(0, 4294967295, dtype=np.uint, endpoint=True), np.uint) +assert_type(def_gen.integers(I_u4_low_like, 4294967295, dtype=np.uint, endpoint=True), npt.NDArray[np.uint]) +assert_type(def_gen.integers(I_u4_high_open, dtype=np.uint), npt.NDArray[np.uint]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_open, dtype=np.uint), npt.NDArray[np.uint]) +assert_type(def_gen.integers(0, I_u4_high_open, dtype=np.uint), npt.NDArray[np.uint]) +assert_type(def_gen.integers(I_u4_high_closed, dtype=np.uint, endpoint=True), npt.NDArray[np.uint]) +assert_type(def_gen.integers(I_u4_low, I_u4_high_closed, dtype=np.uint, endpoint=True), npt.NDArray[np.uint]) +assert_type(def_gen.integers(0, I_u4_high_closed, dtype=np.uint, endpoint=True), npt.NDArray[np.uint]) + +I_u8_low: npt.NDArray[np.uint64] = np.array([0], dtype=np.uint64) +I_u8_low_like: list[int] = [0] +I_u8_high_open: npt.NDArray[np.uint64] = np.array([18446744073709551615], dtype=np.uint64) +I_u8_high_closed: npt.NDArray[np.uint64] = np.array([18446744073709551615], dtype=np.uint64) + +assert_type(def_gen.integers(18446744073709551616, dtype="u8"), np.uint64) +assert_type(def_gen.integers(0, 18446744073709551616, dtype="u8"), np.uint64) +assert_type(def_gen.integers(18446744073709551615, dtype="u8", endpoint=True), np.uint64) +assert_type(def_gen.integers(0, 18446744073709551615, dtype="u8", endpoint=True), np.uint64) +assert_type(def_gen.integers(I_u8_low_like, 18446744073709551615, dtype="u8", endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_high_open, dtype="u8"), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_low, I_u8_high_open, dtype="u8"), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(0, I_u8_high_open, dtype="u8"), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_high_closed, dtype="u8", endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_low, I_u8_high_closed, dtype="u8", endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(0, I_u8_high_closed, dtype="u8", endpoint=True), npt.NDArray[np.uint64]) + +assert_type(def_gen.integers(18446744073709551616, dtype="uint64"), np.uint64) +assert_type(def_gen.integers(0, 18446744073709551616, dtype="uint64"), np.uint64) +assert_type(def_gen.integers(18446744073709551615, dtype="uint64", endpoint=True), np.uint64) +assert_type(def_gen.integers(0, 18446744073709551615, dtype="uint64", endpoint=True), np.uint64) +assert_type(def_gen.integers(I_u8_low_like, 18446744073709551615, dtype="uint64", endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_high_open, dtype="uint64"), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_low, I_u8_high_open, dtype="uint64"), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(0, I_u8_high_open, dtype="uint64"), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_high_closed, dtype="uint64", endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_low, I_u8_high_closed, dtype="uint64", endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(0, I_u8_high_closed, dtype="uint64", endpoint=True), npt.NDArray[np.uint64]) + +assert_type(def_gen.integers(18446744073709551616, dtype=np.uint64), np.uint64) +assert_type(def_gen.integers(0, 18446744073709551616, dtype=np.uint64), np.uint64) +assert_type(def_gen.integers(18446744073709551615, dtype=np.uint64, endpoint=True), np.uint64) +assert_type(def_gen.integers(0, 18446744073709551615, dtype=np.uint64, endpoint=True), np.uint64) +assert_type(def_gen.integers(I_u8_low_like, 18446744073709551615, dtype=np.uint64, endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_high_open, dtype=np.uint64), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_low, I_u8_high_open, dtype=np.uint64), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(0, I_u8_high_open, dtype=np.uint64), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_high_closed, dtype=np.uint64, endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(I_u8_low, I_u8_high_closed, dtype=np.uint64, endpoint=True), npt.NDArray[np.uint64]) +assert_type(def_gen.integers(0, I_u8_high_closed, dtype=np.uint64, endpoint=True), npt.NDArray[np.uint64]) + +I_i1_low: npt.NDArray[np.int8] = np.array([-128], dtype=np.int8) +I_i1_low_like: list[int] = [-128] +I_i1_high_open: npt.NDArray[np.int8] = np.array([127], dtype=np.int8) +I_i1_high_closed: npt.NDArray[np.int8] = np.array([127], dtype=np.int8) + +assert_type(def_gen.integers(128, dtype="i1"), np.int8) +assert_type(def_gen.integers(-128, 128, dtype="i1"), np.int8) +assert_type(def_gen.integers(127, dtype="i1", endpoint=True), np.int8) +assert_type(def_gen.integers(-128, 127, dtype="i1", endpoint=True), np.int8) +assert_type(def_gen.integers(I_i1_low_like, 127, dtype="i1", endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_high_open, dtype="i1"), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_low, I_i1_high_open, dtype="i1"), npt.NDArray[np.int8]) +assert_type(def_gen.integers(-128, I_i1_high_open, dtype="i1"), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_high_closed, dtype="i1", endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_low, I_i1_high_closed, dtype="i1", endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(-128, I_i1_high_closed, dtype="i1", endpoint=True), npt.NDArray[np.int8]) + +assert_type(def_gen.integers(128, dtype="int8"), np.int8) +assert_type(def_gen.integers(-128, 128, dtype="int8"), np.int8) +assert_type(def_gen.integers(127, dtype="int8", endpoint=True), np.int8) +assert_type(def_gen.integers(-128, 127, dtype="int8", endpoint=True), np.int8) +assert_type(def_gen.integers(I_i1_low_like, 127, dtype="int8", endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_high_open, dtype="int8"), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_low, I_i1_high_open, dtype="int8"), npt.NDArray[np.int8]) +assert_type(def_gen.integers(-128, I_i1_high_open, dtype="int8"), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_high_closed, dtype="int8", endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_low, I_i1_high_closed, dtype="int8", endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(-128, I_i1_high_closed, dtype="int8", endpoint=True), npt.NDArray[np.int8]) + +assert_type(def_gen.integers(128, dtype=np.int8), np.int8) +assert_type(def_gen.integers(-128, 128, dtype=np.int8), np.int8) +assert_type(def_gen.integers(127, dtype=np.int8, endpoint=True), np.int8) +assert_type(def_gen.integers(-128, 127, dtype=np.int8, endpoint=True), np.int8) +assert_type(def_gen.integers(I_i1_low_like, 127, dtype=np.int8, endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_high_open, dtype=np.int8), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_low, I_i1_high_open, dtype=np.int8), npt.NDArray[np.int8]) +assert_type(def_gen.integers(-128, I_i1_high_open, dtype=np.int8), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_high_closed, dtype=np.int8, endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(I_i1_low, I_i1_high_closed, dtype=np.int8, endpoint=True), npt.NDArray[np.int8]) +assert_type(def_gen.integers(-128, I_i1_high_closed, dtype=np.int8, endpoint=True), npt.NDArray[np.int8]) + +I_i2_low: npt.NDArray[np.int16] = np.array([-32768], dtype=np.int16) +I_i2_low_like: list[int] = [-32768] +I_i2_high_open: npt.NDArray[np.int16] = np.array([32767], dtype=np.int16) +I_i2_high_closed: npt.NDArray[np.int16] = np.array([32767], dtype=np.int16) + +assert_type(def_gen.integers(32768, dtype="i2"), np.int16) +assert_type(def_gen.integers(-32768, 32768, dtype="i2"), np.int16) +assert_type(def_gen.integers(32767, dtype="i2", endpoint=True), np.int16) +assert_type(def_gen.integers(-32768, 32767, dtype="i2", endpoint=True), np.int16) +assert_type(def_gen.integers(I_i2_low_like, 32767, dtype="i2", endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_high_open, dtype="i2"), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_low, I_i2_high_open, dtype="i2"), npt.NDArray[np.int16]) +assert_type(def_gen.integers(-32768, I_i2_high_open, dtype="i2"), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_high_closed, dtype="i2", endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_low, I_i2_high_closed, dtype="i2", endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(-32768, I_i2_high_closed, dtype="i2", endpoint=True), npt.NDArray[np.int16]) + +assert_type(def_gen.integers(32768, dtype="int16"), np.int16) +assert_type(def_gen.integers(-32768, 32768, dtype="int16"), np.int16) +assert_type(def_gen.integers(32767, dtype="int16", endpoint=True), np.int16) +assert_type(def_gen.integers(-32768, 32767, dtype="int16", endpoint=True), np.int16) +assert_type(def_gen.integers(I_i2_low_like, 32767, dtype="int16", endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_high_open, dtype="int16"), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_low, I_i2_high_open, dtype="int16"), npt.NDArray[np.int16]) +assert_type(def_gen.integers(-32768, I_i2_high_open, dtype="int16"), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_high_closed, dtype="int16", endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_low, I_i2_high_closed, dtype="int16", endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(-32768, I_i2_high_closed, dtype="int16", endpoint=True), npt.NDArray[np.int16]) + +assert_type(def_gen.integers(32768, dtype=np.int16), np.int16) +assert_type(def_gen.integers(-32768, 32768, dtype=np.int16), np.int16) +assert_type(def_gen.integers(32767, dtype=np.int16, endpoint=True), np.int16) +assert_type(def_gen.integers(-32768, 32767, dtype=np.int16, endpoint=True), np.int16) +assert_type(def_gen.integers(I_i2_low_like, 32767, dtype=np.int16, endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_high_open, dtype=np.int16), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_low, I_i2_high_open, dtype=np.int16), npt.NDArray[np.int16]) +assert_type(def_gen.integers(-32768, I_i2_high_open, dtype=np.int16), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_high_closed, dtype=np.int16, endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(I_i2_low, I_i2_high_closed, dtype=np.int16, endpoint=True), npt.NDArray[np.int16]) +assert_type(def_gen.integers(-32768, I_i2_high_closed, dtype=np.int16, endpoint=True), npt.NDArray[np.int16]) + +I_i4_low: npt.NDArray[np.int32] = np.array([-2147483648], dtype=np.int32) +I_i4_low_like: list[int] = [-2147483648] +I_i4_high_open: npt.NDArray[np.int32] = np.array([2147483647], dtype=np.int32) +I_i4_high_closed: npt.NDArray[np.int32] = np.array([2147483647], dtype=np.int32) + +assert_type(def_gen.integers(2147483648, dtype="i4"), np.int32) +assert_type(def_gen.integers(-2147483648, 2147483648, dtype="i4"), np.int32) +assert_type(def_gen.integers(2147483647, dtype="i4", endpoint=True), np.int32) +assert_type(def_gen.integers(-2147483648, 2147483647, dtype="i4", endpoint=True), np.int32) +assert_type(def_gen.integers(I_i4_low_like, 2147483647, dtype="i4", endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_high_open, dtype="i4"), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_low, I_i4_high_open, dtype="i4"), npt.NDArray[np.int32]) +assert_type(def_gen.integers(-2147483648, I_i4_high_open, dtype="i4"), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_high_closed, dtype="i4", endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_low, I_i4_high_closed, dtype="i4", endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(-2147483648, I_i4_high_closed, dtype="i4", endpoint=True), npt.NDArray[np.int32]) + +assert_type(def_gen.integers(2147483648, dtype="int32"), np.int32) +assert_type(def_gen.integers(-2147483648, 2147483648, dtype="int32"), np.int32) +assert_type(def_gen.integers(2147483647, dtype="int32", endpoint=True), np.int32) +assert_type(def_gen.integers(-2147483648, 2147483647, dtype="int32", endpoint=True), np.int32) +assert_type(def_gen.integers(I_i4_low_like, 2147483647, dtype="int32", endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_high_open, dtype="int32"), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_low, I_i4_high_open, dtype="int32"), npt.NDArray[np.int32]) +assert_type(def_gen.integers(-2147483648, I_i4_high_open, dtype="int32"), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_high_closed, dtype="int32", endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_low, I_i4_high_closed, dtype="int32", endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(-2147483648, I_i4_high_closed, dtype="int32", endpoint=True), npt.NDArray[np.int32]) + +assert_type(def_gen.integers(2147483648, dtype=np.int32), np.int32) +assert_type(def_gen.integers(-2147483648, 2147483648, dtype=np.int32), np.int32) +assert_type(def_gen.integers(2147483647, dtype=np.int32, endpoint=True), np.int32) +assert_type(def_gen.integers(-2147483648, 2147483647, dtype=np.int32, endpoint=True), np.int32) +assert_type(def_gen.integers(I_i4_low_like, 2147483647, dtype=np.int32, endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_high_open, dtype=np.int32), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_low, I_i4_high_open, dtype=np.int32), npt.NDArray[np.int32]) +assert_type(def_gen.integers(-2147483648, I_i4_high_open, dtype=np.int32), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_high_closed, dtype=np.int32, endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(I_i4_low, I_i4_high_closed, dtype=np.int32, endpoint=True), npt.NDArray[np.int32]) +assert_type(def_gen.integers(-2147483648, I_i4_high_closed, dtype=np.int32, endpoint=True), npt.NDArray[np.int32]) + +I_i8_low: npt.NDArray[np.int64] = np.array([-9223372036854775808], dtype=np.int64) +I_i8_low_like: list[int] = [-9223372036854775808] +I_i8_high_open: npt.NDArray[np.int64] = np.array([9223372036854775807], dtype=np.int64) +I_i8_high_closed: npt.NDArray[np.int64] = np.array([9223372036854775807], dtype=np.int64) + +assert_type(def_gen.integers(9223372036854775808, dtype="i8"), np.int64) +assert_type(def_gen.integers(-9223372036854775808, 9223372036854775808, dtype="i8"), np.int64) +assert_type(def_gen.integers(9223372036854775807, dtype="i8", endpoint=True), np.int64) +assert_type(def_gen.integers(-9223372036854775808, 9223372036854775807, dtype="i8", endpoint=True), np.int64) +assert_type(def_gen.integers(I_i8_low_like, 9223372036854775807, dtype="i8", endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_high_open, dtype="i8"), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_low, I_i8_high_open, dtype="i8"), npt.NDArray[np.int64]) +assert_type(def_gen.integers(-9223372036854775808, I_i8_high_open, dtype="i8"), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_high_closed, dtype="i8", endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_low, I_i8_high_closed, dtype="i8", endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(-9223372036854775808, I_i8_high_closed, dtype="i8", endpoint=True), npt.NDArray[np.int64]) + +assert_type(def_gen.integers(9223372036854775808, dtype="int64"), np.int64) +assert_type(def_gen.integers(-9223372036854775808, 9223372036854775808, dtype="int64"), np.int64) +assert_type(def_gen.integers(9223372036854775807, dtype="int64", endpoint=True), np.int64) +assert_type(def_gen.integers(-9223372036854775808, 9223372036854775807, dtype="int64", endpoint=True), np.int64) +assert_type(def_gen.integers(I_i8_low_like, 9223372036854775807, dtype="int64", endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_high_open, dtype="int64"), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_low, I_i8_high_open, dtype="int64"), npt.NDArray[np.int64]) +assert_type(def_gen.integers(-9223372036854775808, I_i8_high_open, dtype="int64"), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_high_closed, dtype="int64", endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_low, I_i8_high_closed, dtype="int64", endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(-9223372036854775808, I_i8_high_closed, dtype="int64", endpoint=True), npt.NDArray[np.int64]) + +assert_type(def_gen.integers(9223372036854775808, dtype=np.int64), np.int64) +assert_type(def_gen.integers(-9223372036854775808, 9223372036854775808, dtype=np.int64), np.int64) +assert_type(def_gen.integers(9223372036854775807, dtype=np.int64, endpoint=True), np.int64) +assert_type(def_gen.integers(-9223372036854775808, 9223372036854775807, dtype=np.int64, endpoint=True), np.int64) +assert_type(def_gen.integers(I_i8_low_like, 9223372036854775807, dtype=np.int64, endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_high_open, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_low, I_i8_high_open, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(def_gen.integers(-9223372036854775808, I_i8_high_open, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_high_closed, dtype=np.int64, endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(I_i8_low, I_i8_high_closed, dtype=np.int64, endpoint=True), npt.NDArray[np.int64]) +assert_type(def_gen.integers(-9223372036854775808, I_i8_high_closed, dtype=np.int64, endpoint=True), npt.NDArray[np.int64]) + +assert_type(def_gen.bit_generator, np.random.BitGenerator) + +assert_type(def_gen.bytes(2), bytes) + +assert_type(def_gen.choice(5), int) +assert_type(def_gen.choice(5, 3), npt.NDArray[np.int64]) +assert_type(def_gen.choice(5, 3, replace=True), npt.NDArray[np.int64]) +assert_type(def_gen.choice(5, 3, p=[1 / 5] * 5), npt.NDArray[np.int64]) +assert_type(def_gen.choice(5, 3, p=[1 / 5] * 5, replace=False), npt.NDArray[np.int64]) + +assert_type(def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"]), Any) +assert_type(def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3), npt.NDArray[Any]) +assert_type(def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, p=[1 / 4] * 4), npt.NDArray[Any]) +assert_type(def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=True), npt.NDArray[Any]) +assert_type(def_gen.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=False, p=np.array([1 / 8, 1 / 8, 1 / 2, 1 / 4])), npt.NDArray[Any]) + +assert_type(def_gen.dirichlet([0.5, 0.5]), npt.NDArray[np.float64]) +assert_type(def_gen.dirichlet(np.array([0.5, 0.5])), npt.NDArray[np.float64]) +assert_type(def_gen.dirichlet(np.array([0.5, 0.5]), size=3), npt.NDArray[np.float64]) + +assert_type(def_gen.multinomial(20, [1 / 6.0] * 6), npt.NDArray[np.int64]) +assert_type(def_gen.multinomial(20, np.array([0.5, 0.5])), npt.NDArray[np.int64]) +assert_type(def_gen.multinomial(20, [1 / 6.0] * 6, size=2), npt.NDArray[np.int64]) +assert_type(def_gen.multinomial([[10], [20]], [1 / 6.0] * 6, size=(2, 2)), npt.NDArray[np.int64]) +assert_type(def_gen.multinomial(np.array([[10], [20]]), np.array([0.5, 0.5]), size=(2, 2)), npt.NDArray[np.int64]) + +assert_type(def_gen.multivariate_hypergeometric([3, 5, 7], 2), npt.NDArray[np.int64]) +assert_type(def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2), npt.NDArray[np.int64]) +assert_type(def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2, size=4), npt.NDArray[np.int64]) +assert_type(def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2, size=(4, 7)), npt.NDArray[np.int64]) +assert_type(def_gen.multivariate_hypergeometric([3, 5, 7], 2, method="count"), npt.NDArray[np.int64]) +assert_type(def_gen.multivariate_hypergeometric(np.array([3, 5, 7]), 2, method="marginals"), npt.NDArray[np.int64]) + +assert_type(def_gen.multivariate_normal([0.0], [[1.0]]), npt.NDArray[np.float64]) +assert_type(def_gen.multivariate_normal([0.0], np.array([[1.0]])), npt.NDArray[np.float64]) +assert_type(def_gen.multivariate_normal(np.array([0.0]), [[1.0]]), npt.NDArray[np.float64]) +assert_type(def_gen.multivariate_normal([0.0], np.array([[1.0]])), npt.NDArray[np.float64]) + +assert_type(def_gen.permutation(10), npt.NDArray[np.int64]) +assert_type(def_gen.permutation([1, 2, 3, 4]), npt.NDArray[Any]) +assert_type(def_gen.permutation(np.array([1, 2, 3, 4])), npt.NDArray[Any]) +assert_type(def_gen.permutation(D_2D, axis=1), npt.NDArray[Any]) +assert_type(def_gen.permuted(D_2D), npt.NDArray[Any]) +assert_type(def_gen.permuted(D_2D_like), npt.NDArray[Any]) +assert_type(def_gen.permuted(D_2D, axis=1), npt.NDArray[Any]) +assert_type(def_gen.permuted(D_2D, out=D_2D), npt.NDArray[Any]) +assert_type(def_gen.permuted(D_2D_like, out=D_2D), npt.NDArray[Any]) +assert_type(def_gen.permuted(D_2D_like, out=D_2D), npt.NDArray[Any]) +assert_type(def_gen.permuted(D_2D, axis=1, out=D_2D), npt.NDArray[Any]) + +assert_type(def_gen.shuffle(np.arange(10)), None) +assert_type(def_gen.shuffle([1, 2, 3, 4, 5]), None) +assert_type(def_gen.shuffle(D_2D, axis=1), None) + +assert_type(np.random.Generator(pcg64), np.random.Generator) +assert_type(def_gen.__str__(), str) +assert_type(def_gen.__repr__(), str) +assert_type(def_gen.__setstate__(dict(def_gen.bit_generator.state)), None) + +# RandomState +random_st: np.random.RandomState = np.random.RandomState() + +assert_type(random_st.standard_normal(), float) +assert_type(random_st.standard_normal(size=None), float) +assert_type(random_st.standard_normal(size=1), npt.NDArray[np.float64]) + +assert_type(random_st.random(), float) +assert_type(random_st.random(size=None), float) +assert_type(random_st.random(size=1), npt.NDArray[np.float64]) + +assert_type(random_st.standard_cauchy(), float) +assert_type(random_st.standard_cauchy(size=None), float) +assert_type(random_st.standard_cauchy(size=1), npt.NDArray[np.float64]) + +assert_type(random_st.standard_exponential(), float) +assert_type(random_st.standard_exponential(size=None), float) +assert_type(random_st.standard_exponential(size=1), npt.NDArray[np.float64]) + +assert_type(random_st.zipf(1.5), int) +assert_type(random_st.zipf(1.5, size=None), int) +assert_type(random_st.zipf(1.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.zipf(D_arr_1p5), npt.NDArray[np.long]) +assert_type(random_st.zipf(D_arr_1p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.zipf(D_arr_like_1p5), npt.NDArray[np.long]) +assert_type(random_st.zipf(D_arr_like_1p5, size=1), npt.NDArray[np.long]) + +assert_type(random_st.weibull(0.5), float) +assert_type(random_st.weibull(0.5, size=None), float) +assert_type(random_st.weibull(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.weibull(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.weibull(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.weibull(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.weibull(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.standard_t(0.5), float) +assert_type(random_st.standard_t(0.5, size=None), float) +assert_type(random_st.standard_t(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.standard_t(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.standard_t(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.standard_t(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.standard_t(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.poisson(0.5), int) +assert_type(random_st.poisson(0.5, size=None), int) +assert_type(random_st.poisson(0.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.poisson(D_arr_0p5), npt.NDArray[np.long]) +assert_type(random_st.poisson(D_arr_0p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.poisson(D_arr_like_0p5), npt.NDArray[np.long]) +assert_type(random_st.poisson(D_arr_like_0p5, size=1), npt.NDArray[np.long]) + +assert_type(random_st.power(0.5), float) +assert_type(random_st.power(0.5, size=None), float) +assert_type(random_st.power(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.power(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.power(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.power(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.power(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.pareto(0.5), float) +assert_type(random_st.pareto(0.5, size=None), float) +assert_type(random_st.pareto(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.pareto(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.pareto(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.pareto(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.pareto(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.chisquare(0.5), float) +assert_type(random_st.chisquare(0.5, size=None), float) +assert_type(random_st.chisquare(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.chisquare(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.chisquare(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.chisquare(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.chisquare(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.exponential(0.5), float) +assert_type(random_st.exponential(0.5, size=None), float) +assert_type(random_st.exponential(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.exponential(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.exponential(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.exponential(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.exponential(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.geometric(0.5), int) +assert_type(random_st.geometric(0.5, size=None), int) +assert_type(random_st.geometric(0.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.geometric(D_arr_0p5), npt.NDArray[np.long]) +assert_type(random_st.geometric(D_arr_0p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.geometric(D_arr_like_0p5), npt.NDArray[np.long]) +assert_type(random_st.geometric(D_arr_like_0p5, size=1), npt.NDArray[np.long]) + +assert_type(random_st.logseries(0.5), int) +assert_type(random_st.logseries(0.5, size=None), int) +assert_type(random_st.logseries(0.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.logseries(D_arr_0p5), npt.NDArray[np.long]) +assert_type(random_st.logseries(D_arr_0p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.logseries(D_arr_like_0p5), npt.NDArray[np.long]) +assert_type(random_st.logseries(D_arr_like_0p5, size=1), npt.NDArray[np.long]) + +assert_type(random_st.rayleigh(0.5), float) +assert_type(random_st.rayleigh(0.5, size=None), float) +assert_type(random_st.rayleigh(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.rayleigh(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.rayleigh(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.rayleigh(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.rayleigh(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.standard_gamma(0.5), float) +assert_type(random_st.standard_gamma(0.5, size=None), float) +assert_type(random_st.standard_gamma(0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.standard_gamma(D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.standard_gamma(D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.standard_gamma(D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.standard_gamma(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.standard_gamma(D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.vonmises(0.5, 0.5), float) +assert_type(random_st.vonmises(0.5, 0.5, size=None), float) +assert_type(random_st.vonmises(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.vonmises(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.wald(0.5, 0.5), float) +assert_type(random_st.wald(0.5, 0.5, size=None), float) +assert_type(random_st.wald(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.wald(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.wald(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.wald(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.wald(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.wald(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.wald(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.wald(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.wald(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.wald(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.wald(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.uniform(0.5, 0.5), float) +assert_type(random_st.uniform(0.5, 0.5, size=None), float) +assert_type(random_st.uniform(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.uniform(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.uniform(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.uniform(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.uniform(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.uniform(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.uniform(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.uniform(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.uniform(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.uniform(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.uniform(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.beta(0.5, 0.5), float) +assert_type(random_st.beta(0.5, 0.5, size=None), float) +assert_type(random_st.beta(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.beta(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.beta(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.beta(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.beta(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.beta(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.beta(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.beta(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.beta(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.beta(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.beta(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.f(0.5, 0.5), float) +assert_type(random_st.f(0.5, 0.5, size=None), float) +assert_type(random_st.f(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.f(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.f(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.f(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.f(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.f(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.f(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.f(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.f(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.f(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.f(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.gamma(0.5, 0.5), float) +assert_type(random_st.gamma(0.5, 0.5, size=None), float) +assert_type(random_st.gamma(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gamma(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.gamma(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gamma(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gamma(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gamma(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.gamma(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gamma(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gamma(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gamma(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gamma(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.gumbel(0.5, 0.5), float) +assert_type(random_st.gumbel(0.5, 0.5, size=None), float) +assert_type(random_st.gumbel(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.gumbel(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.laplace(0.5, 0.5), float) +assert_type(random_st.laplace(0.5, 0.5, size=None), float) +assert_type(random_st.laplace(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.laplace(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.laplace(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.laplace(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.laplace(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.laplace(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.laplace(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.laplace(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.laplace(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.laplace(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.laplace(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.logistic(0.5, 0.5), float) +assert_type(random_st.logistic(0.5, 0.5, size=None), float) +assert_type(random_st.logistic(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.logistic(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.logistic(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.logistic(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.logistic(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.logistic(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.logistic(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.logistic(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.logistic(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.logistic(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.logistic(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.lognormal(0.5, 0.5), float) +assert_type(random_st.lognormal(0.5, 0.5, size=None), float) +assert_type(random_st.lognormal(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.lognormal(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.noncentral_chisquare(0.5, 0.5), float) +assert_type(random_st.noncentral_chisquare(0.5, 0.5, size=None), float) +assert_type(random_st.noncentral_chisquare(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_chisquare(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.normal(0.5, 0.5), float) +assert_type(random_st.normal(0.5, 0.5, size=None), float) +assert_type(random_st.normal(0.5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.normal(D_arr_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.normal(0.5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.normal(D_arr_0p5, 0.5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.normal(0.5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.normal(D_arr_like_0p5, 0.5), npt.NDArray[np.float64]) +assert_type(random_st.normal(0.5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.normal(D_arr_0p5, D_arr_0p5), npt.NDArray[np.float64]) +assert_type(random_st.normal(D_arr_like_0p5, D_arr_like_0p5), npt.NDArray[np.float64]) +assert_type(random_st.normal(D_arr_0p5, D_arr_0p5, size=1), npt.NDArray[np.float64]) +assert_type(random_st.normal(D_arr_like_0p5, D_arr_like_0p5, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.triangular(0.1, 0.5, 0.9), float) +assert_type(random_st.triangular(0.1, 0.5, 0.9, size=None), float) +assert_type(random_st.triangular(0.1, 0.5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.triangular(D_arr_0p1, 0.5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.triangular(0.1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.triangular(D_arr_0p1, 0.5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.triangular(0.1, D_arr_0p5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.triangular(D_arr_like_0p1, 0.5, D_arr_0p9), npt.NDArray[np.float64]) +assert_type(random_st.triangular(0.5, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.triangular(D_arr_0p1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.triangular(D_arr_like_0p1, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.triangular(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.triangular(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.noncentral_f(0.1, 0.5, 0.9), float) +assert_type(random_st.noncentral_f(0.1, 0.5, 0.9, size=None), float) +assert_type(random_st.noncentral_f(0.1, 0.5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(D_arr_0p1, 0.5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(0.1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(D_arr_0p1, 0.5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(0.1, D_arr_0p5, 0.9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(D_arr_like_0p1, 0.5, D_arr_0p9), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(0.5, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(D_arr_0p1, D_arr_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, 0.9), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(D_arr_0p1, D_arr_0p5, D_arr_0p9, size=1), npt.NDArray[np.float64]) +assert_type(random_st.noncentral_f(D_arr_like_0p1, D_arr_like_0p5, D_arr_like_0p9, size=1), npt.NDArray[np.float64]) + +assert_type(random_st.binomial(10, 0.5), int) +assert_type(random_st.binomial(10, 0.5, size=None), int) +assert_type(random_st.binomial(10, 0.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.binomial(I_arr_10, 0.5), npt.NDArray[np.long]) +assert_type(random_st.binomial(10, D_arr_0p5), npt.NDArray[np.long]) +assert_type(random_st.binomial(I_arr_10, 0.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.binomial(10, D_arr_0p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.binomial(I_arr_like_10, 0.5), npt.NDArray[np.long]) +assert_type(random_st.binomial(10, D_arr_like_0p5), npt.NDArray[np.long]) +assert_type(random_st.binomial(I_arr_10, D_arr_0p5), npt.NDArray[np.long]) +assert_type(random_st.binomial(I_arr_like_10, D_arr_like_0p5), npt.NDArray[np.long]) +assert_type(random_st.binomial(I_arr_10, D_arr_0p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.binomial(I_arr_like_10, D_arr_like_0p5, size=1), npt.NDArray[np.long]) + +assert_type(random_st.negative_binomial(10, 0.5), int) +assert_type(random_st.negative_binomial(10, 0.5, size=None), int) +assert_type(random_st.negative_binomial(10, 0.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(I_arr_10, 0.5), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(10, D_arr_0p5), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(I_arr_10, 0.5, size=1), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(10, D_arr_0p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(I_arr_like_10, 0.5), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(10, D_arr_like_0p5), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(I_arr_10, D_arr_0p5), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(I_arr_like_10, D_arr_like_0p5), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(I_arr_10, D_arr_0p5, size=1), npt.NDArray[np.long]) +assert_type(random_st.negative_binomial(I_arr_like_10, D_arr_like_0p5, size=1), npt.NDArray[np.long]) + +assert_type(random_st.hypergeometric(20, 20, 10), int) +assert_type(random_st.hypergeometric(20, 20, 10, size=None), int) +assert_type(random_st.hypergeometric(20, 20, 10, size=1), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(I_arr_20, 20, 10), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(20, I_arr_20, 10), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(I_arr_20, 20, I_arr_like_10, size=1), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(20, I_arr_20, 10, size=1), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(I_arr_like_20, 20, I_arr_10), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(20, I_arr_like_20, 10), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(I_arr_20, I_arr_20, 10), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(I_arr_like_20, I_arr_like_20, 10), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(I_arr_20, I_arr_20, I_arr_10, size=1), npt.NDArray[np.long]) +assert_type(random_st.hypergeometric(I_arr_like_20, I_arr_like_20, I_arr_like_10, size=1), npt.NDArray[np.long]) + +assert_type(random_st.randint(0, 100), int) +assert_type(random_st.randint(100), int) +assert_type(random_st.randint([100]), npt.NDArray[np.long]) +assert_type(random_st.randint(0, [100]), npt.NDArray[np.long]) + +assert_type(random_st.randint(2, dtype=bool), bool) +assert_type(random_st.randint(0, 2, dtype=bool), bool) +assert_type(random_st.randint(I_bool_high_open, dtype=bool), npt.NDArray[np.bool]) +assert_type(random_st.randint(I_bool_low, I_bool_high_open, dtype=bool), npt.NDArray[np.bool]) +assert_type(random_st.randint(0, I_bool_high_open, dtype=bool), npt.NDArray[np.bool]) + +assert_type(random_st.randint(2, dtype=np.bool), np.bool) +assert_type(random_st.randint(0, 2, dtype=np.bool), np.bool) +assert_type(random_st.randint(I_bool_high_open, dtype=np.bool), npt.NDArray[np.bool]) +assert_type(random_st.randint(I_bool_low, I_bool_high_open, dtype=np.bool), npt.NDArray[np.bool]) +assert_type(random_st.randint(0, I_bool_high_open, dtype=np.bool), npt.NDArray[np.bool]) + +assert_type(random_st.randint(256, dtype="u1"), np.uint8) +assert_type(random_st.randint(0, 256, dtype="u1"), np.uint8) +assert_type(random_st.randint(I_u1_high_open, dtype="u1"), npt.NDArray[np.uint8]) +assert_type(random_st.randint(I_u1_low, I_u1_high_open, dtype="u1"), npt.NDArray[np.uint8]) +assert_type(random_st.randint(0, I_u1_high_open, dtype="u1"), npt.NDArray[np.uint8]) + +assert_type(random_st.randint(256, dtype="uint8"), np.uint8) +assert_type(random_st.randint(0, 256, dtype="uint8"), np.uint8) +assert_type(random_st.randint(I_u1_high_open, dtype="uint8"), npt.NDArray[np.uint8]) +assert_type(random_st.randint(I_u1_low, I_u1_high_open, dtype="uint8"), npt.NDArray[np.uint8]) +assert_type(random_st.randint(0, I_u1_high_open, dtype="uint8"), npt.NDArray[np.uint8]) + +assert_type(random_st.randint(256, dtype=np.uint8), np.uint8) +assert_type(random_st.randint(0, 256, dtype=np.uint8), np.uint8) +assert_type(random_st.randint(I_u1_high_open, dtype=np.uint8), npt.NDArray[np.uint8]) +assert_type(random_st.randint(I_u1_low, I_u1_high_open, dtype=np.uint8), npt.NDArray[np.uint8]) +assert_type(random_st.randint(0, I_u1_high_open, dtype=np.uint8), npt.NDArray[np.uint8]) + +assert_type(random_st.randint(65536, dtype="u2"), np.uint16) +assert_type(random_st.randint(0, 65536, dtype="u2"), np.uint16) +assert_type(random_st.randint(I_u2_high_open, dtype="u2"), npt.NDArray[np.uint16]) +assert_type(random_st.randint(I_u2_low, I_u2_high_open, dtype="u2"), npt.NDArray[np.uint16]) +assert_type(random_st.randint(0, I_u2_high_open, dtype="u2"), npt.NDArray[np.uint16]) + +assert_type(random_st.randint(65536, dtype="uint16"), np.uint16) +assert_type(random_st.randint(0, 65536, dtype="uint16"), np.uint16) +assert_type(random_st.randint(I_u2_high_open, dtype="uint16"), npt.NDArray[np.uint16]) +assert_type(random_st.randint(I_u2_low, I_u2_high_open, dtype="uint16"), npt.NDArray[np.uint16]) +assert_type(random_st.randint(0, I_u2_high_open, dtype="uint16"), npt.NDArray[np.uint16]) + +assert_type(random_st.randint(65536, dtype=np.uint16), np.uint16) +assert_type(random_st.randint(0, 65536, dtype=np.uint16), np.uint16) +assert_type(random_st.randint(I_u2_high_open, dtype=np.uint16), npt.NDArray[np.uint16]) +assert_type(random_st.randint(I_u2_low, I_u2_high_open, dtype=np.uint16), npt.NDArray[np.uint16]) +assert_type(random_st.randint(0, I_u2_high_open, dtype=np.uint16), npt.NDArray[np.uint16]) + +assert_type(random_st.randint(4294967296, dtype="u4"), np.uint32) +assert_type(random_st.randint(0, 4294967296, dtype="u4"), np.uint32) +assert_type(random_st.randint(I_u4_high_open, dtype="u4"), npt.NDArray[np.uint32]) +assert_type(random_st.randint(I_u4_low, I_u4_high_open, dtype="u4"), npt.NDArray[np.uint32]) +assert_type(random_st.randint(0, I_u4_high_open, dtype="u4"), npt.NDArray[np.uint32]) + +assert_type(random_st.randint(4294967296, dtype="uint32"), np.uint32) +assert_type(random_st.randint(0, 4294967296, dtype="uint32"), np.uint32) +assert_type(random_st.randint(I_u4_high_open, dtype="uint32"), npt.NDArray[np.uint32]) +assert_type(random_st.randint(I_u4_low, I_u4_high_open, dtype="uint32"), npt.NDArray[np.uint32]) +assert_type(random_st.randint(0, I_u4_high_open, dtype="uint32"), npt.NDArray[np.uint32]) + +assert_type(random_st.randint(4294967296, dtype=np.uint32), np.uint32) +assert_type(random_st.randint(0, 4294967296, dtype=np.uint32), np.uint32) +assert_type(random_st.randint(I_u4_high_open, dtype=np.uint32), npt.NDArray[np.uint32]) +assert_type(random_st.randint(I_u4_low, I_u4_high_open, dtype=np.uint32), npt.NDArray[np.uint32]) +assert_type(random_st.randint(0, I_u4_high_open, dtype=np.uint32), npt.NDArray[np.uint32]) + +assert_type(random_st.randint(4294967296, dtype=np.uint), np.uint) +assert_type(random_st.randint(0, 4294967296, dtype=np.uint), np.uint) +assert_type(random_st.randint(I_u4_high_open, dtype=np.uint), npt.NDArray[np.uint]) +assert_type(random_st.randint(I_u4_low, I_u4_high_open, dtype=np.uint), npt.NDArray[np.uint]) +assert_type(random_st.randint(0, I_u4_high_open, dtype=np.uint), npt.NDArray[np.uint]) + +assert_type(random_st.randint(18446744073709551616, dtype="u8"), np.uint64) +assert_type(random_st.randint(0, 18446744073709551616, dtype="u8"), np.uint64) +assert_type(random_st.randint(I_u8_high_open, dtype="u8"), npt.NDArray[np.uint64]) +assert_type(random_st.randint(I_u8_low, I_u8_high_open, dtype="u8"), npt.NDArray[np.uint64]) +assert_type(random_st.randint(0, I_u8_high_open, dtype="u8"), npt.NDArray[np.uint64]) + +assert_type(random_st.randint(18446744073709551616, dtype="uint64"), np.uint64) +assert_type(random_st.randint(0, 18446744073709551616, dtype="uint64"), np.uint64) +assert_type(random_st.randint(I_u8_high_open, dtype="uint64"), npt.NDArray[np.uint64]) +assert_type(random_st.randint(I_u8_low, I_u8_high_open, dtype="uint64"), npt.NDArray[np.uint64]) +assert_type(random_st.randint(0, I_u8_high_open, dtype="uint64"), npt.NDArray[np.uint64]) + +assert_type(random_st.randint(18446744073709551616, dtype=np.uint64), np.uint64) +assert_type(random_st.randint(0, 18446744073709551616, dtype=np.uint64), np.uint64) +assert_type(random_st.randint(I_u8_high_open, dtype=np.uint64), npt.NDArray[np.uint64]) +assert_type(random_st.randint(I_u8_low, I_u8_high_open, dtype=np.uint64), npt.NDArray[np.uint64]) +assert_type(random_st.randint(0, I_u8_high_open, dtype=np.uint64), npt.NDArray[np.uint64]) + +assert_type(random_st.randint(128, dtype="i1"), np.int8) +assert_type(random_st.randint(-128, 128, dtype="i1"), np.int8) +assert_type(random_st.randint(I_i1_high_open, dtype="i1"), npt.NDArray[np.int8]) +assert_type(random_st.randint(I_i1_low, I_i1_high_open, dtype="i1"), npt.NDArray[np.int8]) +assert_type(random_st.randint(-128, I_i1_high_open, dtype="i1"), npt.NDArray[np.int8]) + +assert_type(random_st.randint(128, dtype="int8"), np.int8) +assert_type(random_st.randint(-128, 128, dtype="int8"), np.int8) +assert_type(random_st.randint(I_i1_high_open, dtype="int8"), npt.NDArray[np.int8]) +assert_type(random_st.randint(I_i1_low, I_i1_high_open, dtype="int8"), npt.NDArray[np.int8]) +assert_type(random_st.randint(-128, I_i1_high_open, dtype="int8"), npt.NDArray[np.int8]) + +assert_type(random_st.randint(128, dtype=np.int8), np.int8) +assert_type(random_st.randint(-128, 128, dtype=np.int8), np.int8) +assert_type(random_st.randint(I_i1_high_open, dtype=np.int8), npt.NDArray[np.int8]) +assert_type(random_st.randint(I_i1_low, I_i1_high_open, dtype=np.int8), npt.NDArray[np.int8]) +assert_type(random_st.randint(-128, I_i1_high_open, dtype=np.int8), npt.NDArray[np.int8]) + +assert_type(random_st.randint(32768, dtype="i2"), np.int16) +assert_type(random_st.randint(-32768, 32768, dtype="i2"), np.int16) +assert_type(random_st.randint(I_i2_high_open, dtype="i2"), npt.NDArray[np.int16]) +assert_type(random_st.randint(I_i2_low, I_i2_high_open, dtype="i2"), npt.NDArray[np.int16]) +assert_type(random_st.randint(-32768, I_i2_high_open, dtype="i2"), npt.NDArray[np.int16]) + +assert_type(random_st.randint(32768, dtype="int16"), np.int16) +assert_type(random_st.randint(-32768, 32768, dtype="int16"), np.int16) +assert_type(random_st.randint(I_i2_high_open, dtype="int16"), npt.NDArray[np.int16]) +assert_type(random_st.randint(I_i2_low, I_i2_high_open, dtype="int16"), npt.NDArray[np.int16]) +assert_type(random_st.randint(-32768, I_i2_high_open, dtype="int16"), npt.NDArray[np.int16]) + +assert_type(random_st.randint(32768, dtype=np.int16), np.int16) +assert_type(random_st.randint(-32768, 32768, dtype=np.int16), np.int16) +assert_type(random_st.randint(I_i2_high_open, dtype=np.int16), npt.NDArray[np.int16]) +assert_type(random_st.randint(I_i2_low, I_i2_high_open, dtype=np.int16), npt.NDArray[np.int16]) +assert_type(random_st.randint(-32768, I_i2_high_open, dtype=np.int16), npt.NDArray[np.int16]) + +assert_type(random_st.randint(2147483648, dtype="i4"), np.int32) +assert_type(random_st.randint(-2147483648, 2147483648, dtype="i4"), np.int32) +assert_type(random_st.randint(I_i4_high_open, dtype="i4"), npt.NDArray[np.int32]) +assert_type(random_st.randint(I_i4_low, I_i4_high_open, dtype="i4"), npt.NDArray[np.int32]) +assert_type(random_st.randint(-2147483648, I_i4_high_open, dtype="i4"), npt.NDArray[np.int32]) + +assert_type(random_st.randint(2147483648, dtype="int32"), np.int32) +assert_type(random_st.randint(-2147483648, 2147483648, dtype="int32"), np.int32) +assert_type(random_st.randint(I_i4_high_open, dtype="int32"), npt.NDArray[np.int32]) +assert_type(random_st.randint(I_i4_low, I_i4_high_open, dtype="int32"), npt.NDArray[np.int32]) +assert_type(random_st.randint(-2147483648, I_i4_high_open, dtype="int32"), npt.NDArray[np.int32]) + +assert_type(random_st.randint(2147483648, dtype=np.int32), np.int32) +assert_type(random_st.randint(-2147483648, 2147483648, dtype=np.int32), np.int32) +assert_type(random_st.randint(I_i4_high_open, dtype=np.int32), npt.NDArray[np.int32]) +assert_type(random_st.randint(I_i4_low, I_i4_high_open, dtype=np.int32), npt.NDArray[np.int32]) +assert_type(random_st.randint(-2147483648, I_i4_high_open, dtype=np.int32), npt.NDArray[np.int32]) + +assert_type(random_st.randint(2147483648, dtype=np.int_), np.int_) +assert_type(random_st.randint(-2147483648, 2147483648, dtype=np.int_), np.int_) +assert_type(random_st.randint(I_i4_high_open, dtype=np.int_), npt.NDArray[np.int_]) +assert_type(random_st.randint(I_i4_low, I_i4_high_open, dtype=np.int_), npt.NDArray[np.int_]) +assert_type(random_st.randint(-2147483648, I_i4_high_open, dtype=np.int_), npt.NDArray[np.int_]) + +assert_type(random_st.randint(9223372036854775808, dtype="i8"), np.int64) +assert_type(random_st.randint(-9223372036854775808, 9223372036854775808, dtype="i8"), np.int64) +assert_type(random_st.randint(I_i8_high_open, dtype="i8"), npt.NDArray[np.int64]) +assert_type(random_st.randint(I_i8_low, I_i8_high_open, dtype="i8"), npt.NDArray[np.int64]) +assert_type(random_st.randint(-9223372036854775808, I_i8_high_open, dtype="i8"), npt.NDArray[np.int64]) + +assert_type(random_st.randint(9223372036854775808, dtype="int64"), np.int64) +assert_type(random_st.randint(-9223372036854775808, 9223372036854775808, dtype="int64"), np.int64) +assert_type(random_st.randint(I_i8_high_open, dtype="int64"), npt.NDArray[np.int64]) +assert_type(random_st.randint(I_i8_low, I_i8_high_open, dtype="int64"), npt.NDArray[np.int64]) +assert_type(random_st.randint(-9223372036854775808, I_i8_high_open, dtype="int64"), npt.NDArray[np.int64]) + +assert_type(random_st.randint(9223372036854775808, dtype=np.int64), np.int64) +assert_type(random_st.randint(-9223372036854775808, 9223372036854775808, dtype=np.int64), np.int64) +assert_type(random_st.randint(I_i8_high_open, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(random_st.randint(I_i8_low, I_i8_high_open, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(random_st.randint(-9223372036854775808, I_i8_high_open, dtype=np.int64), npt.NDArray[np.int64]) + +assert_type(random_st._bit_generator, np.random.BitGenerator) + +assert_type(random_st.bytes(2), bytes) + +assert_type(random_st.choice(5), int) +assert_type(random_st.choice(5, 3), npt.NDArray[np.long]) +assert_type(random_st.choice(5, 3, replace=True), npt.NDArray[np.long]) +assert_type(random_st.choice(5, 3, p=[1 / 5] * 5), npt.NDArray[np.long]) +assert_type(random_st.choice(5, 3, p=[1 / 5] * 5, replace=False), npt.NDArray[np.long]) + +assert_type(random_st.choice(["pooh", "rabbit", "piglet", "Christopher"]), Any) +assert_type(random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3), npt.NDArray[Any]) +assert_type(random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, p=[1 / 4] * 4), npt.NDArray[Any]) +assert_type(random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=True), npt.NDArray[Any]) +assert_type(random_st.choice(["pooh", "rabbit", "piglet", "Christopher"], 3, replace=False, p=np.array([1 / 8, 1 / 8, 1 / 2, 1 / 4])), npt.NDArray[Any]) + +assert_type(random_st.dirichlet([0.5, 0.5]), npt.NDArray[np.float64]) +assert_type(random_st.dirichlet(np.array([0.5, 0.5])), npt.NDArray[np.float64]) +assert_type(random_st.dirichlet(np.array([0.5, 0.5]), size=3), npt.NDArray[np.float64]) + +assert_type(random_st.multinomial(20, [1 / 6.0] * 6), npt.NDArray[np.long]) +assert_type(random_st.multinomial(20, np.array([0.5, 0.5])), npt.NDArray[np.long]) +assert_type(random_st.multinomial(20, [1 / 6.0] * 6, size=2), npt.NDArray[np.long]) + +assert_type(random_st.multivariate_normal([0.0], [[1.0]]), npt.NDArray[np.float64]) +assert_type(random_st.multivariate_normal([0.0], np.array([[1.0]])), npt.NDArray[np.float64]) +assert_type(random_st.multivariate_normal(np.array([0.0]), [[1.0]]), npt.NDArray[np.float64]) +assert_type(random_st.multivariate_normal([0.0], np.array([[1.0]])), npt.NDArray[np.float64]) + +assert_type(random_st.permutation(10), npt.NDArray[np.long]) +assert_type(random_st.permutation([1, 2, 3, 4]), npt.NDArray[Any]) +assert_type(random_st.permutation(np.array([1, 2, 3, 4])), npt.NDArray[Any]) +assert_type(random_st.permutation(D_2D), npt.NDArray[Any]) + +assert_type(random_st.shuffle(np.arange(10)), None) +assert_type(random_st.shuffle([1, 2, 3, 4, 5]), None) +assert_type(random_st.shuffle(D_2D), None) + +assert_type(np.random.RandomState(pcg64), np.random.RandomState) +assert_type(np.random.RandomState(0), np.random.RandomState) +assert_type(np.random.RandomState([0, 1, 2]), np.random.RandomState) +assert_type(random_st.__str__(), str) +assert_type(random_st.__repr__(), str) +random_st_state = random_st.__getstate__() +assert_type(random_st_state, dict[str, Any]) +assert_type(random_st.__setstate__(random_st_state), None) +assert_type(random_st.seed(), None) +assert_type(random_st.seed(1), None) +assert_type(random_st.seed([0, 1]), None) +random_st_get_state = random_st.get_state() +assert_type(random_st_state, dict[str, Any]) +random_st_get_state_legacy = random_st.get_state(legacy=True) +assert_type(random_st_get_state_legacy, dict[str, Any] | tuple[str, npt.NDArray[np.uint32], int, int, float]) +assert_type(random_st.set_state(random_st_get_state), None) + +assert_type(random_st.rand(), float) +assert_type(random_st.rand(1), npt.NDArray[np.float64]) +assert_type(random_st.rand(1, 2), npt.NDArray[np.float64]) +assert_type(random_st.randn(), float) +assert_type(random_st.randn(1), npt.NDArray[np.float64]) +assert_type(random_st.randn(1, 2), npt.NDArray[np.float64]) +assert_type(random_st.random_sample(), float) +assert_type(random_st.random_sample(1), npt.NDArray[np.float64]) +assert_type(random_st.random_sample(size=(1, 2)), npt.NDArray[np.float64]) + +assert_type(random_st.tomaxint(), int) +assert_type(random_st.tomaxint(1), npt.NDArray[np.int64]) +assert_type(random_st.tomaxint((1,)), npt.NDArray[np.int64]) + +assert_type(np.random.mtrand.set_bit_generator(pcg64), None) +assert_type(np.random.mtrand.get_bit_generator(), np.random.BitGenerator) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/rec.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/rec.pyi new file mode 100644 index 0000000000000000000000000000000000000000..aacf217e42072cacc9a245d7d7faef8cb340f2c2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/rec.pyi @@ -0,0 +1,171 @@ +import io +from typing import Any, TypeAlias, assert_type + +import numpy as np +import numpy.typing as npt + +_RecArray: TypeAlias = np.recarray[tuple[Any, ...], np.dtype[np.record]] + +AR_i8: npt.NDArray[np.int64] +REC_AR_V: _RecArray +AR_LIST: list[npt.NDArray[np.int64]] + +record: np.record +file_obj: io.BufferedIOBase + +assert_type(np.rec.format_parser( + formats=[np.float64, np.int64, np.bool], + names=["f8", "i8", "?"], + titles=None, + aligned=True, +), np.rec.format_parser) +assert_type(np.rec.format_parser.dtype, np.dtype[np.void]) + +assert_type(record.field_a, Any) +assert_type(record.field_b, Any) +assert_type(record["field_a"], Any) +assert_type(record["field_b"], Any) +assert_type(record.pprint(), str) +record.field_c = 5 + +assert_type(REC_AR_V.field(0), Any) +assert_type(REC_AR_V.field("field_a"), Any) +assert_type(REC_AR_V.field(0, AR_i8), None) +assert_type(REC_AR_V.field("field_a", AR_i8), None) +assert_type(REC_AR_V["field_a"], npt.NDArray[Any]) +assert_type(REC_AR_V.field_a, Any) +assert_type(REC_AR_V.__array_finalize__(object()), None) + +assert_type( + np.recarray( + shape=(10, 5), + formats=[np.float64, np.int64, np.bool], + order="K", + byteorder="|", + ), + _RecArray, +) + +assert_type( + np.recarray( + shape=(10, 5), + dtype=[("f8", np.float64), ("i8", np.int64)], + strides=(5, 5), + ), + np.recarray, +) + +assert_type(np.rec.fromarrays(AR_LIST), np.recarray) +assert_type( + np.rec.fromarrays(AR_LIST, dtype=np.int64), + np.recarray, +) +assert_type( + np.rec.fromarrays( + AR_LIST, + formats=[np.int64, np.float64], + names=["i8", "f8"] + ), + _RecArray, +) + +assert_type( + np.rec.fromrecords((1, 1.5)), + _RecArray +) + +assert_type( + np.rec.fromrecords( + [(1, 1.5)], + dtype=[("i8", np.int64), ("f8", np.float64)], + ), + _RecArray, +) + +assert_type( + np.rec.fromrecords( + REC_AR_V, + formats=[np.int64, np.float64], + names=["i8", "f8"] + ), + _RecArray, +) + +assert_type( + np.rec.fromstring( + b"(1, 1.5)", + dtype=[("i8", np.int64), ("f8", np.float64)], + ), + _RecArray, +) + +assert_type( + np.rec.fromstring( + REC_AR_V, + formats=[np.int64, np.float64], + names=["i8", "f8"] + ), + _RecArray, +) + +assert_type( + np.rec.fromfile( + "test_file.txt", + dtype=[("i8", np.int64), ("f8", np.float64)], + ), + np.recarray, +) + +assert_type( + np.rec.fromfile( + file_obj, + formats=[np.int64, np.float64], + names=["i8", "f8"] + ), + _RecArray, +) + +assert_type(np.rec.array(AR_i8), np.recarray[tuple[Any, ...], np.dtype[np.int64]]) + +assert_type( + np.rec.array([(1, 1.5)], dtype=[("i8", np.int64), ("f8", np.float64)]), + np.recarray, +) + +assert_type( + np.rec.array( + [(1, 1.5)], + formats=[np.int64, np.float64], + names=["i8", "f8"] + ), + _RecArray, +) + +assert_type( + np.rec.array( + None, + dtype=np.float64, + shape=(10, 3), + ), + np.recarray, +) + +assert_type( + np.rec.array( + None, + formats=[np.int64, np.float64], + names=["i8", "f8"], + shape=(10, 3), + ), + _RecArray, +) + +assert_type( + np.rec.array(file_obj, dtype=np.float64), + np.recarray, +) + +assert_type( + np.rec.array(file_obj, formats=[np.int64, np.float64], names=["i8", "f8"]), + _RecArray, +) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/scalars.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/scalars.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d7b277735c7ccb7f3323f282f921d22ee6cf04b2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/scalars.pyi @@ -0,0 +1,191 @@ +from typing import Any, Literal, TypeAlias, assert_type + +import numpy as np + +_1: TypeAlias = Literal[1] + +b: np.bool +u8: np.uint64 +i8: np.int64 +f8: np.float64 +c8: np.complex64 +c16: np.complex128 +m: np.timedelta64 +U: np.str_ +S: np.bytes_ +V: np.void +O: np.object_ # cannot exists at runtime + +array_nd: np.ndarray[Any, Any] +array_0d: np.ndarray[tuple[()], Any] +array_2d_2x2: np.ndarray[tuple[Literal[2], Literal[2]], Any] + +assert_type(c8.real, np.float32) +assert_type(c8.imag, np.float32) + +assert_type(c8.real.real, np.float32) +assert_type(c8.real.imag, np.float32) + +assert_type(c8.itemsize, int) +assert_type(c8.shape, tuple[()]) +assert_type(c8.strides, tuple[()]) + +assert_type(c8.ndim, Literal[0]) +assert_type(c8.size, Literal[1]) + +assert_type(c8.squeeze(), np.complex64) +assert_type(c8.byteswap(), np.complex64) +assert_type(c8.transpose(), np.complex64) + +assert_type(c8.dtype, np.dtype[np.complex64]) + +assert_type(c8.real, np.float32) +assert_type(c16.imag, np.float64) + +assert_type(np.str_('foo'), np.str_) + +assert_type(V[0], Any) +assert_type(V["field1"], Any) +assert_type(V[["field1", "field2"]], np.void) +V[0] = 5 + +# Aliases +assert_type(np.bool_(), np.bool[Literal[False]]) +assert_type(np.byte(), np.byte) +assert_type(np.short(), np.short) +assert_type(np.intc(), np.intc) +assert_type(np.intp(), np.intp) +assert_type(np.int_(), np.int_) +assert_type(np.long(), np.long) +assert_type(np.longlong(), np.longlong) + +assert_type(np.ubyte(), np.ubyte) +assert_type(np.ushort(), np.ushort) +assert_type(np.uintc(), np.uintc) +assert_type(np.uintp(), np.uintp) +assert_type(np.uint(), np.uint) +assert_type(np.ulong(), np.ulong) +assert_type(np.ulonglong(), np.ulonglong) + +assert_type(np.half(), np.half) +assert_type(np.single(), np.single) +assert_type(np.double(), np.double) +assert_type(np.longdouble(), np.longdouble) + +assert_type(np.csingle(), np.csingle) +assert_type(np.cdouble(), np.cdouble) +assert_type(np.clongdouble(), np.clongdouble) + +assert_type(b.item(), bool) +assert_type(i8.item(), int) +assert_type(u8.item(), int) +assert_type(f8.item(), float) +assert_type(c16.item(), complex) +assert_type(U.item(), str) +assert_type(S.item(), bytes) + +assert_type(b.tolist(), bool) +assert_type(i8.tolist(), int) +assert_type(u8.tolist(), int) +assert_type(f8.tolist(), float) +assert_type(c16.tolist(), complex) +assert_type(U.tolist(), str) +assert_type(S.tolist(), bytes) + +assert_type(b.ravel(), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(i8.ravel(), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(u8.ravel(), np.ndarray[tuple[int], np.dtype[np.uint64]]) +assert_type(f8.ravel(), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(c16.ravel(), np.ndarray[tuple[int], np.dtype[np.complex128]]) +assert_type(U.ravel(), np.ndarray[tuple[int], np.dtype[np.str_]]) +assert_type(S.ravel(), np.ndarray[tuple[int], np.dtype[np.bytes_]]) + +assert_type(b.flatten(), np.ndarray[tuple[int], np.dtype[np.bool]]) +assert_type(i8.flatten(), np.ndarray[tuple[int], np.dtype[np.int64]]) +assert_type(u8.flatten(), np.ndarray[tuple[int], np.dtype[np.uint64]]) +assert_type(f8.flatten(), np.ndarray[tuple[int], np.dtype[np.float64]]) +assert_type(c16.flatten(), np.ndarray[tuple[int], np.dtype[np.complex128]]) +assert_type(U.flatten(), np.ndarray[tuple[int], np.dtype[np.str_]]) +assert_type(S.flatten(), np.ndarray[tuple[int], np.dtype[np.bytes_]]) + +assert_type(b.reshape(()), np.bool) +assert_type(i8.reshape([]), np.int64) +assert_type(b.reshape(1), np.ndarray[tuple[_1], np.dtype[np.bool]]) +assert_type(i8.reshape(-1), np.ndarray[tuple[_1], np.dtype[np.int64]]) +assert_type(u8.reshape(1, 1), np.ndarray[tuple[_1, _1], np.dtype[np.uint64]]) +assert_type(f8.reshape(1, -1), np.ndarray[tuple[_1, _1], np.dtype[np.float64]]) +assert_type(c16.reshape(1, 1, 1), np.ndarray[tuple[_1, _1, _1], np.dtype[np.complex128]]) +assert_type(U.reshape(1, 1, 1, 1), np.ndarray[tuple[_1, _1, _1, _1], np.dtype[np.str_]]) +assert_type( + S.reshape(1, 1, 1, 1, 1), + np.ndarray[ + # len(shape) >= 5 + tuple[_1, _1, _1, _1, _1, *tuple[_1, ...]], + np.dtype[np.bytes_], + ], +) + +assert_type(i8.astype(float), Any) +assert_type(i8.astype(np.float64), np.float64) + +assert_type(i8.view(), np.int64) +assert_type(i8.view(np.float64), np.float64) +assert_type(i8.view(float), Any) +assert_type(i8.view(np.float64, np.ndarray), np.float64) + +assert_type(i8.getfield(float), Any) +assert_type(i8.getfield(np.float64), np.float64) +assert_type(i8.getfield(np.float64, 8), np.float64) + +assert_type(f8.as_integer_ratio(), tuple[int, int]) +assert_type(f8.is_integer(), bool) +assert_type(f8.__trunc__(), int) +assert_type(f8.__getformat__("float"), str) +assert_type(f8.hex(), str) +assert_type(np.float64.fromhex("0x0.0p+0"), np.float64) + +assert_type(f8.__getnewargs__(), tuple[float]) +assert_type(c16.__getnewargs__(), tuple[float, float]) + +assert_type(i8.numerator, np.int64) +assert_type(i8.denominator, Literal[1]) +assert_type(u8.numerator, np.uint64) +assert_type(u8.denominator, Literal[1]) +assert_type(m.numerator, np.timedelta64) +assert_type(m.denominator, Literal[1]) + +assert_type(round(i8), int) +assert_type(round(i8, 3), np.int64) +assert_type(round(u8), int) +assert_type(round(u8, 3), np.uint64) +assert_type(round(f8), int) +assert_type(round(f8, 3), np.float64) + +assert_type(f8.__ceil__(), int) +assert_type(f8.__floor__(), int) + +assert_type(i8.is_integer(), Literal[True]) + +assert_type(O.real, np.object_) +assert_type(O.imag, np.object_) +assert_type(int(O), int) +assert_type(float(O), float) +assert_type(complex(O), complex) + +# These fail fail because of a mypy __new__ bug: +# https://github.com/python/mypy/issues/15182 +# According to the typing spec, the following statements are valid, see +# https://typing.readthedocs.io/en/latest/spec/constructors.html#new-method + +# assert_type(np.object_(), None) +# assert_type(np.object_(None), None) +# assert_type(np.object_(array_nd), np.ndarray[Any, np.dtype[np.object_]]) +# assert_type(np.object_([]), npt.NDArray[np.object_]) +# assert_type(np.object_(()), npt.NDArray[np.object_]) +# assert_type(np.object_(range(4)), npt.NDArray[np.object_]) +# assert_type(np.object_(+42), int) +# assert_type(np.object_(1 / 137), float) +# assert_type(np.object_('Developers! ' * (1 << 6)), str) +# assert_type(np.object_(object()), object) +# assert_type(np.object_({False, True, NotADirectoryError}), set[Any]) +# assert_type(np.object_({'spam': 'food', 'ham': 'food'}), dict[str, str]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/shape.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/shape.pyi new file mode 100644 index 0000000000000000000000000000000000000000..2406a39f9682b66623bd9879afdfcfe311c5f869 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/shape.pyi @@ -0,0 +1,13 @@ +from typing import Any, NamedTuple, assert_type + +import numpy as np + +# Subtype of tuple[int, int] +class XYGrid(NamedTuple): + x_axis: int + y_axis: int + +arr: np.ndarray[XYGrid, Any] + +# Test shape property matches shape typevar +assert_type(arr.shape, XYGrid) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/shape_base.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/shape_base.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e409a53bcef95ad549769011cf0a7d28ea78320e --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/shape_base.pyi @@ -0,0 +1,52 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +i8: np.int64 +f8: np.float64 + +AR_b: npt.NDArray[np.bool] +AR_i8: npt.NDArray[np.int64] +AR_f8: npt.NDArray[np.float64] + +AR_LIKE_f8: list[float] + +assert_type(np.take_along_axis(AR_f8, AR_i8, axis=1), npt.NDArray[np.float64]) +assert_type(np.take_along_axis(f8, AR_i8, axis=None), npt.NDArray[np.float64]) + +assert_type(np.put_along_axis(AR_f8, AR_i8, "1.0", axis=1), None) + +assert_type(np.expand_dims(AR_i8, 2), npt.NDArray[np.int64]) +assert_type(np.expand_dims(AR_LIKE_f8, 2), npt.NDArray[Any]) + +assert_type(np.column_stack([AR_i8]), npt.NDArray[np.int64]) +assert_type(np.column_stack([AR_LIKE_f8]), npt.NDArray[Any]) + +assert_type(np.dstack([AR_i8]), npt.NDArray[np.int64]) +assert_type(np.dstack([AR_LIKE_f8]), npt.NDArray[Any]) + +assert_type(np.array_split(AR_i8, [3, 5, 6, 10]), list[npt.NDArray[np.int64]]) +assert_type(np.array_split(AR_LIKE_f8, [3, 5, 6, 10]), list[npt.NDArray[Any]]) + +assert_type(np.split(AR_i8, [3, 5, 6, 10]), list[npt.NDArray[np.int64]]) +assert_type(np.split(AR_LIKE_f8, [3, 5, 6, 10]), list[npt.NDArray[Any]]) + +assert_type(np.hsplit(AR_i8, [3, 5, 6, 10]), list[npt.NDArray[np.int64]]) +assert_type(np.hsplit(AR_LIKE_f8, [3, 5, 6, 10]), list[npt.NDArray[Any]]) + +assert_type(np.vsplit(AR_i8, [3, 5, 6, 10]), list[npt.NDArray[np.int64]]) +assert_type(np.vsplit(AR_LIKE_f8, [3, 5, 6, 10]), list[npt.NDArray[Any]]) + +assert_type(np.dsplit(AR_i8, [3, 5, 6, 10]), list[npt.NDArray[np.int64]]) +assert_type(np.dsplit(AR_LIKE_f8, [3, 5, 6, 10]), list[npt.NDArray[Any]]) + +assert_type(np.kron(AR_b, AR_b), npt.NDArray[np.bool]) +assert_type(np.kron(AR_b, AR_i8), npt.NDArray[np.signedinteger]) +assert_type(np.kron(AR_f8, AR_f8), npt.NDArray[np.floating]) + +assert_type(np.tile(AR_i8, 5), npt.NDArray[np.int64]) +assert_type(np.tile(AR_LIKE_f8, [2, 2]), npt.NDArray[Any]) + +assert_type(np.unstack(AR_i8, axis=0), tuple[npt.NDArray[np.int64], ...]) +assert_type(np.unstack(AR_LIKE_f8, axis=0), tuple[npt.NDArray[Any], ...]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/stride_tricks.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/stride_tricks.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8fde9b8ae30dd6dd25b1d4f6f68c1321e6f3c0a0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/stride_tricks.pyi @@ -0,0 +1,27 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] +AR_LIKE_f: list[float] +interface_dict: dict[str, Any] + +assert_type(np.lib.stride_tricks.as_strided(AR_f8), npt.NDArray[np.float64]) +assert_type(np.lib.stride_tricks.as_strided(AR_LIKE_f), npt.NDArray[Any]) +assert_type(np.lib.stride_tricks.as_strided(AR_f8, strides=(1, 5)), npt.NDArray[np.float64]) +assert_type(np.lib.stride_tricks.as_strided(AR_f8, shape=[9, 20]), npt.NDArray[np.float64]) + +assert_type(np.lib.stride_tricks.sliding_window_view(AR_f8, 5), npt.NDArray[np.float64]) +assert_type(np.lib.stride_tricks.sliding_window_view(AR_LIKE_f, (1, 5)), npt.NDArray[Any]) +assert_type(np.lib.stride_tricks.sliding_window_view(AR_f8, [9], axis=1), npt.NDArray[np.float64]) + +assert_type(np.broadcast_to(AR_f8, 5), npt.NDArray[np.float64]) +assert_type(np.broadcast_to(AR_LIKE_f, (1, 5)), npt.NDArray[Any]) +assert_type(np.broadcast_to(AR_f8, [4, 6], subok=True), npt.NDArray[np.float64]) + +assert_type(np.broadcast_shapes((1, 2), [3, 1], (3, 2)), tuple[Any, ...]) +assert_type(np.broadcast_shapes((6, 7), (5, 6, 1), 7, (5, 1, 7)), tuple[Any, ...]) + +assert_type(np.broadcast_arrays(AR_f8, AR_f8), tuple[npt.NDArray[Any], ...]) +assert_type(np.broadcast_arrays(AR_f8, AR_LIKE_f), tuple[npt.NDArray[Any], ...]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/strings.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/strings.pyi new file mode 100644 index 0000000000000000000000000000000000000000..18bd252d5ff90f34878ba34fd150bf17093cdc61 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/strings.pyi @@ -0,0 +1,196 @@ +from typing import TypeAlias, assert_type + +import numpy as np +import numpy._typing as np_t +import numpy.typing as npt + +AR_T_alias: TypeAlias = np.ndarray[np_t._AnyShape, np.dtypes.StringDType] +AR_TU_alias: TypeAlias = AR_T_alias | npt.NDArray[np.str_] + +AR_U: npt.NDArray[np.str_] +AR_S: npt.NDArray[np.bytes_] +AR_T: AR_T_alias + +assert_type(np.strings.equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.not_equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.not_equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.not_equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.greater_equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.greater_equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.greater_equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.less_equal(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.less_equal(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.less_equal(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.greater(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.greater(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.greater(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.less(AR_U, AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.less(AR_S, AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.less(AR_T, AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.add(AR_U, AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.add(AR_S, AR_S), npt.NDArray[np.bytes_]) +assert_type(np.strings.add(AR_T, AR_T), AR_T_alias) + +assert_type(np.strings.multiply(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.strings.multiply(AR_S, [5, 4, 3]), npt.NDArray[np.bytes_]) +assert_type(np.strings.multiply(AR_T, 5), AR_T_alias) + +assert_type(np.strings.mod(AR_U, "test"), npt.NDArray[np.str_]) +assert_type(np.strings.mod(AR_S, "test"), npt.NDArray[np.bytes_]) +assert_type(np.strings.mod(AR_T, "test"), AR_T_alias) + +assert_type(np.strings.capitalize(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.capitalize(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.strings.capitalize(AR_T), AR_T_alias) + +assert_type(np.strings.center(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.strings.center(AR_S, [2, 3, 4], b"a"), npt.NDArray[np.bytes_]) +assert_type(np.strings.center(AR_T, 5), AR_T_alias) + +assert_type(np.strings.encode(AR_U), npt.NDArray[np.bytes_]) +assert_type(np.strings.encode(AR_T), npt.NDArray[np.bytes_]) +assert_type(np.strings.decode(AR_S), npt.NDArray[np.str_]) + +assert_type(np.strings.expandtabs(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.expandtabs(AR_S, tabsize=4), npt.NDArray[np.bytes_]) +assert_type(np.strings.expandtabs(AR_T), AR_T_alias) + +assert_type(np.strings.ljust(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.strings.ljust(AR_S, [4, 3, 1], fillchar=[b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.strings.ljust(AR_T, 5), AR_T_alias) +assert_type(np.strings.ljust(AR_T, [4, 2, 1], fillchar=["a", "b", "c"]), AR_T_alias) + +assert_type(np.strings.rjust(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.strings.rjust(AR_S, [4, 3, 1], fillchar=[b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.strings.rjust(AR_T, 5), AR_T_alias) +assert_type(np.strings.rjust(AR_T, [4, 2, 1], fillchar=["a", "b", "c"]), AR_T_alias) + +assert_type(np.strings.lstrip(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.lstrip(AR_S, b"_"), npt.NDArray[np.bytes_]) +assert_type(np.strings.lstrip(AR_T), AR_T_alias) +assert_type(np.strings.lstrip(AR_T, "_"), AR_T_alias) + +assert_type(np.strings.rstrip(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.rstrip(AR_S, b"_"), npt.NDArray[np.bytes_]) +assert_type(np.strings.rstrip(AR_T), AR_T_alias) +assert_type(np.strings.rstrip(AR_T, "_"), AR_T_alias) + +assert_type(np.strings.strip(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.strip(AR_S, b"_"), npt.NDArray[np.bytes_]) +assert_type(np.strings.strip(AR_T), AR_T_alias) +assert_type(np.strings.strip(AR_T, "_"), AR_T_alias) + +assert_type(np.strings.count(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.strings.count(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.strings.count(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.strings.count(AR_T, ["a", "b", "c"], end=9), npt.NDArray[np.int_]) + +assert_type(np.strings.partition(AR_U, "\n"), npt.NDArray[np.str_]) +assert_type(np.strings.partition(AR_S, [b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.strings.partition(AR_T, "\n"), AR_TU_alias) + +assert_type(np.strings.rpartition(AR_U, "\n"), npt.NDArray[np.str_]) +assert_type(np.strings.rpartition(AR_S, [b"a", b"b", b"c"]), npt.NDArray[np.bytes_]) +assert_type(np.strings.rpartition(AR_T, "\n"), AR_TU_alias) + +assert_type(np.strings.replace(AR_U, "_", "-"), npt.NDArray[np.str_]) +assert_type(np.strings.replace(AR_S, [b"_", b""], [b"a", b"b"]), npt.NDArray[np.bytes_]) +assert_type(np.strings.replace(AR_T, "_", "_"), AR_TU_alias) + +assert_type(np.strings.lower(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.lower(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.strings.lower(AR_T), AR_T_alias) + +assert_type(np.strings.upper(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.upper(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.strings.upper(AR_T), AR_T_alias) + +assert_type(np.strings.swapcase(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.swapcase(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.strings.swapcase(AR_T), AR_T_alias) + +assert_type(np.strings.title(AR_U), npt.NDArray[np.str_]) +assert_type(np.strings.title(AR_S), npt.NDArray[np.bytes_]) +assert_type(np.strings.title(AR_T), AR_T_alias) + +assert_type(np.strings.zfill(AR_U, 5), npt.NDArray[np.str_]) +assert_type(np.strings.zfill(AR_S, [2, 3, 4]), npt.NDArray[np.bytes_]) +assert_type(np.strings.zfill(AR_T, 5), AR_T_alias) + +assert_type(np.strings.endswith(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) +assert_type(np.strings.endswith(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.bool]) +assert_type(np.strings.endswith(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) + +assert_type(np.strings.startswith(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) +assert_type(np.strings.startswith(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.bool]) +assert_type(np.strings.startswith(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.bool]) + +assert_type(np.strings.find(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.strings.find(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.strings.find(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.strings.rfind(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.strings.rfind(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.strings.rfind(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.strings.index(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.strings.index(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.strings.index(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.strings.rindex(AR_U, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) +assert_type(np.strings.rindex(AR_S, [b"a", b"b", b"c"], end=9), npt.NDArray[np.int_]) +assert_type(np.strings.rindex(AR_T, "a", start=[1, 2, 3]), npt.NDArray[np.int_]) + +assert_type(np.strings.isalpha(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.isalpha(AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.isalpha(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.isalnum(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.isalnum(AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.isalnum(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.isdecimal(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.isdecimal(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.isdigit(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.isdigit(AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.isdigit(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.islower(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.islower(AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.islower(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.isnumeric(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.isnumeric(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.isspace(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.isspace(AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.isspace(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.istitle(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.istitle(AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.istitle(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.isupper(AR_U), npt.NDArray[np.bool]) +assert_type(np.strings.isupper(AR_S), npt.NDArray[np.bool]) +assert_type(np.strings.isupper(AR_T), npt.NDArray[np.bool]) + +assert_type(np.strings.str_len(AR_U), npt.NDArray[np.int_]) +assert_type(np.strings.str_len(AR_S), npt.NDArray[np.int_]) +assert_type(np.strings.str_len(AR_T), npt.NDArray[np.int_]) + +assert_type(np.strings.translate(AR_U, ""), npt.NDArray[np.str_]) +assert_type(np.strings.translate(AR_S, ""), npt.NDArray[np.bytes_]) +assert_type(np.strings.translate(AR_T, ""), AR_T_alias) + +assert_type(np.strings.slice(AR_U, 1, 5, 2), npt.NDArray[np.str_]) +assert_type(np.strings.slice(AR_S, 1, 5, 2), npt.NDArray[np.bytes_]) +assert_type(np.strings.slice(AR_T, 1, 5, 2), AR_T_alias) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/testing.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/testing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d70bc971c15f374ead829db9aff447e361ae1e06 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/testing.pyi @@ -0,0 +1,198 @@ +import contextlib +import re +import sys +import types +import unittest +import warnings +from collections.abc import Callable +from pathlib import Path +from typing import Any, TypeVar, assert_type + +import numpy as np +import numpy.typing as npt + +AR_f8: npt.NDArray[np.float64] +AR_i8: npt.NDArray[np.int64] + +bool_obj: bool +suppress_obj: np.testing.suppress_warnings +FT = TypeVar("FT", bound=Callable[..., Any]) + +def func() -> int: ... + +def func2( + x: npt.NDArray[np.number], + y: npt.NDArray[np.number], +) -> npt.NDArray[np.bool]: ... + +assert_type(np.testing.KnownFailureException(), np.testing.KnownFailureException) +assert_type(np.testing.IgnoreException(), np.testing.IgnoreException) + +assert_type( + np.testing.clear_and_catch_warnings(modules=[np.testing]), + np.testing.clear_and_catch_warnings[None], +) +assert_type( + np.testing.clear_and_catch_warnings(True), + np.testing.clear_and_catch_warnings[list[warnings.WarningMessage]], +) +assert_type( + np.testing.clear_and_catch_warnings(False), + np.testing.clear_and_catch_warnings[None], +) +assert_type( + np.testing.clear_and_catch_warnings(bool_obj), + np.testing.clear_and_catch_warnings, +) +assert_type( + np.testing.clear_and_catch_warnings.class_modules, + tuple[types.ModuleType, ...], +) +assert_type( + np.testing.clear_and_catch_warnings.modules, + set[types.ModuleType], +) + +with np.testing.clear_and_catch_warnings(True) as c1: + assert_type(c1, list[warnings.WarningMessage]) +with np.testing.clear_and_catch_warnings() as c2: + assert_type(c2, None) + +assert_type(np.testing.suppress_warnings("once"), np.testing.suppress_warnings) +assert_type(np.testing.suppress_warnings()(func), Callable[[], int]) +assert_type(suppress_obj.filter(RuntimeWarning), None) +assert_type(suppress_obj.record(RuntimeWarning), list[warnings.WarningMessage]) +with suppress_obj as c3: + assert_type(c3, np.testing.suppress_warnings) + +assert_type(np.testing.verbose, int) +assert_type(np.testing.IS_PYPY, bool) +assert_type(np.testing.HAS_REFCOUNT, bool) +assert_type(np.testing.HAS_LAPACK64, bool) + +assert_type(np.testing.assert_(1, msg="test"), None) +assert_type(np.testing.assert_(2, msg=lambda: "test"), None) + +if sys.platform == "win32" or sys.platform == "cygwin": + assert_type(np.testing.memusage(), int) +elif sys.platform == "linux": + assert_type(np.testing.memusage(), int | None) + +assert_type(np.testing.jiffies(), int) + +assert_type(np.testing.build_err_msg([0, 1, 2], "test"), str) +assert_type(np.testing.build_err_msg(range(2), "test", header="header"), str) +assert_type(np.testing.build_err_msg(np.arange(9).reshape(3, 3), "test", verbose=False), str) +assert_type(np.testing.build_err_msg("abc", "test", names=["x", "y"]), str) +assert_type(np.testing.build_err_msg([1.0, 2.0], "test", precision=5), str) + +assert_type(np.testing.assert_equal({1}, {1}), None) +assert_type(np.testing.assert_equal([1, 2, 3], [1, 2, 3], err_msg="fail"), None) +assert_type(np.testing.assert_equal(1, 1.0, verbose=True), None) + +assert_type(np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]), None) + +assert_type(np.testing.assert_almost_equal(1.0, 1.1), None) +assert_type(np.testing.assert_almost_equal([1, 2, 3], [1, 2, 3], err_msg="fail"), None) +assert_type(np.testing.assert_almost_equal(1, 1.0, verbose=True), None) +assert_type(np.testing.assert_almost_equal(1, 1.0001, decimal=2), None) + +assert_type(np.testing.assert_approx_equal(1.0, 1.1), None) +assert_type(np.testing.assert_approx_equal("1", "2", err_msg="fail"), None) +assert_type(np.testing.assert_approx_equal(1, 1.0, verbose=True), None) +assert_type(np.testing.assert_approx_equal(1, 1.0001, significant=2), None) + +assert_type(np.testing.assert_array_compare(func2, AR_i8, AR_f8, err_msg="test"), None) +assert_type(np.testing.assert_array_compare(func2, AR_i8, AR_f8, verbose=True), None) +assert_type(np.testing.assert_array_compare(func2, AR_i8, AR_f8, header="header"), None) +assert_type(np.testing.assert_array_compare(func2, AR_i8, AR_f8, precision=np.int64()), None) +assert_type(np.testing.assert_array_compare(func2, AR_i8, AR_f8, equal_nan=False), None) +assert_type(np.testing.assert_array_compare(func2, AR_i8, AR_f8, equal_inf=True), None) + +assert_type(np.testing.assert_array_equal(AR_i8, AR_f8), None) +assert_type(np.testing.assert_array_equal(AR_i8, AR_f8, err_msg="test"), None) +assert_type(np.testing.assert_array_equal(AR_i8, AR_f8, verbose=True), None) + +assert_type(np.testing.assert_array_almost_equal(AR_i8, AR_f8), None) +assert_type(np.testing.assert_array_almost_equal(AR_i8, AR_f8, err_msg="test"), None) +assert_type(np.testing.assert_array_almost_equal(AR_i8, AR_f8, verbose=True), None) +assert_type(np.testing.assert_array_almost_equal(AR_i8, AR_f8, decimal=1), None) + +assert_type(np.testing.assert_array_less(AR_i8, AR_f8), None) +assert_type(np.testing.assert_array_less(AR_i8, AR_f8, err_msg="test"), None) +assert_type(np.testing.assert_array_less(AR_i8, AR_f8, verbose=True), None) + +assert_type(np.testing.runstring("1 + 1", {}), Any) +assert_type(np.testing.runstring("int64() + 1", {"int64": np.int64}), Any) + +assert_type(np.testing.assert_string_equal("1", "1"), None) + +assert_type(np.testing.rundocs(), None) +assert_type(np.testing.rundocs("test.py"), None) +assert_type(np.testing.rundocs(Path("test.py"), raise_on_error=True), None) + +def func3(a: int) -> bool: ... + +assert_type( + np.testing.assert_raises(RuntimeWarning), + unittest.case._AssertRaisesContext[RuntimeWarning], +) +assert_type(np.testing.assert_raises(RuntimeWarning, func3, 5), None) + +assert_type( + np.testing.assert_raises_regex(RuntimeWarning, r"test"), + unittest.case._AssertRaisesContext[RuntimeWarning], +) +assert_type(np.testing.assert_raises_regex(RuntimeWarning, b"test", func3, 5), None) +assert_type(np.testing.assert_raises_regex(RuntimeWarning, re.compile(b"test"), func3, 5), None) + +class Test: ... + +def decorate(a: FT) -> FT: + return a + +assert_type(np.testing.decorate_methods(Test, decorate), None) +assert_type(np.testing.decorate_methods(Test, decorate, None), None) +assert_type(np.testing.decorate_methods(Test, decorate, "test"), None) +assert_type(np.testing.decorate_methods(Test, decorate, b"test"), None) +assert_type(np.testing.decorate_methods(Test, decorate, re.compile("test")), None) + +assert_type(np.testing.measure("for i in range(1000): np.sqrt(i**2)"), float) +assert_type(np.testing.measure(b"for i in range(1000): np.sqrt(i**2)", times=5), float) + +assert_type(np.testing.assert_allclose(AR_i8, AR_f8), None) +assert_type(np.testing.assert_allclose(AR_i8, AR_f8, rtol=0.005), None) +assert_type(np.testing.assert_allclose(AR_i8, AR_f8, atol=1), None) +assert_type(np.testing.assert_allclose(AR_i8, AR_f8, equal_nan=True), None) +assert_type(np.testing.assert_allclose(AR_i8, AR_f8, err_msg="err"), None) +assert_type(np.testing.assert_allclose(AR_i8, AR_f8, verbose=False), None) + +assert_type(np.testing.assert_array_almost_equal_nulp(AR_i8, AR_f8, nulp=2), None) + +assert_type(np.testing.assert_array_max_ulp(AR_i8, AR_f8, maxulp=2), npt.NDArray[Any]) +assert_type(np.testing.assert_array_max_ulp(AR_i8, AR_f8, dtype=np.float32), npt.NDArray[Any]) + +assert_type(np.testing.assert_warns(RuntimeWarning), contextlib._GeneratorContextManager[None]) +assert_type(np.testing.assert_warns(RuntimeWarning, func3, 5), bool) + +def func4(a: int, b: str) -> bool: ... + +assert_type(np.testing.assert_no_warnings(), contextlib._GeneratorContextManager[None]) +assert_type(np.testing.assert_no_warnings(func3, 5), bool) +assert_type(np.testing.assert_no_warnings(func4, a=1, b="test"), bool) +assert_type(np.testing.assert_no_warnings(func4, 1, "test"), bool) + +assert_type(np.testing.tempdir("test_dir"), contextlib._GeneratorContextManager[str]) +assert_type(np.testing.tempdir(prefix=b"test"), contextlib._GeneratorContextManager[bytes]) +assert_type(np.testing.tempdir("test_dir", dir=Path("here")), contextlib._GeneratorContextManager[str]) + +assert_type(np.testing.temppath("test_dir", text=True), contextlib._GeneratorContextManager[str]) +assert_type(np.testing.temppath(prefix=b"test"), contextlib._GeneratorContextManager[bytes]) +assert_type(np.testing.temppath("test_dir", dir=Path("here")), contextlib._GeneratorContextManager[str]) + +assert_type(np.testing.assert_no_gc_cycles(), contextlib._GeneratorContextManager[None]) +assert_type(np.testing.assert_no_gc_cycles(func3, 5), None) + +assert_type(np.testing.break_cycles(), None) + +assert_type(np.testing.TestCase(), unittest.case.TestCase) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/twodim_base.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/twodim_base.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7e9563a386117ca75a2650095ddc03b259c0db7b --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/twodim_base.pyi @@ -0,0 +1,145 @@ +from typing import Any, TypeVar, assert_type + +import numpy as np +import numpy.typing as npt + +_ScalarT = TypeVar("_ScalarT", bound=np.generic) + +def func1(ar: npt.NDArray[_ScalarT], a: int) -> npt.NDArray[_ScalarT]: ... + +def func2(ar: npt.NDArray[np.number], a: str) -> npt.NDArray[np.float64]: ... + +AR_b: npt.NDArray[np.bool] +AR_u: npt.NDArray[np.uint64] +AR_i: npt.NDArray[np.int64] +AR_f: npt.NDArray[np.float64] +AR_c: npt.NDArray[np.complex128] +AR_O: npt.NDArray[np.object_] + +AR_LIKE_b: list[bool] +AR_LIKE_c: list[complex] + +assert_type(np.fliplr(AR_b), npt.NDArray[np.bool]) +assert_type(np.fliplr(AR_LIKE_b), npt.NDArray[Any]) + +assert_type(np.flipud(AR_b), npt.NDArray[np.bool]) +assert_type(np.flipud(AR_LIKE_b), npt.NDArray[Any]) + +assert_type(np.eye(10), npt.NDArray[np.float64]) +assert_type(np.eye(10, M=20, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.eye(10, k=2, dtype=int), npt.NDArray[Any]) + +assert_type(np.diag(AR_b), npt.NDArray[np.bool]) +assert_type(np.diag(AR_LIKE_b, k=0), npt.NDArray[Any]) + +assert_type(np.diagflat(AR_b), npt.NDArray[np.bool]) +assert_type(np.diagflat(AR_LIKE_b, k=0), npt.NDArray[Any]) + +assert_type(np.tri(10), npt.NDArray[np.float64]) +assert_type(np.tri(10, M=20, dtype=np.int64), npt.NDArray[np.int64]) +assert_type(np.tri(10, k=2, dtype=int), npt.NDArray[Any]) + +assert_type(np.tril(AR_b), npt.NDArray[np.bool]) +assert_type(np.tril(AR_LIKE_b, k=0), npt.NDArray[Any]) + +assert_type(np.triu(AR_b), npt.NDArray[np.bool]) +assert_type(np.triu(AR_LIKE_b, k=0), npt.NDArray[Any]) + +assert_type(np.vander(AR_b), npt.NDArray[np.signedinteger]) +assert_type(np.vander(AR_u), npt.NDArray[np.signedinteger]) +assert_type(np.vander(AR_i, N=2), npt.NDArray[np.signedinteger]) +assert_type(np.vander(AR_f, increasing=True), npt.NDArray[np.floating]) +assert_type(np.vander(AR_c), npt.NDArray[np.complexfloating]) +assert_type(np.vander(AR_O), npt.NDArray[np.object_]) + +assert_type( + np.histogram2d(AR_LIKE_c, AR_LIKE_c), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.complex128 | np.float64], + npt.NDArray[np.complex128 | np.float64], + ], +) +assert_type( + np.histogram2d(AR_i, AR_b), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ], +) +assert_type( + np.histogram2d(AR_f, AR_i), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ], +) +assert_type( + np.histogram2d(AR_i, AR_f), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ], +) +assert_type( + np.histogram2d(AR_f, AR_c, weights=AR_LIKE_b), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.complex128], + npt.NDArray[np.complex128], + ], +) +assert_type( + np.histogram2d(AR_f, AR_c, bins=8), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.complex128], + npt.NDArray[np.complex128], + ], +) +assert_type( + np.histogram2d(AR_c, AR_f, bins=(8, 5)), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.complex128], + npt.NDArray[np.complex128], + ], +) +assert_type( + np.histogram2d(AR_c, AR_i, bins=AR_u), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.uint64], + npt.NDArray[np.uint64], + ], +) +assert_type( + np.histogram2d(AR_c, AR_c, bins=(AR_u, AR_u)), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.uint64], + npt.NDArray[np.uint64], + ], +) +assert_type( + np.histogram2d(AR_c, AR_c, bins=(AR_b, 8)), + tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.bool | np.complex128], + npt.NDArray[np.bool | np.complex128], + ], +) + +assert_type(np.mask_indices(10, func1), tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]) +assert_type(np.mask_indices(8, func2, "0"), tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]) + +assert_type(np.tril_indices(10), tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]) + +assert_type(np.tril_indices_from(AR_b), tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]) + +assert_type(np.triu_indices(10), tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]) + +assert_type(np.triu_indices_from(AR_b), tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/type_check.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/type_check.pyi new file mode 100644 index 0000000000000000000000000000000000000000..df95da78ffb7a4669b88310739bb710d7d41f6ae --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/type_check.pyi @@ -0,0 +1,67 @@ +from typing import Any, Literal, assert_type + +import numpy as np +import numpy.typing as npt + +f8: np.float64 +f: float + +# NOTE: Avoid importing the platform specific `np.float128` type +AR_i8: npt.NDArray[np.int64] +AR_i4: npt.NDArray[np.int32] +AR_f2: npt.NDArray[np.float16] +AR_f8: npt.NDArray[np.float64] +AR_f16: npt.NDArray[np.longdouble] +AR_c8: npt.NDArray[np.complex64] +AR_c16: npt.NDArray[np.complex128] + +AR_LIKE_f: list[float] + +class ComplexObj: + real: slice + imag: slice + +assert_type(np.mintypecode(["f8"], typeset="qfQF"), str) + +assert_type(np.real(ComplexObj()), slice) +assert_type(np.real(AR_f8), npt.NDArray[np.float64]) +assert_type(np.real(AR_c16), npt.NDArray[np.float64]) +assert_type(np.real(AR_LIKE_f), npt.NDArray[Any]) + +assert_type(np.imag(ComplexObj()), slice) +assert_type(np.imag(AR_f8), npt.NDArray[np.float64]) +assert_type(np.imag(AR_c16), npt.NDArray[np.float64]) +assert_type(np.imag(AR_LIKE_f), npt.NDArray[Any]) + +assert_type(np.iscomplex(f8), np.bool) +assert_type(np.iscomplex(AR_f8), npt.NDArray[np.bool]) +assert_type(np.iscomplex(AR_LIKE_f), npt.NDArray[np.bool]) + +assert_type(np.isreal(f8), np.bool) +assert_type(np.isreal(AR_f8), npt.NDArray[np.bool]) +assert_type(np.isreal(AR_LIKE_f), npt.NDArray[np.bool]) + +assert_type(np.iscomplexobj(f8), bool) +assert_type(np.isrealobj(f8), bool) + +assert_type(np.nan_to_num(f8), np.float64) +assert_type(np.nan_to_num(f, copy=True), Any) +assert_type(np.nan_to_num(AR_f8, nan=1.5), npt.NDArray[np.float64]) +assert_type(np.nan_to_num(AR_LIKE_f, posinf=9999), npt.NDArray[Any]) + +assert_type(np.real_if_close(AR_f8), npt.NDArray[np.float64]) +assert_type(np.real_if_close(AR_c16), npt.NDArray[np.float64 | np.complex128]) +assert_type(np.real_if_close(AR_c8), npt.NDArray[np.float32 | np.complex64]) +assert_type(np.real_if_close(AR_LIKE_f), npt.NDArray[Any]) + +assert_type(np.typename("h"), Literal["short"]) +assert_type(np.typename("B"), Literal["unsigned char"]) +assert_type(np.typename("V"), Literal["void"]) +assert_type(np.typename("S1"), Literal["character"]) + +assert_type(np.common_type(AR_i4), type[np.float64]) +assert_type(np.common_type(AR_f2), type[np.float16]) +assert_type(np.common_type(AR_f2, AR_i4), type[np.float64]) +assert_type(np.common_type(AR_f16, AR_i4), type[np.longdouble]) +assert_type(np.common_type(AR_c8, AR_f2), type[np.complex64]) +assert_type(np.common_type(AR_f2, AR_c8, AR_i4), type[np.complexfloating]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufunc_config.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufunc_config.pyi new file mode 100644 index 0000000000000000000000000000000000000000..748507530aa11fae503e95cf43473f1dde496f80 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufunc_config.pyi @@ -0,0 +1,30 @@ +"""Typing tests for `_core._ufunc_config`.""" + +from collections.abc import Callable +from typing import Any, assert_type + +from _typeshed import SupportsWrite + +import numpy as np + +def func(a: str, b: int) -> None: ... + +class Write: + def write(self, value: str) -> None: ... + +assert_type(np.seterr(all=None), np._core._ufunc_config._ErrDict) +assert_type(np.seterr(divide="ignore"), np._core._ufunc_config._ErrDict) +assert_type(np.seterr(over="warn"), np._core._ufunc_config._ErrDict) +assert_type(np.seterr(under="call"), np._core._ufunc_config._ErrDict) +assert_type(np.seterr(invalid="raise"), np._core._ufunc_config._ErrDict) +assert_type(np.geterr(), np._core._ufunc_config._ErrDict) + +assert_type(np.setbufsize(4096), int) +assert_type(np.getbufsize(), int) + +assert_type(np.seterrcall(func), Callable[[str, int], Any] | SupportsWrite[str] | None) +assert_type(np.seterrcall(Write()), Callable[[str, int], Any] | SupportsWrite[str] | None) +assert_type(np.geterrcall(), Callable[[str, int], Any] | SupportsWrite[str] | None) + +assert_type(np.errstate(call=func, all="call"), np.errstate) +assert_type(np.errstate(call=Write(), divide="log", over="log"), np.errstate) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufunclike.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufunclike.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a0ede60e01585e73634f321d908dbe5cbe948629 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufunclike.pyi @@ -0,0 +1,31 @@ +from typing import Any, assert_type + +import numpy as np +import numpy.typing as npt + +AR_LIKE_b: list[bool] +AR_LIKE_u: list[np.uint32] +AR_LIKE_i: list[int] +AR_LIKE_f: list[float] +AR_LIKE_O: list[np.object_] + +AR_U: npt.NDArray[np.str_] + +assert_type(np.fix(AR_LIKE_b), npt.NDArray[np.floating]) +assert_type(np.fix(AR_LIKE_u), npt.NDArray[np.floating]) +assert_type(np.fix(AR_LIKE_i), npt.NDArray[np.floating]) +assert_type(np.fix(AR_LIKE_f), npt.NDArray[np.floating]) +assert_type(np.fix(AR_LIKE_O), npt.NDArray[np.object_]) +assert_type(np.fix(AR_LIKE_f, out=AR_U), npt.NDArray[np.str_]) + +assert_type(np.isposinf(AR_LIKE_b), npt.NDArray[np.bool]) +assert_type(np.isposinf(AR_LIKE_u), npt.NDArray[np.bool]) +assert_type(np.isposinf(AR_LIKE_i), npt.NDArray[np.bool]) +assert_type(np.isposinf(AR_LIKE_f), npt.NDArray[np.bool]) +assert_type(np.isposinf(AR_LIKE_f, out=AR_U), npt.NDArray[np.str_]) + +assert_type(np.isneginf(AR_LIKE_b), npt.NDArray[np.bool]) +assert_type(np.isneginf(AR_LIKE_u), npt.NDArray[np.bool]) +assert_type(np.isneginf(AR_LIKE_i), npt.NDArray[np.bool]) +assert_type(np.isneginf(AR_LIKE_f), npt.NDArray[np.bool]) +assert_type(np.isneginf(AR_LIKE_f, out=AR_U), npt.NDArray[np.str_]) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufuncs.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufuncs.pyi new file mode 100644 index 0000000000000000000000000000000000000000..93a8bfb15d06ddb3f7291f8399a4c2db2a3f6568 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/ufuncs.pyi @@ -0,0 +1,123 @@ +from typing import Any, Literal, NoReturn, assert_type + +import numpy as np +import numpy.typing as npt + +i8: np.int64 +f8: np.float64 +AR_f8: npt.NDArray[np.float64] +AR_i8: npt.NDArray[np.int64] + +assert_type(np.absolute.__doc__, str) +assert_type(np.absolute.types, list[str]) + +assert_type(np.absolute.__name__, Literal["absolute"]) +assert_type(np.absolute.__qualname__, Literal["absolute"]) +assert_type(np.absolute.ntypes, Literal[20]) +assert_type(np.absolute.identity, None) +assert_type(np.absolute.nin, Literal[1]) +assert_type(np.absolute.nin, Literal[1]) +assert_type(np.absolute.nout, Literal[1]) +assert_type(np.absolute.nargs, Literal[2]) +assert_type(np.absolute.signature, None) +assert_type(np.absolute(f8), Any) +assert_type(np.absolute(AR_f8), npt.NDArray[Any]) +assert_type(np.absolute.at(AR_f8, AR_i8), None) + +assert_type(np.add.__name__, Literal["add"]) +assert_type(np.add.__qualname__, Literal["add"]) +assert_type(np.add.ntypes, Literal[22]) +assert_type(np.add.identity, Literal[0]) +assert_type(np.add.nin, Literal[2]) +assert_type(np.add.nout, Literal[1]) +assert_type(np.add.nargs, Literal[3]) +assert_type(np.add.signature, None) +assert_type(np.add(f8, f8), Any) +assert_type(np.add(AR_f8, f8), npt.NDArray[Any]) +assert_type(np.add.at(AR_f8, AR_i8, f8), None) +assert_type(np.add.reduce(AR_f8, axis=0), Any) +assert_type(np.add.accumulate(AR_f8), npt.NDArray[Any]) +assert_type(np.add.reduceat(AR_f8, AR_i8), npt.NDArray[Any]) +assert_type(np.add.outer(f8, f8), Any) +assert_type(np.add.outer(AR_f8, f8), npt.NDArray[Any]) + +assert_type(np.frexp.__name__, Literal["frexp"]) +assert_type(np.frexp.__qualname__, Literal["frexp"]) +assert_type(np.frexp.ntypes, Literal[4]) +assert_type(np.frexp.identity, None) +assert_type(np.frexp.nin, Literal[1]) +assert_type(np.frexp.nout, Literal[2]) +assert_type(np.frexp.nargs, Literal[3]) +assert_type(np.frexp.signature, None) +assert_type(np.frexp(f8), tuple[Any, Any]) +assert_type(np.frexp(AR_f8), tuple[npt.NDArray[Any], npt.NDArray[Any]]) + +assert_type(np.divmod.__name__, Literal["divmod"]) +assert_type(np.divmod.__qualname__, Literal["divmod"]) +assert_type(np.divmod.ntypes, Literal[15]) +assert_type(np.divmod.identity, None) +assert_type(np.divmod.nin, Literal[2]) +assert_type(np.divmod.nout, Literal[2]) +assert_type(np.divmod.nargs, Literal[4]) +assert_type(np.divmod.signature, None) +assert_type(np.divmod(f8, f8), tuple[Any, Any]) +assert_type(np.divmod(AR_f8, f8), tuple[npt.NDArray[Any], npt.NDArray[Any]]) + +assert_type(np.matmul.__name__, Literal["matmul"]) +assert_type(np.matmul.__qualname__, Literal["matmul"]) +assert_type(np.matmul.ntypes, Literal[19]) +assert_type(np.matmul.identity, None) +assert_type(np.matmul.nin, Literal[2]) +assert_type(np.matmul.nout, Literal[1]) +assert_type(np.matmul.nargs, Literal[3]) +assert_type(np.matmul.signature, Literal["(n?,k),(k,m?)->(n?,m?)"]) +assert_type(np.matmul.identity, None) +assert_type(np.matmul(AR_f8, AR_f8), Any) +assert_type(np.matmul(AR_f8, AR_f8, axes=[(0, 1), (0, 1), (0, 1)]), Any) + +assert_type(np.vecdot.__name__, Literal["vecdot"]) +assert_type(np.vecdot.__qualname__, Literal["vecdot"]) +assert_type(np.vecdot.ntypes, Literal[19]) +assert_type(np.vecdot.identity, None) +assert_type(np.vecdot.nin, Literal[2]) +assert_type(np.vecdot.nout, Literal[1]) +assert_type(np.vecdot.nargs, Literal[3]) +assert_type(np.vecdot.signature, Literal["(n),(n)->()"]) +assert_type(np.vecdot.identity, None) +assert_type(np.vecdot(AR_f8, AR_f8), Any) + +assert_type(np.bitwise_count.__name__, Literal["bitwise_count"]) +assert_type(np.bitwise_count.__qualname__, Literal["bitwise_count"]) +assert_type(np.bitwise_count.ntypes, Literal[11]) +assert_type(np.bitwise_count.identity, None) +assert_type(np.bitwise_count.nin, Literal[1]) +assert_type(np.bitwise_count.nout, Literal[1]) +assert_type(np.bitwise_count.nargs, Literal[2]) +assert_type(np.bitwise_count.signature, None) +assert_type(np.bitwise_count.identity, None) +assert_type(np.bitwise_count(i8), Any) +assert_type(np.bitwise_count(AR_i8), npt.NDArray[Any]) + +assert_type(np.absolute.outer(), NoReturn) +assert_type(np.frexp.outer(), NoReturn) +assert_type(np.divmod.outer(), NoReturn) +assert_type(np.matmul.outer(), NoReturn) + +assert_type(np.absolute.reduceat(), NoReturn) +assert_type(np.frexp.reduceat(), NoReturn) +assert_type(np.divmod.reduceat(), NoReturn) +assert_type(np.matmul.reduceat(), NoReturn) + +assert_type(np.absolute.reduce(), NoReturn) +assert_type(np.frexp.reduce(), NoReturn) +assert_type(np.divmod.reduce(), NoReturn) +assert_type(np.matmul.reduce(), NoReturn) + +assert_type(np.absolute.accumulate(), NoReturn) +assert_type(np.frexp.accumulate(), NoReturn) +assert_type(np.divmod.accumulate(), NoReturn) +assert_type(np.matmul.accumulate(), NoReturn) + +assert_type(np.frexp.at(), NoReturn) +assert_type(np.divmod.at(), NoReturn) +assert_type(np.matmul.at(), NoReturn) diff --git a/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/warnings_and_errors.pyi b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/warnings_and_errors.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f756a8e45d46629d499612c5452932811dc95084 --- /dev/null +++ b/venv/lib/python3.13/site-packages/numpy/typing/tests/data/reveal/warnings_and_errors.pyi @@ -0,0 +1,11 @@ +from typing import assert_type + +import numpy.exceptions as ex + +assert_type(ex.ModuleDeprecationWarning(), ex.ModuleDeprecationWarning) +assert_type(ex.VisibleDeprecationWarning(), ex.VisibleDeprecationWarning) +assert_type(ex.ComplexWarning(), ex.ComplexWarning) +assert_type(ex.RankWarning(), ex.RankWarning) +assert_type(ex.TooHardError(), ex.TooHardError) +assert_type(ex.AxisError("test"), ex.AxisError) +assert_type(ex.AxisError(5, 1), ex.AxisError) diff --git a/venv/lib/python3.13/site-packages/regex-2025.11.3.dist-info/licenses/LICENSE.txt b/venv/lib/python3.13/site-packages/regex-2025.11.3.dist-info/licenses/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..99c19cf844141395a87c34325da3a537f1e95815 --- /dev/null +++ b/venv/lib/python3.13/site-packages/regex-2025.11.3.dist-info/licenses/LICENSE.txt @@ -0,0 +1,208 @@ +This work was derived from the 're' module of CPython 2.6 and CPython 3.1, +copyright (c) 1998-2001 by Secret Labs AB and licensed under CNRI's Python 1.6 +license. + +All additions and alterations are licensed under the Apache 2.0 License. + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 Matthew Barnett + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/venv/lib/python3.13/site-packages/regex/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/regex/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11bb65327facaaa882afdb2b21192f4fb577205b Binary files /dev/null and b/venv/lib/python3.13/site-packages/regex/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/regex/__pycache__/_main.cpython-313.pyc b/venv/lib/python3.13/site-packages/regex/__pycache__/_main.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7426a27a8d0ef49197759ff674f8017e21a37c30 Binary files /dev/null and b/venv/lib/python3.13/site-packages/regex/__pycache__/_main.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/regex/tests/test_regex.py b/venv/lib/python3.13/site-packages/regex/tests/test_regex.py new file mode 100644 index 0000000000000000000000000000000000000000..b391472b5bb4ea94924376c448a3674b8c6ff8ad --- /dev/null +++ b/venv/lib/python3.13/site-packages/regex/tests/test_regex.py @@ -0,0 +1,4540 @@ +from weakref import proxy +import copy +import pickle +import regex +import string +import sys +import unittest + +# String subclasses for issue 18468. +class StrSubclass(str): + def __getitem__(self, index): + return StrSubclass(super().__getitem__(index)) + +class BytesSubclass(bytes): + def __getitem__(self, index): + return BytesSubclass(super().__getitem__(index)) + +class RegexTests(unittest.TestCase): + PATTERN_CLASS = "" + FLAGS_WITH_COMPILED_PAT = "cannot process flags argument with a compiled pattern" + INVALID_GROUP_REF = "invalid group reference" + MISSING_GT = "missing >" + BAD_GROUP_NAME = "bad character in group name" + MISSING_GROUP_NAME = "missing group name" + MISSING_LT = "missing <" + UNKNOWN_GROUP_I = "unknown group" + UNKNOWN_GROUP = "unknown group" + BAD_ESCAPE = r"bad escape \(end of pattern\)" + BAD_OCTAL_ESCAPE = r"bad escape \\" + BAD_SET = "unterminated character set" + STR_PAT_ON_BYTES = "cannot use a string pattern on a bytes-like object" + BYTES_PAT_ON_STR = "cannot use a bytes pattern on a string-like object" + STR_PAT_BYTES_TEMPL = "expected str instance, bytes found" + BYTES_PAT_STR_TEMPL = "expected a bytes-like object, str found" + BYTES_PAT_UNI_FLAG = "cannot use UNICODE flag with a bytes pattern" + MIXED_FLAGS = "ASCII, LOCALE and UNICODE flags are mutually incompatible" + MISSING_RPAREN = "missing \\)" + TRAILING_CHARS = "unbalanced parenthesis" + BAD_CHAR_RANGE = "bad character range" + NOTHING_TO_REPEAT = "nothing to repeat" + MULTIPLE_REPEAT = "multiple repeat" + OPEN_GROUP = "cannot refer to an open group" + DUPLICATE_GROUP = "duplicate group" + CANT_TURN_OFF = "bad inline flags: cannot turn flags off" + UNDEF_CHAR_NAME = "undefined character name" + + def assertTypedEqual(self, actual, expect, msg=None): + self.assertEqual(actual, expect, msg) + + def recurse(actual, expect): + if isinstance(expect, (tuple, list)): + for x, y in zip(actual, expect): + recurse(x, y) + else: + self.assertIs(type(actual), type(expect), msg) + + recurse(actual, expect) + + def test_weakref(self): + s = 'QabbbcR' + x = regex.compile('ab+c') + y = proxy(x) + if x.findall('QabbbcR') != y.findall('QabbbcR'): + self.fail() + + def test_search_star_plus(self): + self.assertEqual(regex.search('a*', 'xxx').span(0), (0, 0)) + self.assertEqual(regex.search('x*', 'axx').span(), (0, 0)) + self.assertEqual(regex.search('x+', 'axx').span(0), (1, 3)) + self.assertEqual(regex.search('x+', 'axx').span(), (1, 3)) + self.assertEqual(regex.search('x', 'aaa'), None) + self.assertEqual(regex.match('a*', 'xxx').span(0), (0, 0)) + self.assertEqual(regex.match('a*', 'xxx').span(), (0, 0)) + self.assertEqual(regex.match('x*', 'xxxa').span(0), (0, 3)) + self.assertEqual(regex.match('x*', 'xxxa').span(), (0, 3)) + self.assertEqual(regex.match('a+', 'xxx'), None) + + def bump_num(self, matchobj): + int_value = int(matchobj[0]) + return str(int_value + 1) + + def test_basic_regex_sub(self): + self.assertEqual(regex.sub("(?i)b+", "x", "bbbb BBBB"), 'x x') + self.assertEqual(regex.sub(r'\d+', self.bump_num, '08.2 -2 23x99y'), + '9.3 -3 24x100y') + self.assertEqual(regex.sub(r'\d+', self.bump_num, '08.2 -2 23x99y', 3), + '9.3 -3 23x99y') + + self.assertEqual(regex.sub('.', lambda m: r"\n", 'x'), "\\n") + self.assertEqual(regex.sub('.', r"\n", 'x'), "\n") + + self.assertEqual(regex.sub('(?Px)', r'\g\g', 'xx'), 'xxxx') + self.assertEqual(regex.sub('(?Px)', r'\g\g<1>', 'xx'), 'xxxx') + self.assertEqual(regex.sub('(?Px)', r'\g\g', 'xx'), + 'xxxx') + self.assertEqual(regex.sub('(?Px)', r'\g<1>\g<1>', 'xx'), 'xxxx') + + self.assertEqual(regex.sub('a', r'\t\n\v\r\f\a\b', 'a'), "\t\n\v\r\f\a\b") + self.assertEqual(regex.sub('a', '\t\n\v\r\f\a', 'a'), "\t\n\v\r\f\a") + self.assertEqual(regex.sub('a', '\t\n\v\r\f\a', 'a'), chr(9) + chr(10) + + chr(11) + chr(13) + chr(12) + chr(7)) + + self.assertEqual(regex.sub(r'^\s*', 'X', 'test'), 'Xtest') + + self.assertEqual(regex.sub(r"x", r"\x0A", "x"), "\n") + self.assertEqual(regex.sub(r"x", r"\u000A", "x"), "\n") + self.assertEqual(regex.sub(r"x", r"\U0000000A", "x"), "\n") + self.assertEqual(regex.sub(r"x", r"\N{LATIN CAPITAL LETTER A}", + "x"), "A") + + self.assertEqual(regex.sub(br"x", br"\x0A", b"x"), b"\n") + + def test_bug_449964(self): + # Fails for group followed by other escape. + self.assertEqual(regex.sub(r'(?Px)', r'\g<1>\g<1>\b', 'xx'), + "xx\bxx\b") + + def test_bug_449000(self): + # Test for sub() on escaped characters. + self.assertEqual(regex.sub(r'\r\n', r'\n', 'abc\r\ndef\r\n'), + "abc\ndef\n") + self.assertEqual(regex.sub('\r\n', r'\n', 'abc\r\ndef\r\n'), + "abc\ndef\n") + self.assertEqual(regex.sub(r'\r\n', '\n', 'abc\r\ndef\r\n'), + "abc\ndef\n") + self.assertEqual(regex.sub('\r\n', '\n', 'abc\r\ndef\r\n'), + "abc\ndef\n") + + def test_bug_1661(self): + # Verify that flags do not get silently ignored with compiled patterns + pattern = regex.compile('.') + self.assertRaisesRegex(ValueError, self.FLAGS_WITH_COMPILED_PAT, + lambda: regex.match(pattern, 'A', regex.I)) + self.assertRaisesRegex(ValueError, self.FLAGS_WITH_COMPILED_PAT, + lambda: regex.search(pattern, 'A', regex.I)) + self.assertRaisesRegex(ValueError, self.FLAGS_WITH_COMPILED_PAT, + lambda: regex.findall(pattern, 'A', regex.I)) + self.assertRaisesRegex(ValueError, self.FLAGS_WITH_COMPILED_PAT, + lambda: regex.compile(pattern, regex.I)) + + def test_bug_3629(self): + # A regex that triggered a bug in the sre-code validator + self.assertEqual(repr(type(regex.compile("(?P)(?(quote))"))), + self.PATTERN_CLASS) + + def test_sub_template_numeric_escape(self): + # Bug 776311 and friends. + self.assertEqual(regex.sub('x', r'\0', 'x'), "\0") + self.assertEqual(regex.sub('x', r'\000', 'x'), "\000") + self.assertEqual(regex.sub('x', r'\001', 'x'), "\001") + self.assertEqual(regex.sub('x', r'\008', 'x'), "\0" + "8") + self.assertEqual(regex.sub('x', r'\009', 'x'), "\0" + "9") + self.assertEqual(regex.sub('x', r'\111', 'x'), "\111") + self.assertEqual(regex.sub('x', r'\117', 'x'), "\117") + + self.assertEqual(regex.sub('x', r'\1111', 'x'), "\1111") + self.assertEqual(regex.sub('x', r'\1111', 'x'), "\111" + "1") + + self.assertEqual(regex.sub('x', r'\00', 'x'), '\x00') + self.assertEqual(regex.sub('x', r'\07', 'x'), '\x07') + self.assertEqual(regex.sub('x', r'\08', 'x'), "\0" + "8") + self.assertEqual(regex.sub('x', r'\09', 'x'), "\0" + "9") + self.assertEqual(regex.sub('x', r'\0a', 'x'), "\0" + "a") + + self.assertEqual(regex.sub('x', r'\400', 'x'), "\u0100") + self.assertEqual(regex.sub('x', r'\777', 'x'), "\u01FF") + self.assertEqual(regex.sub(b'x', br'\400', b'x'), b"\x00") + self.assertEqual(regex.sub(b'x', br'\777', b'x'), b"\xFF") + + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\1', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\8', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\9', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\11', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\18', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\1a', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\90', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\99', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\118', 'x')) # r'\11' + '8' + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\11a', 'x')) + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\181', 'x')) # r'\18' + '1' + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.sub('x', r'\800', 'x')) # r'\80' + '0' + + # In Python 2.3 (etc), these loop endlessly in sre_parser.py. + self.assertEqual(regex.sub('(((((((((((x)))))))))))', r'\11', 'x'), + 'x') + self.assertEqual(regex.sub('((((((((((y))))))))))(.)', r'\118', 'xyz'), + 'xz8') + self.assertEqual(regex.sub('((((((((((y))))))))))(.)', r'\11a', 'xyz'), + 'xza') + + def test_qualified_re_sub(self): + self.assertEqual(regex.sub('a', 'b', 'aaaaa'), 'bbbbb') + self.assertEqual(regex.sub('a', 'b', 'aaaaa', 1), 'baaaa') + + def test_bug_114660(self): + self.assertEqual(regex.sub(r'(\S)\s+(\S)', r'\1 \2', 'hello there'), + 'hello there') + + def test_bug_462270(self): + # Test for empty sub() behaviour, see SF bug #462270 + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.sub('(?V0)x*', '-', 'abxd'), '-a-b--d-') + else: + self.assertEqual(regex.sub('(?V0)x*', '-', 'abxd'), '-a-b-d-') + self.assertEqual(regex.sub('(?V1)x*', '-', 'abxd'), '-a-b--d-') + self.assertEqual(regex.sub('x+', '-', 'abxd'), 'ab-d') + + def test_bug_14462(self): + # chr(255) is a valid identifier in Python 3. + group_name = '\xFF' + self.assertEqual(regex.search(r'(?P<' + group_name + '>a)', + 'abc').group(group_name), 'a') + + def test_symbolic_refs(self): + self.assertRaisesRegex(regex.error, self.MISSING_GT, lambda: + regex.sub('(?Px)', r'\gx)', r'\g<', 'xx')) + self.assertRaisesRegex(regex.error, self.MISSING_LT, lambda: + regex.sub('(?Px)', r'\g', 'xx')) + self.assertRaisesRegex(regex.error, self.BAD_GROUP_NAME, lambda: + regex.sub('(?Px)', r'\g', 'xx')) + self.assertRaisesRegex(regex.error, self.BAD_GROUP_NAME, lambda: + regex.sub('(?Px)', r'\g<1a1>', 'xx')) + self.assertRaisesRegex(IndexError, self.UNKNOWN_GROUP_I, lambda: + regex.sub('(?Px)', r'\g', 'xx')) + + # The new behaviour of unmatched but valid groups is to treat them like + # empty matches in the replacement template, like in Perl. + self.assertEqual(regex.sub('(?Px)|(?Py)', r'\g', 'xx'), '') + self.assertEqual(regex.sub('(?Px)|(?Py)', r'\2', 'xx'), '') + + # The old behaviour was to raise it as an IndexError. + self.assertRaisesRegex(regex.error, self.BAD_GROUP_NAME, lambda: + regex.sub('(?Px)', r'\g<-1>', 'xx')) + + def test_re_subn(self): + self.assertEqual(regex.subn("(?i)b+", "x", "bbbb BBBB"), ('x x', 2)) + self.assertEqual(regex.subn("b+", "x", "bbbb BBBB"), ('x BBBB', 1)) + self.assertEqual(regex.subn("b+", "x", "xyz"), ('xyz', 0)) + self.assertEqual(regex.subn("b*", "x", "xyz"), ('xxxyxzx', 4)) + self.assertEqual(regex.subn("b*", "x", "xyz", 2), ('xxxyz', 2)) + + def test_re_split(self): + self.assertEqual(regex.split(":", ":a:b::c"), ['', 'a', 'b', '', 'c']) + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.split(":*", ":a:b::c"), ['', '', 'a', '', + 'b', '', 'c', '']) + self.assertEqual(regex.split("(:*)", ":a:b::c"), ['', ':', '', '', + 'a', ':', '', '', 'b', '::', '', '', 'c', '', '']) + self.assertEqual(regex.split("(?::*)", ":a:b::c"), ['', '', 'a', + '', 'b', '', 'c', '']) + self.assertEqual(regex.split("(:)*", ":a:b::c"), ['', ':', '', + None, 'a', ':', '', None, 'b', ':', '', None, 'c', None, '']) + else: + self.assertEqual(regex.split(":*", ":a:b::c"), ['', 'a', 'b', 'c']) + self.assertEqual(regex.split("(:*)", ":a:b::c"), ['', ':', 'a', + ':', 'b', '::', 'c']) + self.assertEqual(regex.split("(?::*)", ":a:b::c"), ['', 'a', 'b', + 'c']) + self.assertEqual(regex.split("(:)*", ":a:b::c"), ['', ':', 'a', + ':', 'b', ':', 'c']) + self.assertEqual(regex.split("([b:]+)", ":a:b::c"), ['', ':', 'a', + ':b::', 'c']) + self.assertEqual(regex.split("(b)|(:+)", ":a:b::c"), ['', None, ':', + 'a', None, ':', '', 'b', None, '', None, '::', 'c']) + self.assertEqual(regex.split("(?:b)|(?::+)", ":a:b::c"), ['', 'a', '', + '', 'c']) + + self.assertEqual(regex.split("x", "xaxbxc"), ['', 'a', 'b', 'c']) + self.assertEqual([m for m in regex.splititer("x", "xaxbxc")], ['', 'a', + 'b', 'c']) + + self.assertEqual(regex.split("(?r)x", "xaxbxc"), ['c', 'b', 'a', '']) + self.assertEqual([m for m in regex.splititer("(?r)x", "xaxbxc")], ['c', + 'b', 'a', '']) + + self.assertEqual(regex.split("(x)|(y)", "xaxbxc"), ['', 'x', None, 'a', + 'x', None, 'b', 'x', None, 'c']) + self.assertEqual([m for m in regex.splititer("(x)|(y)", "xaxbxc")], + ['', 'x', None, 'a', 'x', None, 'b', 'x', None, 'c']) + + self.assertEqual(regex.split("(?r)(x)|(y)", "xaxbxc"), ['c', 'x', None, + 'b', 'x', None, 'a', 'x', None, '']) + self.assertEqual([m for m in regex.splititer("(?r)(x)|(y)", "xaxbxc")], + ['c', 'x', None, 'b', 'x', None, 'a', 'x', None, '']) + + self.assertEqual(regex.split(r"(?V1)\b", "a b c"), ['', 'a', ' ', 'b', + ' ', 'c', '']) + self.assertEqual(regex.split(r"(?V1)\m", "a b c"), ['', 'a ', 'b ', + 'c']) + self.assertEqual(regex.split(r"(?V1)\M", "a b c"), ['a', ' b', ' c', + '']) + + def test_qualified_re_split(self): + self.assertEqual(regex.split(":", ":a:b::c", 2), ['', 'a', 'b::c']) + self.assertEqual(regex.split(':', 'a:b:c:d', 2), ['a', 'b', 'c:d']) + self.assertEqual(regex.split("(:)", ":a:b::c", 2), ['', ':', 'a', ':', + 'b::c']) + + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.split("(:*)", ":a:b::c", 2), ['', ':', '', + '', 'a:b::c']) + else: + self.assertEqual(regex.split("(:*)", ":a:b::c", 2), ['', ':', 'a', + ':', 'b::c']) + + def test_re_findall(self): + self.assertEqual(regex.findall(":+", "abc"), []) + self.assertEqual(regex.findall(":+", "a:b::c:::d"), [':', '::', ':::']) + self.assertEqual(regex.findall("(:+)", "a:b::c:::d"), [':', '::', + ':::']) + self.assertEqual(regex.findall("(:)(:*)", "a:b::c:::d"), [(':', ''), + (':', ':'), (':', '::')]) + + self.assertEqual(regex.findall(r"\((?P.{0,5}?TEST)\)", + "(MY TEST)"), ["MY TEST"]) + self.assertEqual(regex.findall(r"\((?P.{0,3}?TEST)\)", + "(MY TEST)"), ["MY TEST"]) + self.assertEqual(regex.findall(r"\((?P.{0,3}?T)\)", "(MY T)"), + ["MY T"]) + + self.assertEqual(regex.findall(r"[^a]{2}[A-Z]", "\n S"), [' S']) + self.assertEqual(regex.findall(r"[^a]{2,3}[A-Z]", "\n S"), ['\n S']) + self.assertEqual(regex.findall(r"[^a]{2,3}[A-Z]", "\n S"), [' S']) + + self.assertEqual(regex.findall(r"X(Y[^Y]+?){1,2}( |Q)+DEF", + "XYABCYPPQ\nQ DEF"), [('YPPQ\n', ' ')]) + + self.assertEqual(regex.findall(r"(\nTest(\n+.+?){0,2}?)?\n+End", + "\nTest\nxyz\nxyz\nEnd"), [('\nTest\nxyz\nxyz', '\nxyz')]) + + def test_bug_117612(self): + self.assertEqual(regex.findall(r"(a|(b))", "aba"), [('a', ''), ('b', + 'b'), ('a', '')]) + + def test_re_match(self): + self.assertEqual(regex.match('a', 'a')[:], ('a',)) + self.assertEqual(regex.match('(a)', 'a')[:], ('a', 'a')) + self.assertEqual(regex.match(r'(a)', 'a')[0], 'a') + self.assertEqual(regex.match(r'(a)', 'a')[1], 'a') + self.assertEqual(regex.match(r'(a)', 'a').group(1, 1), ('a', 'a')) + + pat = regex.compile('((a)|(b))(c)?') + self.assertEqual(pat.match('a')[:], ('a', 'a', 'a', None, None)) + self.assertEqual(pat.match('b')[:], ('b', 'b', None, 'b', None)) + self.assertEqual(pat.match('ac')[:], ('ac', 'a', 'a', None, 'c')) + self.assertEqual(pat.match('bc')[:], ('bc', 'b', None, 'b', 'c')) + self.assertEqual(pat.match('bc')[:], ('bc', 'b', None, 'b', 'c')) + + # A single group. + m = regex.match('(a)', 'a') + self.assertEqual(m.group(), 'a') + self.assertEqual(m.group(0), 'a') + self.assertEqual(m.group(1), 'a') + self.assertEqual(m.group(1, 1), ('a', 'a')) + + pat = regex.compile('(?:(?Pa)|(?Pb))(?Pc)?') + self.assertEqual(pat.match('a').group(1, 2, 3), ('a', None, None)) + self.assertEqual(pat.match('b').group('a1', 'b2', 'c3'), (None, 'b', + None)) + self.assertEqual(pat.match('ac').group(1, 'b2', 3), ('a', None, 'c')) + + def test_re_groupref_exists(self): + self.assertEqual(regex.match(r'^(\()?([^()]+)(?(1)\))$', '(a)')[:], + ('(a)', '(', 'a')) + self.assertEqual(regex.match(r'^(\()?([^()]+)(?(1)\))$', 'a')[:], ('a', + None, 'a')) + self.assertEqual(regex.match(r'^(\()?([^()]+)(?(1)\))$', 'a)'), None) + self.assertEqual(regex.match(r'^(\()?([^()]+)(?(1)\))$', '(a'), None) + self.assertEqual(regex.match('^(?:(a)|c)((?(1)b|d))$', 'ab')[:], ('ab', + 'a', 'b')) + self.assertEqual(regex.match('^(?:(a)|c)((?(1)b|d))$', 'cd')[:], ('cd', + None, 'd')) + self.assertEqual(regex.match('^(?:(a)|c)((?(1)|d))$', 'cd')[:], ('cd', + None, 'd')) + self.assertEqual(regex.match('^(?:(a)|c)((?(1)|d))$', 'a')[:], ('a', + 'a', '')) + + # Tests for bug #1177831: exercise groups other than the first group. + p = regex.compile('(?Pa)(?Pb)?((?(g2)c|d))') + self.assertEqual(p.match('abc')[:], ('abc', 'a', 'b', 'c')) + self.assertEqual(p.match('ad')[:], ('ad', 'a', None, 'd')) + self.assertEqual(p.match('abd'), None) + self.assertEqual(p.match('ac'), None) + + def test_re_groupref(self): + self.assertEqual(regex.match(r'^(\|)?([^()]+)\1$', '|a|')[:], ('|a|', + '|', 'a')) + self.assertEqual(regex.match(r'^(\|)?([^()]+)\1?$', 'a')[:], ('a', + None, 'a')) + self.assertEqual(regex.match(r'^(\|)?([^()]+)\1$', 'a|'), None) + self.assertEqual(regex.match(r'^(\|)?([^()]+)\1$', '|a'), None) + self.assertEqual(regex.match(r'^(?:(a)|c)(\1)$', 'aa')[:], ('aa', 'a', + 'a')) + self.assertEqual(regex.match(r'^(?:(a)|c)(\1)?$', 'c')[:], ('c', None, + None)) + + self.assertEqual(regex.findall(r"(?i)(.{1,40}?),(.{1,40}?)(?:;)+(.{1,80}).{1,40}?\3(\ |;)+(.{1,80}?)\1", + "TEST, BEST; LEST ; Lest 123 Test, Best"), [('TEST', ' BEST', + ' LEST', ' ', '123 ')]) + + def test_groupdict(self): + self.assertEqual(regex.match('(?Pfirst) (?Psecond)', + 'first second').groupdict(), {'first': 'first', 'second': 'second'}) + + def test_expand(self): + self.assertEqual(regex.match("(?Pfirst) (?Psecond)", + "first second").expand(r"\2 \1 \g \g"), + 'second first second first') + + def test_repeat_minmax(self): + self.assertEqual(regex.match(r"^(\w){1}$", "abc"), None) + self.assertEqual(regex.match(r"^(\w){1}?$", "abc"), None) + self.assertEqual(regex.match(r"^(\w){1,2}$", "abc"), None) + self.assertEqual(regex.match(r"^(\w){1,2}?$", "abc"), None) + + self.assertEqual(regex.match(r"^(\w){3}$", "abc")[1], 'c') + self.assertEqual(regex.match(r"^(\w){1,3}$", "abc")[1], 'c') + self.assertEqual(regex.match(r"^(\w){1,4}$", "abc")[1], 'c') + self.assertEqual(regex.match(r"^(\w){3,4}?$", "abc")[1], 'c') + self.assertEqual(regex.match(r"^(\w){3}?$", "abc")[1], 'c') + self.assertEqual(regex.match(r"^(\w){1,3}?$", "abc")[1], 'c') + self.assertEqual(regex.match(r"^(\w){1,4}?$", "abc")[1], 'c') + self.assertEqual(regex.match(r"^(\w){3,4}?$", "abc")[1], 'c') + + self.assertEqual(regex.match("^x{1}$", "xxx"), None) + self.assertEqual(regex.match("^x{1}?$", "xxx"), None) + self.assertEqual(regex.match("^x{1,2}$", "xxx"), None) + self.assertEqual(regex.match("^x{1,2}?$", "xxx"), None) + + self.assertEqual(regex.match("^x{1}", "xxx")[0], 'x') + self.assertEqual(regex.match("^x{1}?", "xxx")[0], 'x') + self.assertEqual(regex.match("^x{0,1}", "xxx")[0], 'x') + self.assertEqual(regex.match("^x{0,1}?", "xxx")[0], '') + + self.assertEqual(bool(regex.match("^x{3}$", "xxx")), True) + self.assertEqual(bool(regex.match("^x{1,3}$", "xxx")), True) + self.assertEqual(bool(regex.match("^x{1,4}$", "xxx")), True) + self.assertEqual(bool(regex.match("^x{3,4}?$", "xxx")), True) + self.assertEqual(bool(regex.match("^x{3}?$", "xxx")), True) + self.assertEqual(bool(regex.match("^x{1,3}?$", "xxx")), True) + self.assertEqual(bool(regex.match("^x{1,4}?$", "xxx")), True) + self.assertEqual(bool(regex.match("^x{3,4}?$", "xxx")), True) + + self.assertEqual(regex.match("^x{}$", "xxx"), None) + self.assertEqual(bool(regex.match("^x{}$", "x{}")), True) + + def test_getattr(self): + self.assertEqual(regex.compile("(?i)(a)(b)").pattern, '(?i)(a)(b)') + self.assertEqual(regex.compile("(?i)(a)(b)").flags, regex.I | regex.U | + regex.DEFAULT_VERSION) + self.assertEqual(regex.compile(b"(?i)(a)(b)").flags, regex.A | regex.I + | regex.DEFAULT_VERSION) + self.assertEqual(regex.compile("(?i)(a)(b)").groups, 2) + self.assertEqual(regex.compile("(?i)(a)(b)").groupindex, {}) + + self.assertEqual(regex.compile("(?i)(?Pa)(?Pb)").groupindex, + {'first': 1, 'other': 2}) + + self.assertEqual(regex.match("(a)", "a").pos, 0) + self.assertEqual(regex.match("(a)", "a").endpos, 1) + + self.assertEqual(regex.search("b(c)", "abcdef").pos, 0) + self.assertEqual(regex.search("b(c)", "abcdef").endpos, 6) + self.assertEqual(regex.search("b(c)", "abcdef").span(), (1, 3)) + self.assertEqual(regex.search("b(c)", "abcdef").span(1), (2, 3)) + + self.assertEqual(regex.match("(a)", "a").string, 'a') + self.assertEqual(regex.match("(a)", "a").regs, ((0, 1), (0, 1))) + self.assertEqual(repr(type(regex.match("(a)", "a").re)), + self.PATTERN_CLASS) + + # Issue 14260. + p = regex.compile(r'abc(?Pdef)') + p.groupindex["n"] = 0 + self.assertEqual(p.groupindex["n"], 1) + + def test_special_escapes(self): + self.assertEqual(regex.search(r"\b(b.)\b", "abcd abc bcd bx")[1], 'bx') + self.assertEqual(regex.search(r"\B(b.)\B", "abc bcd bc abxd")[1], 'bx') + self.assertEqual(regex.search(br"\b(b.)\b", b"abcd abc bcd bx", + regex.LOCALE)[1], b'bx') + self.assertEqual(regex.search(br"\B(b.)\B", b"abc bcd bc abxd", + regex.LOCALE)[1], b'bx') + self.assertEqual(regex.search(r"\b(b.)\b", "abcd abc bcd bx", + regex.UNICODE)[1], 'bx') + self.assertEqual(regex.search(r"\B(b.)\B", "abc bcd bc abxd", + regex.UNICODE)[1], 'bx') + + self.assertEqual(regex.search(r"^abc$", "\nabc\n", regex.M)[0], 'abc') + self.assertEqual(regex.search(r"^\Aabc\Z$", "abc", regex.M)[0], 'abc') + self.assertEqual(regex.search(r"^\Aabc\Z$", "\nabc\n", regex.M), None) + + self.assertEqual(regex.search(br"\b(b.)\b", b"abcd abc bcd bx")[1], + b'bx') + self.assertEqual(regex.search(br"\B(b.)\B", b"abc bcd bc abxd")[1], + b'bx') + self.assertEqual(regex.search(br"^abc$", b"\nabc\n", regex.M)[0], + b'abc') + self.assertEqual(regex.search(br"^\Aabc\Z$", b"abc", regex.M)[0], + b'abc') + self.assertEqual(regex.search(br"^\Aabc\Z$", b"\nabc\n", regex.M), + None) + + self.assertEqual(regex.search(r"\d\D\w\W\s\S", "1aa! a")[0], '1aa! a') + self.assertEqual(regex.search(br"\d\D\w\W\s\S", b"1aa! a", + regex.LOCALE)[0], b'1aa! a') + self.assertEqual(regex.search(r"\d\D\w\W\s\S", "1aa! a", + regex.UNICODE)[0], '1aa! a') + + def test_bigcharset(self): + self.assertEqual(regex.match(r"([\u2222\u2223])", "\u2222")[1], + '\u2222') + self.assertEqual(regex.match(r"([\u2222\u2223])", "\u2222", + regex.UNICODE)[1], '\u2222') + self.assertEqual("".join(regex.findall(".", + "e\xe8\xe9\xea\xeb\u0113\u011b\u0117", flags=regex.UNICODE)), + 'e\xe8\xe9\xea\xeb\u0113\u011b\u0117') + self.assertEqual("".join(regex.findall(r"[e\xe8\xe9\xea\xeb\u0113\u011b\u0117]", + "e\xe8\xe9\xea\xeb\u0113\u011b\u0117", flags=regex.UNICODE)), + 'e\xe8\xe9\xea\xeb\u0113\u011b\u0117') + self.assertEqual("".join(regex.findall(r"e|\xe8|\xe9|\xea|\xeb|\u0113|\u011b|\u0117", + "e\xe8\xe9\xea\xeb\u0113\u011b\u0117", flags=regex.UNICODE)), + 'e\xe8\xe9\xea\xeb\u0113\u011b\u0117') + + def test_anyall(self): + self.assertEqual(regex.match("a.b", "a\nb", regex.DOTALL)[0], "a\nb") + self.assertEqual(regex.match("a.*b", "a\n\nb", regex.DOTALL)[0], + "a\n\nb") + + def test_non_consuming(self): + self.assertEqual(regex.match(r"(a(?=\s[^a]))", "a b")[1], 'a') + self.assertEqual(regex.match(r"(a(?=\s[^a]*))", "a b")[1], 'a') + self.assertEqual(regex.match(r"(a(?=\s[abc]))", "a b")[1], 'a') + self.assertEqual(regex.match(r"(a(?=\s[abc]*))", "a bc")[1], 'a') + self.assertEqual(regex.match(r"(a)(?=\s\1)", "a a")[1], 'a') + self.assertEqual(regex.match(r"(a)(?=\s\1*)", "a aa")[1], 'a') + self.assertEqual(regex.match(r"(a)(?=\s(abc|a))", "a a")[1], 'a') + + self.assertEqual(regex.match(r"(a(?!\s[^a]))", "a a")[1], 'a') + self.assertEqual(regex.match(r"(a(?!\s[abc]))", "a d")[1], 'a') + self.assertEqual(regex.match(r"(a)(?!\s\1)", "a b")[1], 'a') + self.assertEqual(regex.match(r"(a)(?!\s(abc|a))", "a b")[1], 'a') + + def test_ignore_case(self): + self.assertEqual(regex.match("abc", "ABC", regex.I)[0], 'ABC') + self.assertEqual(regex.match(b"abc", b"ABC", regex.I)[0], b'ABC') + + self.assertEqual(regex.match(r"(a\s[^a]*)", "a bb", regex.I)[1], + 'a bb') + self.assertEqual(regex.match(r"(a\s[abc])", "a b", regex.I)[1], 'a b') + self.assertEqual(regex.match(r"(a\s[abc]*)", "a bb", regex.I)[1], + 'a bb') + self.assertEqual(regex.match(r"((a)\s\2)", "a a", regex.I)[1], 'a a') + self.assertEqual(regex.match(r"((a)\s\2*)", "a aa", regex.I)[1], + 'a aa') + self.assertEqual(regex.match(r"((a)\s(abc|a))", "a a", regex.I)[1], + 'a a') + self.assertEqual(regex.match(r"((a)\s(abc|a)*)", "a aa", regex.I)[1], + 'a aa') + + # Issue 3511. + self.assertEqual(regex.match(r"[Z-a]", "_").span(), (0, 1)) + self.assertEqual(regex.match(r"(?i)[Z-a]", "_").span(), (0, 1)) + + self.assertEqual(bool(regex.match(r"(?i)nao", "nAo")), True) + self.assertEqual(bool(regex.match(r"(?i)n\xE3o", "n\xC3o")), True) + self.assertEqual(bool(regex.match(r"(?i)n\xE3o", "N\xC3O")), True) + self.assertEqual(bool(regex.match(r"(?i)s", "\u017F")), True) + + def test_case_folding(self): + self.assertEqual(regex.search(r"(?fi)ss", "SS").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)SS", "ss").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)SS", + "\N{LATIN SMALL LETTER SHARP S}").span(), (0, 1)) + self.assertEqual(regex.search(r"(?fi)\N{LATIN SMALL LETTER SHARP S}", + "SS").span(), (0, 2)) + + self.assertEqual(regex.search(r"(?fi)\N{LATIN SMALL LIGATURE ST}", + "ST").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)ST", + "\N{LATIN SMALL LIGATURE ST}").span(), (0, 1)) + self.assertEqual(regex.search(r"(?fi)ST", + "\N{LATIN SMALL LIGATURE LONG S T}").span(), (0, 1)) + + self.assertEqual(regex.search(r"(?fi)SST", + "\N{LATIN SMALL LETTER SHARP S}t").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)SST", + "s\N{LATIN SMALL LIGATURE LONG S T}").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)SST", + "s\N{LATIN SMALL LIGATURE ST}").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)\N{LATIN SMALL LIGATURE ST}", + "SST").span(), (1, 3)) + self.assertEqual(regex.search(r"(?fi)SST", + "s\N{LATIN SMALL LIGATURE ST}").span(), (0, 2)) + + self.assertEqual(regex.search(r"(?fi)FFI", + "\N{LATIN SMALL LIGATURE FFI}").span(), (0, 1)) + self.assertEqual(regex.search(r"(?fi)FFI", + "\N{LATIN SMALL LIGATURE FF}i").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)FFI", + "f\N{LATIN SMALL LIGATURE FI}").span(), (0, 2)) + self.assertEqual(regex.search(r"(?fi)\N{LATIN SMALL LIGATURE FFI}", + "FFI").span(), (0, 3)) + self.assertEqual(regex.search(r"(?fi)\N{LATIN SMALL LIGATURE FF}i", + "FFI").span(), (0, 3)) + self.assertEqual(regex.search(r"(?fi)f\N{LATIN SMALL LIGATURE FI}", + "FFI").span(), (0, 3)) + + sigma = "\u03A3\u03C3\u03C2" + for ch1 in sigma: + for ch2 in sigma: + if not regex.match(r"(?fi)" + ch1, ch2): + self.fail() + + self.assertEqual(bool(regex.search(r"(?iV1)ff", "\uFB00\uFB01")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)ff", "\uFB01\uFB00")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)fi", "\uFB00\uFB01")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)fi", "\uFB01\uFB00")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)fffi", "\uFB00\uFB01")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)f\uFB03", + "\uFB00\uFB01")), True) + self.assertEqual(bool(regex.search(r"(?iV1)ff", "\uFB00\uFB01")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)fi", "\uFB00\uFB01")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)fffi", "\uFB00\uFB01")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)f\uFB03", + "\uFB00\uFB01")), True) + self.assertEqual(bool(regex.search(r"(?iV1)f\uFB01", "\uFB00i")), + True) + self.assertEqual(bool(regex.search(r"(?iV1)f\uFB01", "\uFB00i")), + True) + + self.assertEqual(regex.findall(r"(?iV0)\m(?:word){e<=3}\M(?ne", "affine", + options=["\N{LATIN SMALL LIGATURE FFI}"]).span(), (0, 6)) + self.assertEqual(regex.search(r"(?fi)a\Lne", + "a\N{LATIN SMALL LIGATURE FFI}ne", options=["ffi"]).span(), (0, 4)) + + def test_category(self): + self.assertEqual(regex.match(r"(\s)", " ")[1], ' ') + + def test_not_literal(self): + self.assertEqual(regex.search(r"\s([^a])", " b")[1], 'b') + self.assertEqual(regex.search(r"\s([^a]*)", " bb")[1], 'bb') + + def test_search_coverage(self): + self.assertEqual(regex.search(r"\s(b)", " b")[1], 'b') + self.assertEqual(regex.search(r"a\s", "a ")[0], 'a ') + + def test_re_escape(self): + p = "" + self.assertEqual(regex.escape(p), p) + for i in range(0, 256): + p += chr(i) + self.assertEqual(bool(regex.match(regex.escape(chr(i)), chr(i))), + True) + self.assertEqual(regex.match(regex.escape(chr(i)), chr(i)).span(), + (0, 1)) + + pat = regex.compile(regex.escape(p)) + self.assertEqual(pat.match(p).span(), (0, 256)) + + def test_re_escape_byte(self): + p = b"" + self.assertEqual(regex.escape(p), p) + for i in range(0, 256): + b = bytes([i]) + p += b + self.assertEqual(bool(regex.match(regex.escape(b), b)), True) + self.assertEqual(regex.match(regex.escape(b), b).span(), (0, 1)) + + pat = regex.compile(regex.escape(p)) + self.assertEqual(pat.match(p).span(), (0, 256)) + + def test_constants(self): + if regex.I != regex.IGNORECASE: + self.fail() + if regex.L != regex.LOCALE: + self.fail() + if regex.M != regex.MULTILINE: + self.fail() + if regex.S != regex.DOTALL: + self.fail() + if regex.X != regex.VERBOSE: + self.fail() + + def test_flags(self): + for flag in [regex.I, regex.M, regex.X, regex.S, regex.L]: + self.assertEqual(repr(type(regex.compile('^pattern$', flag))), + self.PATTERN_CLASS) + + def test_sre_character_literals(self): + for i in [0, 8, 16, 32, 64, 127, 128, 255]: + self.assertEqual(bool(regex.match(r"\%03o" % i, chr(i))), True) + self.assertEqual(bool(regex.match(r"\%03o0" % i, chr(i) + "0")), + True) + self.assertEqual(bool(regex.match(r"\%03o8" % i, chr(i) + "8")), + True) + self.assertEqual(bool(regex.match(r"\x%02x" % i, chr(i))), True) + self.assertEqual(bool(regex.match(r"\x%02x0" % i, chr(i) + "0")), + True) + self.assertEqual(bool(regex.match(r"\x%02xz" % i, chr(i) + "z")), + True) + + self.assertRaisesRegex(regex.error, self.INVALID_GROUP_REF, lambda: + regex.match(r"\911", "")) + + def test_sre_character_class_literals(self): + for i in [0, 8, 16, 32, 64, 127, 128, 255]: + self.assertEqual(bool(regex.match(r"[\%03o]" % i, chr(i))), True) + self.assertEqual(bool(regex.match(r"[\%03o0]" % i, chr(i))), True) + self.assertEqual(bool(regex.match(r"[\%03o8]" % i, chr(i))), True) + self.assertEqual(bool(regex.match(r"[\x%02x]" % i, chr(i))), True) + self.assertEqual(bool(regex.match(r"[\x%02x0]" % i, chr(i))), True) + self.assertEqual(bool(regex.match(r"[\x%02xz]" % i, chr(i))), True) + + self.assertRaisesRegex(regex.error, self.BAD_OCTAL_ESCAPE, lambda: + regex.match(r"[\911]", "")) + + def test_bug_113254(self): + self.assertEqual(regex.match(r'(a)|(b)', 'b').start(1), -1) + self.assertEqual(regex.match(r'(a)|(b)', 'b').end(1), -1) + self.assertEqual(regex.match(r'(a)|(b)', 'b').span(1), (-1, -1)) + + def test_bug_527371(self): + # Bug described in patches 527371/672491. + self.assertEqual(regex.match(r'(a)?a','a').lastindex, None) + self.assertEqual(regex.match(r'(a)(b)?b','ab').lastindex, 1) + self.assertEqual(regex.match(r'(?Pa)(?Pb)?b','ab').lastgroup, + 'a') + self.assertEqual(regex.match("(?Pa(b))", "ab").lastgroup, 'a') + self.assertEqual(regex.match("((a))", "a").lastindex, 1) + + def test_bug_545855(self): + # Bug 545855 -- This pattern failed to cause a compile error as it + # should, instead provoking a TypeError. + self.assertRaisesRegex(regex.error, self.BAD_SET, lambda: + regex.compile('foo[a-')) + + def test_bug_418626(self): + # Bugs 418626 at al. -- Testing Greg Chapman's addition of op code + # SRE_OP_MIN_REPEAT_ONE for eliminating recursion on simple uses of + # pattern '*?' on a long string. + self.assertEqual(regex.match('.*?c', 10000 * 'ab' + 'cd').end(0), + 20001) + self.assertEqual(regex.match('.*?cd', 5000 * 'ab' + 'c' + 5000 * 'ab' + + 'cde').end(0), 20003) + self.assertEqual(regex.match('.*?cd', 20000 * 'abc' + 'de').end(0), + 60001) + # Non-simple '*?' still used to hit the recursion limit, before the + # non-recursive scheme was implemented. + self.assertEqual(regex.search('(a|b)*?c', 10000 * 'ab' + 'cd').end(0), + 20001) + + def test_bug_612074(self): + pat = "[" + regex.escape("\u2039") + "]" + self.assertEqual(regex.compile(pat) and 1, 1) + + def test_stack_overflow(self): + # Nasty cases that used to overflow the straightforward recursive + # implementation of repeated groups. + self.assertEqual(regex.match('(x)*', 50000 * 'x')[1], 'x') + self.assertEqual(regex.match('(x)*y', 50000 * 'x' + 'y')[1], 'x') + self.assertEqual(regex.match('(x)*?y', 50000 * 'x' + 'y')[1], 'x') + + def test_scanner(self): + def s_ident(scanner, token): return token + def s_operator(scanner, token): return "op%s" % token + def s_float(scanner, token): return float(token) + def s_int(scanner, token): return int(token) + + scanner = regex.Scanner([(r"[a-zA-Z_]\w*", s_ident), (r"\d+\.\d*", + s_float), (r"\d+", s_int), (r"=|\+|-|\*|/", s_operator), (r"\s+", + None), ]) + + self.assertEqual(repr(type(scanner.scanner.scanner("").pattern)), + self.PATTERN_CLASS) + + self.assertEqual(scanner.scan("sum = 3*foo + 312.50 + bar"), (['sum', + 'op=', 3, 'op*', 'foo', 'op+', 312.5, 'op+', 'bar'], '')) + + def test_bug_448951(self): + # Bug 448951 (similar to 429357, but with single char match). + # (Also test greedy matches.) + for op in '', '?', '*': + self.assertEqual(regex.match(r'((.%s):)?z' % op, 'z')[:], ('z', + None, None)) + self.assertEqual(regex.match(r'((.%s):)?z' % op, 'a:z')[:], ('a:z', + 'a:', 'a')) + + def test_bug_725106(self): + # Capturing groups in alternatives in repeats. + self.assertEqual(regex.match('^((a)|b)*', 'abc')[:], ('ab', 'b', 'a')) + self.assertEqual(regex.match('^(([ab])|c)*', 'abc')[:], ('abc', 'c', + 'b')) + self.assertEqual(regex.match('^((d)|[ab])*', 'abc')[:], ('ab', 'b', + None)) + self.assertEqual(regex.match('^((a)c|[ab])*', 'abc')[:], ('ab', 'b', + None)) + self.assertEqual(regex.match('^((a)|b)*?c', 'abc')[:], ('abc', 'b', + 'a')) + self.assertEqual(regex.match('^(([ab])|c)*?d', 'abcd')[:], ('abcd', + 'c', 'b')) + self.assertEqual(regex.match('^((d)|[ab])*?c', 'abc')[:], ('abc', 'b', + None)) + self.assertEqual(regex.match('^((a)c|[ab])*?c', 'abc')[:], ('abc', 'b', + None)) + + def test_bug_725149(self): + # Mark_stack_base restoring before restoring marks. + self.assertEqual(regex.match('(a)(?:(?=(b)*)c)*', 'abb')[:], ('a', 'a', + None)) + self.assertEqual(regex.match('(a)((?!(b)*))*', 'abb')[:], ('a', 'a', + None, None)) + + def test_bug_764548(self): + # Bug 764548, regex.compile() barfs on str/unicode subclasses. + class my_unicode(str): pass + pat = regex.compile(my_unicode("abc")) + self.assertEqual(pat.match("xyz"), None) + + def test_finditer(self): + it = regex.finditer(r":+", "a:b::c:::d") + self.assertEqual([item[0] for item in it], [':', '::', ':::']) + + def test_bug_926075(self): + if regex.compile('bug_926075') is regex.compile(b'bug_926075'): + self.fail() + + def test_bug_931848(self): + pattern = "[\u002E\u3002\uFF0E\uFF61]" + self.assertEqual(regex.compile(pattern).split("a.b.c"), ['a', 'b', + 'c']) + + def test_bug_581080(self): + it = regex.finditer(r"\s", "a b") + self.assertEqual(next(it).span(), (1, 2)) + self.assertRaises(StopIteration, lambda: next(it)) + + scanner = regex.compile(r"\s").scanner("a b") + self.assertEqual(scanner.search().span(), (1, 2)) + self.assertEqual(scanner.search(), None) + + def test_bug_817234(self): + it = regex.finditer(r".*", "asdf") + self.assertEqual(next(it).span(), (0, 4)) + self.assertEqual(next(it).span(), (4, 4)) + self.assertRaises(StopIteration, lambda: next(it)) + + def test_empty_array(self): + # SF buf 1647541. + import array + for typecode in 'bBhHiIlLfd': + a = array.array(typecode) + self.assertEqual(regex.compile(b"bla").match(a), None) + self.assertEqual(regex.compile(b"").match(a)[1 : ], ()) + + def test_inline_flags(self): + # Bug #1700. + upper_char = chr(0x1ea0) # Latin Capital Letter A with Dot Below + lower_char = chr(0x1ea1) # Latin Small Letter A with Dot Below + + p = regex.compile(upper_char, regex.I | regex.U) + self.assertEqual(bool(p.match(lower_char)), True) + + p = regex.compile(lower_char, regex.I | regex.U) + self.assertEqual(bool(p.match(upper_char)), True) + + p = regex.compile('(?i)' + upper_char, regex.U) + self.assertEqual(bool(p.match(lower_char)), True) + + p = regex.compile('(?i)' + lower_char, regex.U) + self.assertEqual(bool(p.match(upper_char)), True) + + p = regex.compile('(?iu)' + upper_char) + self.assertEqual(bool(p.match(lower_char)), True) + + p = regex.compile('(?iu)' + lower_char) + self.assertEqual(bool(p.match(upper_char)), True) + + # Changed to positional flags in regex 2023.12.23. + self.assertEqual(bool(regex.match(r"(?i)a", "A")), True) + self.assertEqual(regex.match(r"a(?i)", "A"), None) + + def test_dollar_matches_twice(self): + # $ matches the end of string, and just before the terminating \n. + pattern = regex.compile('$') + self.assertEqual(pattern.sub('#', 'a\nb\n'), 'a\nb#\n#') + self.assertEqual(pattern.sub('#', 'a\nb\nc'), 'a\nb\nc#') + self.assertEqual(pattern.sub('#', '\n'), '#\n#') + + pattern = regex.compile('$', regex.MULTILINE) + self.assertEqual(pattern.sub('#', 'a\nb\n' ), 'a#\nb#\n#') + self.assertEqual(pattern.sub('#', 'a\nb\nc'), 'a#\nb#\nc#') + self.assertEqual(pattern.sub('#', '\n'), '#\n#') + + def test_bytes_str_mixing(self): + # Mixing str and bytes is disallowed. + pat = regex.compile('.') + bpat = regex.compile(b'.') + self.assertRaisesRegex(TypeError, self.STR_PAT_ON_BYTES, lambda: + pat.match(b'b')) + self.assertRaisesRegex(TypeError, self.BYTES_PAT_ON_STR, lambda: + bpat.match('b')) + self.assertRaisesRegex(TypeError, self.STR_PAT_BYTES_TEMPL, lambda: + pat.sub(b'b', 'c')) + self.assertRaisesRegex(TypeError, self.STR_PAT_ON_BYTES, lambda: + pat.sub('b', b'c')) + self.assertRaisesRegex(TypeError, self.STR_PAT_ON_BYTES, lambda: + pat.sub(b'b', b'c')) + self.assertRaisesRegex(TypeError, self.BYTES_PAT_ON_STR, lambda: + bpat.sub(b'b', 'c')) + self.assertRaisesRegex(TypeError, self.BYTES_PAT_STR_TEMPL, lambda: + bpat.sub('b', b'c')) + self.assertRaisesRegex(TypeError, self.BYTES_PAT_ON_STR, lambda: + bpat.sub('b', 'c')) + + self.assertRaisesRegex(ValueError, self.BYTES_PAT_UNI_FLAG, lambda: + regex.compile(br'\w', regex.UNICODE)) + self.assertRaisesRegex(ValueError, self.BYTES_PAT_UNI_FLAG, lambda: + regex.compile(br'(?u)\w')) + self.assertRaisesRegex(ValueError, self.MIXED_FLAGS, lambda: + regex.compile(r'\w', regex.UNICODE | regex.ASCII)) + self.assertRaisesRegex(ValueError, self.MIXED_FLAGS, lambda: + regex.compile(r'(?u)\w', regex.ASCII)) + self.assertRaisesRegex(ValueError, self.MIXED_FLAGS, lambda: + regex.compile(r'(?a)\w', regex.UNICODE)) + self.assertRaisesRegex(ValueError, self.MIXED_FLAGS, lambda: + regex.compile(r'(?au)\w')) + + def test_ascii_and_unicode_flag(self): + # String patterns. + for flags in (0, regex.UNICODE): + pat = regex.compile('\xc0', flags | regex.IGNORECASE) + self.assertEqual(bool(pat.match('\xe0')), True) + pat = regex.compile(r'\w', flags) + self.assertEqual(bool(pat.match('\xe0')), True) + + pat = regex.compile('\xc0', regex.ASCII | regex.IGNORECASE) + self.assertEqual(pat.match('\xe0'), None) + pat = regex.compile('(?a)\xc0', regex.IGNORECASE) + self.assertEqual(pat.match('\xe0'), None) + pat = regex.compile(r'\w', regex.ASCII) + self.assertEqual(pat.match('\xe0'), None) + pat = regex.compile(r'(?a)\w') + self.assertEqual(pat.match('\xe0'), None) + + # Bytes patterns. + for flags in (0, regex.ASCII): + pat = regex.compile(b'\xc0', flags | regex.IGNORECASE) + self.assertEqual(pat.match(b'\xe0'), None) + pat = regex.compile(br'\w') + self.assertEqual(pat.match(b'\xe0'), None) + + self.assertRaisesRegex(ValueError, self.MIXED_FLAGS, lambda: + regex.compile(r'(?au)\w')) + + def test_subscripting_match(self): + m = regex.match(r'(?\w)', 'xy') + if not m: + self.fail("Failed: expected match but returned None") + elif not m or m[0] != m.group(0) or m[1] != m.group(1): + self.fail("Failed") + if not m: + self.fail("Failed: expected match but returned None") + elif m[:] != ('x', 'x'): + self.fail("Failed: expected \"('x', 'x')\" but got {} instead".format(ascii(m[:]))) + + def test_new_named_groups(self): + m0 = regex.match(r'(?P\w)', 'x') + m1 = regex.match(r'(?\w)', 'x') + if not (m0 and m1 and m0[:] == m1[:]): + self.fail("Failed") + + def test_properties(self): + self.assertEqual(regex.match(b'(?ai)\xC0', b'\xE0'), None) + self.assertEqual(regex.match(br'(?ai)\xC0', b'\xE0'), None) + self.assertEqual(regex.match(br'(?a)\w', b'\xE0'), None) + self.assertEqual(bool(regex.match(r'\w', '\xE0')), True) + + # Dropped the following test. It's not possible to determine what the + # correct result should be in the general case. +# self.assertEqual(bool(regex.match(br'(?L)\w', b'\xE0')), +# b'\xE0'.isalnum()) + + self.assertEqual(bool(regex.match(br'(?L)\d', b'0')), True) + self.assertEqual(bool(regex.match(br'(?L)\s', b' ')), True) + self.assertEqual(bool(regex.match(br'(?L)\w', b'a')), True) + self.assertEqual(regex.match(br'(?L)\d', b'?'), None) + self.assertEqual(regex.match(br'(?L)\s', b'?'), None) + self.assertEqual(regex.match(br'(?L)\w', b'?'), None) + + self.assertEqual(regex.match(br'(?L)\D', b'0'), None) + self.assertEqual(regex.match(br'(?L)\S', b' '), None) + self.assertEqual(regex.match(br'(?L)\W', b'a'), None) + self.assertEqual(bool(regex.match(br'(?L)\D', b'?')), True) + self.assertEqual(bool(regex.match(br'(?L)\S', b'?')), True) + self.assertEqual(bool(regex.match(br'(?L)\W', b'?')), True) + + self.assertEqual(bool(regex.match(r'\p{Cyrillic}', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'(?i)\p{Cyrillic}', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{IsCyrillic}', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{Script=Cyrillic}', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{InCyrillic}', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{Block=Cyrillic}', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:Cyrillic:]]', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:IsCyrillic:]]', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:Script=Cyrillic:]]', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:InCyrillic:]]', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:Block=Cyrillic:]]', + '\N{CYRILLIC CAPITAL LETTER A}')), True) + + self.assertEqual(bool(regex.match(r'\P{Cyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\P{IsCyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\P{Script=Cyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\P{InCyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\P{Block=Cyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{^Cyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{^IsCyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{^Script=Cyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{^InCyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'\p{^Block=Cyrillic}', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:^Cyrillic:]]', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:^IsCyrillic:]]', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:^Script=Cyrillic:]]', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:^InCyrillic:]]', + '\N{LATIN CAPITAL LETTER A}')), True) + self.assertEqual(bool(regex.match(r'[[:^Block=Cyrillic:]]', + '\N{LATIN CAPITAL LETTER A}')), True) + + self.assertEqual(bool(regex.match(r'\d', '0')), True) + self.assertEqual(bool(regex.match(r'\s', ' ')), True) + self.assertEqual(bool(regex.match(r'\w', 'A')), True) + self.assertEqual(regex.match(r"\d", "?"), None) + self.assertEqual(regex.match(r"\s", "?"), None) + self.assertEqual(regex.match(r"\w", "?"), None) + self.assertEqual(regex.match(r"\D", "0"), None) + self.assertEqual(regex.match(r"\S", " "), None) + self.assertEqual(regex.match(r"\W", "A"), None) + self.assertEqual(bool(regex.match(r'\D', '?')), True) + self.assertEqual(bool(regex.match(r'\S', '?')), True) + self.assertEqual(bool(regex.match(r'\W', '?')), True) + + self.assertEqual(bool(regex.match(r'\p{L}', 'A')), True) + self.assertEqual(bool(regex.match(r'\p{L}', 'a')), True) + self.assertEqual(bool(regex.match(r'\p{Lu}', 'A')), True) + self.assertEqual(bool(regex.match(r'\p{Ll}', 'a')), True) + + self.assertEqual(bool(regex.match(r'(?i)a', 'a')), True) + self.assertEqual(bool(regex.match(r'(?i)a', 'A')), True) + + self.assertEqual(bool(regex.match(r'\w', '0')), True) + self.assertEqual(bool(regex.match(r'\w', 'a')), True) + self.assertEqual(bool(regex.match(r'\w', '_')), True) + + self.assertEqual(regex.match(r"\X", "\xE0").span(), (0, 1)) + self.assertEqual(regex.match(r"\X", "a\u0300").span(), (0, 2)) + self.assertEqual(regex.findall(r"\X", + "a\xE0a\u0300e\xE9e\u0301"), ['a', '\xe0', 'a\u0300', 'e', + '\xe9', 'e\u0301']) + self.assertEqual(regex.findall(r"\X{3}", + "a\xE0a\u0300e\xE9e\u0301"), ['a\xe0a\u0300', 'e\xe9e\u0301']) + self.assertEqual(regex.findall(r"\X", "\r\r\n\u0301A\u0301"), + ['\r', '\r\n', '\u0301', 'A\u0301']) + + self.assertEqual(bool(regex.match(r'\p{Ll}', 'a')), True) + + chars_u = "-09AZaz_\u0393\u03b3" + chars_b = b"-09AZaz_" + word_set = set("Ll Lm Lo Lt Lu Mc Me Mn Nd Nl No Pc".split()) + + tests = [ + (r"\w", chars_u, "09AZaz_\u0393\u03b3"), + (r"[[:word:]]", chars_u, "09AZaz_\u0393\u03b3"), + (r"\W", chars_u, "-"), + (r"[[:^word:]]", chars_u, "-"), + (r"\d", chars_u, "09"), + (r"[[:digit:]]", chars_u, "09"), + (r"\D", chars_u, "-AZaz_\u0393\u03b3"), + (r"[[:^digit:]]", chars_u, "-AZaz_\u0393\u03b3"), + (r"[[:alpha:]]", chars_u, "AZaz\u0393\u03b3"), + (r"[[:^alpha:]]", chars_u, "-09_"), + (r"[[:alnum:]]", chars_u, "09AZaz\u0393\u03b3"), + (r"[[:^alnum:]]", chars_u, "-_"), + (r"[[:xdigit:]]", chars_u, "09Aa"), + (r"[[:^xdigit:]]", chars_u, "-Zz_\u0393\u03b3"), + (r"\p{InBasicLatin}", "a\xE1", "a"), + (r"\P{InBasicLatin}", "a\xE1", "\xE1"), + (r"(?i)\p{InBasicLatin}", "a\xE1", "a"), + (r"(?i)\P{InBasicLatin}", "a\xE1", "\xE1"), + + (br"(?L)\w", chars_b, b"09AZaz_"), + (br"(?L)[[:word:]]", chars_b, b"09AZaz_"), + (br"(?L)\W", chars_b, b"-"), + (br"(?L)[[:^word:]]", chars_b, b"-"), + (br"(?L)\d", chars_b, b"09"), + (br"(?L)[[:digit:]]", chars_b, b"09"), + (br"(?L)\D", chars_b, b"-AZaz_"), + (br"(?L)[[:^digit:]]", chars_b, b"-AZaz_"), + (br"(?L)[[:alpha:]]", chars_b, b"AZaz"), + (br"(?L)[[:^alpha:]]", chars_b, b"-09_"), + (br"(?L)[[:alnum:]]", chars_b, b"09AZaz"), + (br"(?L)[[:^alnum:]]", chars_b, b"-_"), + (br"(?L)[[:xdigit:]]", chars_b, b"09Aa"), + (br"(?L)[[:^xdigit:]]", chars_b, b"-Zz_"), + + (br"(?a)\w", chars_b, b"09AZaz_"), + (br"(?a)[[:word:]]", chars_b, b"09AZaz_"), + (br"(?a)\W", chars_b, b"-"), + (br"(?a)[[:^word:]]", chars_b, b"-"), + (br"(?a)\d", chars_b, b"09"), + (br"(?a)[[:digit:]]", chars_b, b"09"), + (br"(?a)\D", chars_b, b"-AZaz_"), + (br"(?a)[[:^digit:]]", chars_b, b"-AZaz_"), + (br"(?a)[[:alpha:]]", chars_b, b"AZaz"), + (br"(?a)[[:^alpha:]]", chars_b, b"-09_"), + (br"(?a)[[:alnum:]]", chars_b, b"09AZaz"), + (br"(?a)[[:^alnum:]]", chars_b, b"-_"), + (br"(?a)[[:xdigit:]]", chars_b, b"09Aa"), + (br"(?a)[[:^xdigit:]]", chars_b, b"-Zz_"), + ] + for pattern, chars, expected in tests: + try: + if chars[ : 0].join(regex.findall(pattern, chars)) != expected: + self.fail("Failed: {}".format(pattern)) + except Exception as e: + self.fail("Failed: {} raised {}".format(pattern, ascii(e))) + + self.assertEqual(bool(regex.match(r"\p{NumericValue=0}", "0")), + True) + self.assertEqual(bool(regex.match(r"\p{NumericValue=1/2}", + "\N{VULGAR FRACTION ONE HALF}")), True) + self.assertEqual(bool(regex.match(r"\p{NumericValue=0.5}", + "\N{VULGAR FRACTION ONE HALF}")), True) + + def test_word_class(self): + self.assertEqual(regex.findall(r"\w+", + " \u0939\u093f\u0928\u094d\u0926\u0940,"), + ['\u0939\u093f\u0928\u094d\u0926\u0940']) + self.assertEqual(regex.findall(r"\W+", + " \u0939\u093f\u0928\u094d\u0926\u0940,"), [' ', ',']) + self.assertEqual(regex.split(r"(?V1)\b", + " \u0939\u093f\u0928\u094d\u0926\u0940,"), [' ', + '\u0939\u093f\u0928\u094d\u0926\u0940', ',']) + self.assertEqual(regex.split(r"(?V1)\B", + " \u0939\u093f\u0928\u094d\u0926\u0940,"), ['', ' \u0939', + '\u093f', '\u0928', '\u094d', '\u0926', '\u0940,', '']) + + def test_search_anchor(self): + self.assertEqual(regex.findall(r"\G\w{2}", "abcd ef"), ['ab', 'cd']) + + def test_search_reverse(self): + self.assertEqual(regex.findall(r"(?r).", "abc"), ['c', 'b', 'a']) + self.assertEqual(regex.findall(r"(?r).", "abc", overlapped=True), ['c', + 'b', 'a']) + self.assertEqual(regex.findall(r"(?r)..", "abcde"), ['de', 'bc']) + self.assertEqual(regex.findall(r"(?r)..", "abcde", overlapped=True), + ['de', 'cd', 'bc', 'ab']) + self.assertEqual(regex.findall(r"(?r)(.)(-)(.)", "a-b-c", + overlapped=True), [("b", "-", "c"), ("a", "-", "b")]) + + self.assertEqual([m[0] for m in regex.finditer(r"(?r).", "abc")], ['c', + 'b', 'a']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r)..", "abcde", + overlapped=True)], ['de', 'cd', 'bc', 'ab']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r).", "abc")], ['c', + 'b', 'a']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r)..", "abcde", + overlapped=True)], ['de', 'cd', 'bc', 'ab']) + + self.assertEqual(regex.findall(r"^|\w+", "foo bar"), ['', 'foo', + 'bar']) + self.assertEqual(regex.findall(r"(?V1)^|\w+", "foo bar"), ['', 'foo', + 'bar']) + self.assertEqual(regex.findall(r"(?r)^|\w+", "foo bar"), ['bar', 'foo', + '']) + self.assertEqual(regex.findall(r"(?rV1)^|\w+", "foo bar"), ['bar', + 'foo', '']) + + self.assertEqual([m[0] for m in regex.finditer(r"^|\w+", "foo bar")], + ['', 'foo', 'bar']) + self.assertEqual([m[0] for m in regex.finditer(r"(?V1)^|\w+", + "foo bar")], ['', 'foo', 'bar']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r)^|\w+", + "foo bar")], ['bar', 'foo', '']) + self.assertEqual([m[0] for m in regex.finditer(r"(?rV1)^|\w+", + "foo bar")], ['bar', 'foo', '']) + + self.assertEqual(regex.findall(r"\G\w{2}", "abcd ef"), ['ab', 'cd']) + self.assertEqual(regex.findall(r".{2}(?<=\G.*)", "abcd"), ['ab', 'cd']) + self.assertEqual(regex.findall(r"(?r)\G\w{2}", "abcd ef"), []) + self.assertEqual(regex.findall(r"(?r)\w{2}\G", "abcd ef"), ['ef']) + + self.assertEqual(regex.findall(r"q*", "qqwe"), ['qq', '', '', '']) + self.assertEqual(regex.findall(r"(?V1)q*", "qqwe"), ['qq', '', '', '']) + self.assertEqual(regex.findall(r"(?r)q*", "qqwe"), ['', '', 'qq', '']) + self.assertEqual(regex.findall(r"(?rV1)q*", "qqwe"), ['', '', 'qq', + '']) + + self.assertEqual(regex.findall(".", "abcd", pos=1, endpos=3), ['b', + 'c']) + self.assertEqual(regex.findall(".", "abcd", pos=1, endpos=-1), ['b', + 'c']) + self.assertEqual([m[0] for m in regex.finditer(".", "abcd", pos=1, + endpos=3)], ['b', 'c']) + self.assertEqual([m[0] for m in regex.finditer(".", "abcd", pos=1, + endpos=-1)], ['b', 'c']) + + self.assertEqual([m[0] for m in regex.finditer("(?r).", "abcd", pos=1, + endpos=3)], ['c', 'b']) + self.assertEqual([m[0] for m in regex.finditer("(?r).", "abcd", pos=1, + endpos=-1)], ['c', 'b']) + self.assertEqual(regex.findall("(?r).", "abcd", pos=1, endpos=3), ['c', + 'b']) + self.assertEqual(regex.findall("(?r).", "abcd", pos=1, endpos=-1), + ['c', 'b']) + + self.assertEqual(regex.findall(r"[ab]", "aB", regex.I), ['a', 'B']) + self.assertEqual(regex.findall(r"(?r)[ab]", "aB", regex.I), ['B', 'a']) + + self.assertEqual(regex.findall(r"(?r).{2}", "abc"), ['bc']) + self.assertEqual(regex.findall(r"(?r).{2}", "abc", overlapped=True), + ['bc', 'ab']) + self.assertEqual(regex.findall(r"(\w+) (\w+)", + "first second third fourth fifth"), [('first', 'second'), ('third', + 'fourth')]) + self.assertEqual(regex.findall(r"(?r)(\w+) (\w+)", + "first second third fourth fifth"), [('fourth', 'fifth'), ('second', + 'third')]) + + self.assertEqual([m[0] for m in regex.finditer(r"(?r).{2}", "abc")], + ['bc']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r).{2}", "abc", + overlapped=True)], ['bc', 'ab']) + self.assertEqual([m[0] for m in regex.finditer(r"(\w+) (\w+)", + "first second third fourth fifth")], ['first second', + 'third fourth']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r)(\w+) (\w+)", + "first second third fourth fifth")], ['fourth fifth', + 'second third']) + + self.assertEqual(regex.search("abcdef", "abcdef").span(), (0, 6)) + self.assertEqual(regex.search("(?r)abcdef", "abcdef").span(), (0, 6)) + self.assertEqual(regex.search("(?i)abcdef", "ABCDEF").span(), (0, 6)) + self.assertEqual(regex.search("(?ir)abcdef", "ABCDEF").span(), (0, 6)) + + self.assertEqual(regex.sub(r"(.)", r"\1", "abc"), 'abc') + self.assertEqual(regex.sub(r"(?r)(.)", r"\1", "abc"), 'abc') + + def test_atomic(self): + # Issue 433030. + self.assertEqual(regex.search(r"(?>a*)a", "aa"), None) + + def test_possessive(self): + # Single-character non-possessive. + self.assertEqual(regex.search(r"a?a", "a").span(), (0, 1)) + self.assertEqual(regex.search(r"a*a", "aaa").span(), (0, 3)) + self.assertEqual(regex.search(r"a+a", "aaa").span(), (0, 3)) + self.assertEqual(regex.search(r"a{1,3}a", "aaa").span(), (0, 3)) + + # Multiple-character non-possessive. + self.assertEqual(regex.search(r"(?:ab)?ab", "ab").span(), (0, 2)) + self.assertEqual(regex.search(r"(?:ab)*ab", "ababab").span(), (0, 6)) + self.assertEqual(regex.search(r"(?:ab)+ab", "ababab").span(), (0, 6)) + self.assertEqual(regex.search(r"(?:ab){1,3}ab", "ababab").span(), (0, + 6)) + + # Single-character possessive. + self.assertEqual(regex.search(r"a?+a", "a"), None) + self.assertEqual(regex.search(r"a*+a", "aaa"), None) + self.assertEqual(regex.search(r"a++a", "aaa"), None) + self.assertEqual(regex.search(r"a{1,3}+a", "aaa"), None) + + # Multiple-character possessive. + self.assertEqual(regex.search(r"(?:ab)?+ab", "ab"), None) + self.assertEqual(regex.search(r"(?:ab)*+ab", "ababab"), None) + self.assertEqual(regex.search(r"(?:ab)++ab", "ababab"), None) + self.assertEqual(regex.search(r"(?:ab){1,3}+ab", "ababab"), None) + + def test_zerowidth(self): + # Issue 3262. + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.split(r"\b", "a b"), ['', 'a', ' ', 'b', + '']) + else: + self.assertEqual(regex.split(r"\b", "a b"), ['a b']) + self.assertEqual(regex.split(r"(?V1)\b", "a b"), ['', 'a', ' ', 'b', + '']) + + # Issue 1647489. + self.assertEqual(regex.findall(r"^|\w+", "foo bar"), ['', 'foo', + 'bar']) + self.assertEqual([m[0] for m in regex.finditer(r"^|\w+", "foo bar")], + ['', 'foo', 'bar']) + self.assertEqual(regex.findall(r"(?r)^|\w+", "foo bar"), ['bar', + 'foo', '']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r)^|\w+", + "foo bar")], ['bar', 'foo', '']) + self.assertEqual(regex.findall(r"(?V1)^|\w+", "foo bar"), ['', 'foo', + 'bar']) + self.assertEqual([m[0] for m in regex.finditer(r"(?V1)^|\w+", + "foo bar")], ['', 'foo', 'bar']) + self.assertEqual(regex.findall(r"(?rV1)^|\w+", "foo bar"), ['bar', + 'foo', '']) + self.assertEqual([m[0] for m in regex.finditer(r"(?rV1)^|\w+", + "foo bar")], ['bar', 'foo', '']) + + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.split("", "xaxbxc"), ['', 'x', 'a', 'x', + 'b', 'x', 'c', '']) + self.assertEqual([m for m in regex.splititer("", "xaxbxc")], ['', + 'x', 'a', 'x', 'b', 'x', 'c', '']) + else: + self.assertEqual(regex.split("", "xaxbxc"), ['xaxbxc']) + self.assertEqual([m for m in regex.splititer("", "xaxbxc")], + ['xaxbxc']) + + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.split("(?r)", "xaxbxc"), ['', 'c', 'x', 'b', + 'x', 'a', 'x', '']) + self.assertEqual([m for m in regex.splititer("(?r)", "xaxbxc")], + ['', 'c', 'x', 'b', 'x', 'a', 'x', '']) + else: + self.assertEqual(regex.split("(?r)", "xaxbxc"), ['xaxbxc']) + self.assertEqual([m for m in regex.splititer("(?r)", "xaxbxc")], + ['xaxbxc']) + + self.assertEqual(regex.split("(?V1)", "xaxbxc"), ['', 'x', 'a', 'x', + 'b', 'x', 'c', '']) + self.assertEqual([m for m in regex.splititer("(?V1)", "xaxbxc")], ['', + 'x', 'a', 'x', 'b', 'x', 'c', '']) + + self.assertEqual(regex.split("(?rV1)", "xaxbxc"), ['', 'c', 'x', 'b', + 'x', 'a', 'x', '']) + self.assertEqual([m for m in regex.splititer("(?rV1)", "xaxbxc")], ['', + 'c', 'x', 'b', 'x', 'a', 'x', '']) + + def test_scoped_and_inline_flags(self): + # Issues 433028, 433024, 433027. + self.assertEqual(regex.search(r"(?i)Ab", "ab").span(), (0, 2)) + self.assertEqual(regex.search(r"(?i:A)b", "ab").span(), (0, 2)) + # Changed to positional flags in regex 2023.12.23. + self.assertEqual(regex.search(r"A(?i)b", "ab"), None) + + self.assertEqual(regex.search(r"(?V0)Ab", "ab"), None) + self.assertEqual(regex.search(r"(?V1)Ab", "ab"), None) + self.assertEqual(regex.search(r"(?-i)Ab", "ab", flags=regex.I), None) + self.assertEqual(regex.search(r"(?-i:A)b", "ab", flags=regex.I), None) + self.assertEqual(regex.search(r"A(?-i)b", "ab", flags=regex.I).span(), + (0, 2)) + + def test_repeated_repeats(self): + # Issue 2537. + self.assertEqual(regex.search(r"(?:a+)+", "aaa").span(), (0, 3)) + self.assertEqual(regex.search(r"(?:(?:ab)+c)+", "abcabc").span(), (0, + 6)) + + # Hg issue 286. + self.assertEqual(regex.search(r"(?:a+){2,}", "aaa").span(), (0, 3)) + + def test_lookbehind(self): + self.assertEqual(regex.search(r"123(?<=a\d+)", "a123").span(), (1, 4)) + self.assertEqual(regex.search(r"123(?<=a\d+)", "b123"), None) + self.assertEqual(regex.search(r"123(?= (3, 7, 0): + self.assertEqual(regex.sub(r"(?V0)(x)?(y)?", r"\2-\1", "xy"), + 'y-x-') + else: + self.assertEqual(regex.sub(r"(?V0)(x)?(y)?", r"\2-\1", "xy"), + 'y-x') + self.assertEqual(regex.sub(r"(?V1)(x)?(y)?", r"\2-\1", "xy"), 'y-x-') + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.sub(r"(?V0)(x)?(y)?", r"\2-\1", "x"), '-x-') + else: + self.assertEqual(regex.sub(r"(?V0)(x)?(y)?", r"\2-\1", "x"), '-x') + self.assertEqual(regex.sub(r"(?V1)(x)?(y)?", r"\2-\1", "x"), '-x-') + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.sub(r"(?V0)(x)?(y)?", r"\2-\1", "y"), 'y--') + else: + self.assertEqual(regex.sub(r"(?V0)(x)?(y)?", r"\2-\1", "y"), 'y-') + self.assertEqual(regex.sub(r"(?V1)(x)?(y)?", r"\2-\1", "y"), 'y--') + + def test_bug_10328 (self): + # Issue 10328. + pat = regex.compile(r'(?mV0)(?P[ \t]+\r*$)|(?P(?<=[^\n])\Z)') + if sys.version_info >= (3, 7, 0): + self.assertEqual(pat.subn(lambda m: '<' + m.lastgroup + '>', + 'foobar '), ('foobar', 2)) + else: + self.assertEqual(pat.subn(lambda m: '<' + m.lastgroup + '>', + 'foobar '), ('foobar', 1)) + self.assertEqual([m.group() for m in pat.finditer('foobar ')], [' ', + '']) + pat = regex.compile(r'(?mV1)(?P[ \t]+\r*$)|(?P(?<=[^\n])\Z)') + self.assertEqual(pat.subn(lambda m: '<' + m.lastgroup + '>', + 'foobar '), ('foobar', 2)) + self.assertEqual([m.group() for m in pat.finditer('foobar ')], [' ', + '']) + + def test_overlapped(self): + self.assertEqual(regex.findall(r"..", "abcde"), ['ab', 'cd']) + self.assertEqual(regex.findall(r"..", "abcde", overlapped=True), ['ab', + 'bc', 'cd', 'de']) + self.assertEqual(regex.findall(r"(?r)..", "abcde"), ['de', 'bc']) + self.assertEqual(regex.findall(r"(?r)..", "abcde", overlapped=True), + ['de', 'cd', 'bc', 'ab']) + self.assertEqual(regex.findall(r"(.)(-)(.)", "a-b-c", overlapped=True), + [("a", "-", "b"), ("b", "-", "c")]) + + self.assertEqual([m[0] for m in regex.finditer(r"..", "abcde")], ['ab', + 'cd']) + self.assertEqual([m[0] for m in regex.finditer(r"..", "abcde", + overlapped=True)], ['ab', 'bc', 'cd', 'de']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r)..", "abcde")], + ['de', 'bc']) + self.assertEqual([m[0] for m in regex.finditer(r"(?r)..", "abcde", + overlapped=True)], ['de', 'cd', 'bc', 'ab']) + + self.assertEqual([m.groups() for m in regex.finditer(r"(.)(-)(.)", + "a-b-c", overlapped=True)], [("a", "-", "b"), ("b", "-", "c")]) + self.assertEqual([m.groups() for m in regex.finditer(r"(?r)(.)(-)(.)", + "a-b-c", overlapped=True)], [("b", "-", "c"), ("a", "-", "b")]) + + def test_splititer(self): + self.assertEqual(regex.split(r",", "a,b,,c,"), ['a', 'b', '', 'c', '']) + self.assertEqual([m for m in regex.splititer(r",", "a,b,,c,")], ['a', + 'b', '', 'c', '']) + + def test_grapheme(self): + self.assertEqual(regex.match(r"\X", "\xE0").span(), (0, 1)) + self.assertEqual(regex.match(r"\X", "a\u0300").span(), (0, 2)) + + self.assertEqual(regex.findall(r"\X", + "a\xE0a\u0300e\xE9e\u0301"), ['a', '\xe0', 'a\u0300', 'e', + '\xe9', 'e\u0301']) + self.assertEqual(regex.findall(r"\X{3}", + "a\xE0a\u0300e\xE9e\u0301"), ['a\xe0a\u0300', 'e\xe9e\u0301']) + self.assertEqual(regex.findall(r"\X", "\r\r\n\u0301A\u0301"), + ['\r', '\r\n', '\u0301', 'A\u0301']) + + def test_word_boundary(self): + text = 'The quick ("brown") fox can\'t jump 32.3 feet, right?' + self.assertEqual(regex.split(r'(?V1)\b', text), ['', 'The', ' ', + 'quick', ' ("', 'brown', '") ', 'fox', ' ', 'can', "'", 't', + ' ', 'jump', ' ', '32', '.', '3', ' ', 'feet', ', ', + 'right', '?']) + self.assertEqual(regex.split(r'(?V1w)\b', text), ['', 'The', ' ', + 'quick', ' ', '(', '"', 'brown', '"', ')', ' ', 'fox', ' ', + "can't", ' ', 'jump', ' ', '32.3', ' ', 'feet', ',', ' ', + 'right', '?', '']) + + text = "The fox" + self.assertEqual(regex.split(r'(?V1)\b', text), ['', 'The', ' ', + 'fox', '']) + self.assertEqual(regex.split(r'(?V1w)\b', text), ['', 'The', ' ', + 'fox', '']) + + text = "can't aujourd'hui l'objectif" + self.assertEqual(regex.split(r'(?V1)\b', text), ['', 'can', "'", + 't', ' ', 'aujourd', "'", 'hui', ' ', 'l', "'", 'objectif', + '']) + self.assertEqual(regex.split(r'(?V1w)\b', text), ['', "can't", ' ', + "aujourd'hui", ' ', "l'objectif", '']) + + def test_line_boundary(self): + self.assertEqual(regex.findall(r".+", "Line 1\nLine 2\n"), ["Line 1", + "Line 2"]) + self.assertEqual(regex.findall(r".+", "Line 1\rLine 2\r"), + ["Line 1\rLine 2\r"]) + self.assertEqual(regex.findall(r".+", "Line 1\r\nLine 2\r\n"), + ["Line 1\r", "Line 2\r"]) + self.assertEqual(regex.findall(r"(?w).+", "Line 1\nLine 2\n"), + ["Line 1", "Line 2"]) + self.assertEqual(regex.findall(r"(?w).+", "Line 1\rLine 2\r"), + ["Line 1", "Line 2"]) + self.assertEqual(regex.findall(r"(?w).+", "Line 1\r\nLine 2\r\n"), + ["Line 1", "Line 2"]) + + self.assertEqual(regex.search(r"^abc", "abc").start(), 0) + self.assertEqual(regex.search(r"^abc", "\nabc"), None) + self.assertEqual(regex.search(r"^abc", "\rabc"), None) + self.assertEqual(regex.search(r"(?w)^abc", "abc").start(), 0) + self.assertEqual(regex.search(r"(?w)^abc", "\nabc"), None) + self.assertEqual(regex.search(r"(?w)^abc", "\rabc"), None) + + self.assertEqual(regex.search(r"abc$", "abc").start(), 0) + self.assertEqual(regex.search(r"abc$", "abc\n").start(), 0) + self.assertEqual(regex.search(r"abc$", "abc\r"), None) + self.assertEqual(regex.search(r"(?w)abc$", "abc").start(), 0) + self.assertEqual(regex.search(r"(?w)abc$", "abc\n").start(), 0) + self.assertEqual(regex.search(r"(?w)abc$", "abc\r").start(), 0) + + self.assertEqual(regex.search(r"(?m)^abc", "abc").start(), 0) + self.assertEqual(regex.search(r"(?m)^abc", "\nabc").start(), 1) + self.assertEqual(regex.search(r"(?m)^abc", "\rabc"), None) + self.assertEqual(regex.search(r"(?mw)^abc", "abc").start(), 0) + self.assertEqual(regex.search(r"(?mw)^abc", "\nabc").start(), 1) + self.assertEqual(regex.search(r"(?mw)^abc", "\rabc").start(), 1) + + self.assertEqual(regex.search(r"(?m)abc$", "abc").start(), 0) + self.assertEqual(regex.search(r"(?m)abc$", "abc\n").start(), 0) + self.assertEqual(regex.search(r"(?m)abc$", "abc\r"), None) + self.assertEqual(regex.search(r"(?mw)abc$", "abc").start(), 0) + self.assertEqual(regex.search(r"(?mw)abc$", "abc\n").start(), 0) + self.assertEqual(regex.search(r"(?mw)abc$", "abc\r").start(), 0) + + def test_branch_reset(self): + self.assertEqual(regex.match(r"(?:(a)|(b))(c)", "ac").groups(), ('a', + None, 'c')) + self.assertEqual(regex.match(r"(?:(a)|(b))(c)", "bc").groups(), (None, + 'b', 'c')) + self.assertEqual(regex.match(r"(?:(?a)|(?b))(?c)", + "ac").groups(), ('a', None, 'c')) + self.assertEqual(regex.match(r"(?:(?a)|(?b))(?c)", + "bc").groups(), (None, 'b', 'c')) + + self.assertEqual(regex.match(r"(?a)(?:(?b)|(?c))(?d)", + "abd").groups(), ('a', 'b', None, 'd')) + self.assertEqual(regex.match(r"(?a)(?:(?b)|(?c))(?d)", + "acd").groups(), ('a', None, 'c', 'd')) + self.assertEqual(regex.match(r"(a)(?:(b)|(c))(d)", "abd").groups(), + ('a', 'b', None, 'd')) + + self.assertEqual(regex.match(r"(a)(?:(b)|(c))(d)", "acd").groups(), + ('a', None, 'c', 'd')) + self.assertEqual(regex.match(r"(a)(?|(b)|(b))(d)", "abd").groups(), + ('a', 'b', 'd')) + self.assertEqual(regex.match(r"(?|(?a)|(?b))(c)", "ac").groups(), + ('a', None, 'c')) + self.assertEqual(regex.match(r"(?|(?a)|(?b))(c)", "bc").groups(), + (None, 'b', 'c')) + self.assertEqual(regex.match(r"(?|(?a)|(?b))(c)", "ac").groups(), + ('a', 'c')) + + self.assertEqual(regex.match(r"(?|(?a)|(?b))(c)", "bc").groups(), + ('b', 'c')) + + self.assertEqual(regex.match(r"(?|(?a)(?b)|(?c)(?d))(e)", + "abe").groups(), ('a', 'b', 'e')) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(?c)(?d))(e)", + "cde").groups(), ('d', 'c', 'e')) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(?c)(d))(e)", + "abe").groups(), ('a', 'b', 'e')) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(?c)(d))(e)", + "cde").groups(), ('d', 'c', 'e')) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(c)(d))(e)", + "abe").groups(), ('a', 'b', 'e')) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(c)(d))(e)", + "cde").groups(), ('c', 'd', 'e')) + + # Hg issue 87: Allow duplicate names of groups + self.assertEqual(regex.match(r"(?|(?a)(?b)|(c)(?d))(e)", + "abe").groups(), ("a", "b", "e")) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(c)(?d))(e)", + "abe").capturesdict(), {"a": ["a"], "b": ["b"]}) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(c)(?d))(e)", + "cde").groups(), ("d", None, "e")) + self.assertEqual(regex.match(r"(?|(?a)(?b)|(c)(?d))(e)", + "cde").capturesdict(), {"a": ["c", "d"], "b": []}) + + def test_set(self): + self.assertEqual(regex.match(r"[a]", "a").span(), (0, 1)) + self.assertEqual(regex.match(r"(?i)[a]", "A").span(), (0, 1)) + self.assertEqual(regex.match(r"[a-b]", r"a").span(), (0, 1)) + self.assertEqual(regex.match(r"(?i)[a-b]", r"A").span(), (0, 1)) + + self.assertEqual(regex.sub(r"(?V0)([][])", r"-", "a[b]c"), "a-b-c") + + self.assertEqual(regex.findall(r"[\p{Alpha}]", "a0"), ["a"]) + self.assertEqual(regex.findall(r"(?i)[\p{Alpha}]", "A0"), ["A"]) + + self.assertEqual(regex.findall(r"[a\p{Alpha}]", "ab0"), ["a", "b"]) + self.assertEqual(regex.findall(r"[a\P{Alpha}]", "ab0"), ["a", "0"]) + self.assertEqual(regex.findall(r"(?i)[a\p{Alpha}]", "ab0"), ["a", + "b"]) + self.assertEqual(regex.findall(r"(?i)[a\P{Alpha}]", "ab0"), ["a", + "0"]) + + self.assertEqual(regex.findall(r"[a-b\p{Alpha}]", "abC0"), ["a", + "b", "C"]) + self.assertEqual(regex.findall(r"(?i)[a-b\p{Alpha}]", "AbC0"), ["A", + "b", "C"]) + + self.assertEqual(regex.findall(r"[\p{Alpha}]", "a0"), ["a"]) + self.assertEqual(regex.findall(r"[\P{Alpha}]", "a0"), ["0"]) + self.assertEqual(regex.findall(r"[^\p{Alpha}]", "a0"), ["0"]) + self.assertEqual(regex.findall(r"[^\P{Alpha}]", "a0"), ["a"]) + + self.assertEqual("".join(regex.findall(r"[^\d-h]", "a^b12c-h")), + 'a^bc') + self.assertEqual("".join(regex.findall(r"[^\dh]", "a^b12c-h")), + 'a^bc-') + self.assertEqual("".join(regex.findall(r"[^h\s\db]", "a^b 12c-h")), + 'a^c-') + self.assertEqual("".join(regex.findall(r"[^b\w]", "a b")), ' ') + self.assertEqual("".join(regex.findall(r"[^b\S]", "a b")), ' ') + self.assertEqual("".join(regex.findall(r"[^8\d]", "a 1b2")), 'a b') + + all_chars = "".join(chr(c) for c in range(0x100)) + self.assertEqual(len(regex.findall(r"\p{ASCII}", all_chars)), 128) + self.assertEqual(len(regex.findall(r"\p{Letter}", all_chars)), + 117) + self.assertEqual(len(regex.findall(r"\p{Digit}", all_chars)), 10) + + # Set operators + self.assertEqual(len(regex.findall(r"(?V1)[\p{ASCII}&&\p{Letter}]", + all_chars)), 52) + self.assertEqual(len(regex.findall(r"(?V1)[\p{ASCII}&&\p{Alnum}&&\p{Letter}]", + all_chars)), 52) + self.assertEqual(len(regex.findall(r"(?V1)[\p{ASCII}&&\p{Alnum}&&\p{Digit}]", + all_chars)), 10) + self.assertEqual(len(regex.findall(r"(?V1)[\p{ASCII}&&\p{Cc}]", + all_chars)), 33) + self.assertEqual(len(regex.findall(r"(?V1)[\p{ASCII}&&\p{Graph}]", + all_chars)), 94) + self.assertEqual(len(regex.findall(r"(?V1)[\p{ASCII}--\p{Cc}]", + all_chars)), 95) + self.assertEqual(len(regex.findall(r"[\p{Letter}\p{Digit}]", + all_chars)), 127) + self.assertEqual(len(regex.findall(r"(?V1)[\p{Letter}||\p{Digit}]", + all_chars)), 127) + self.assertEqual(len(regex.findall(r"\p{HexDigit}", all_chars)), + 22) + self.assertEqual(len(regex.findall(r"(?V1)[\p{HexDigit}~~\p{Digit}]", + all_chars)), 12) + self.assertEqual(len(regex.findall(r"(?V1)[\p{Digit}~~\p{HexDigit}]", + all_chars)), 12) + + self.assertEqual(repr(type(regex.compile(r"(?V0)([][-])"))), + self.PATTERN_CLASS) + self.assertEqual(regex.findall(r"(?V1)[[a-z]--[aei]]", "abc"), ["b", + "c"]) + self.assertEqual(regex.findall(r"(?iV1)[[a-z]--[aei]]", "abc"), ["b", + "c"]) + self.assertEqual(regex.findall(r"(?V1)[\w--a]","abc"), ["b", "c"]) + self.assertEqual(regex.findall(r"(?iV1)[\w--a]","abc"), ["b", "c"]) + + def test_various(self): + tests = [ + # Test ?P< and ?P= extensions. + ('(?Pa)', '', '', regex.error, self.BAD_GROUP_NAME), # Begins with a digit. + ('(?Pa)', '', '', regex.error, self.BAD_GROUP_NAME), # Begins with an illegal char. + ('(?Pa)', '', '', regex.error, self.BAD_GROUP_NAME), # Begins with an illegal char. + + # Same tests, for the ?P= form. + ('(?Pa)(?P=foo_123', 'aa', '', regex.error, + self.MISSING_RPAREN), + ('(?Pa)(?P=1)', 'aa', '1', ascii('a')), + ('(?Pa)(?P=0)', 'aa', '', regex.error, + self.BAD_GROUP_NAME), + ('(?Pa)(?P=-1)', 'aa', '', regex.error, + self.BAD_GROUP_NAME), + ('(?Pa)(?P=!)', 'aa', '', regex.error, + self.BAD_GROUP_NAME), + ('(?Pa)(?P=foo_124)', 'aa', '', regex.error, + self.UNKNOWN_GROUP), # Backref to undefined group. + + ('(?Pa)', 'a', '1', ascii('a')), + ('(?Pa)(?P=foo_123)', 'aa', '1', ascii('a')), + + # Mal-formed \g in pattern treated as literal for compatibility. + (r'(?a)\ga)\g<1>', 'aa', '1', ascii('a')), + (r'(?a)\g', 'aa', '', ascii(None)), + (r'(?a)\g', 'aa', '', regex.error, + self.UNKNOWN_GROUP), # Backref to undefined group. + + ('(?a)', 'a', '1', ascii('a')), + (r'(?a)\g', 'aa', '1', ascii('a')), + + # Test octal escapes. + ('\\1', 'a', '', regex.error, self.INVALID_GROUP_REF), # Backreference. + ('[\\1]', '\1', '0', "'\\x01'"), # Character. + ('\\09', chr(0) + '9', '0', ascii(chr(0) + '9')), + ('\\141', 'a', '0', ascii('a')), + ('(a)(b)(c)(d)(e)(f)(g)(h)(i)(j)(k)(l)\\119', 'abcdefghijklk9', + '0,11', ascii(('abcdefghijklk9', 'k'))), + + # Test \0 is handled everywhere. + (r'\0', '\0', '0', ascii('\0')), + (r'[\0a]', '\0', '0', ascii('\0')), + (r'[a\0]', '\0', '0', ascii('\0')), + (r'[^a\0]', '\0', '', ascii(None)), + + # Test various letter escapes. + (r'\a[\b]\f\n\r\t\v', '\a\b\f\n\r\t\v', '0', + ascii('\a\b\f\n\r\t\v')), + (r'[\a][\b][\f][\n][\r][\t][\v]', '\a\b\f\n\r\t\v', '0', + ascii('\a\b\f\n\r\t\v')), + (r'\xff', '\377', '0', ascii(chr(255))), + + # New \x semantics. + (r'\x00ffffffffffffff', '\377', '', ascii(None)), + (r'\x00f', '\017', '', ascii(None)), + (r'\x00fe', '\376', '', ascii(None)), + + (r'\x00ff', '\377', '', ascii(None)), + (r'\t\n\v\r\f\a\g', '\t\n\v\r\f\ag', '0', ascii('\t\n\v\r\f\ag')), + ('\t\n\v\r\f\a\\g', '\t\n\v\r\f\ag', '0', ascii('\t\n\v\r\f\ag')), + (r'\t\n\v\r\f\a', '\t\n\v\r\f\a', '0', ascii(chr(9) + chr(10) + + chr(11) + chr(13) + chr(12) + chr(7))), + (r'[\t][\n][\v][\r][\f][\b]', '\t\n\v\r\f\b', '0', + ascii('\t\n\v\r\f\b')), + + (r"^\w+=(\\[\000-\277]|[^\n\\])*", + "SRC=eval.c g.c blah blah blah \\\\\n\tapes.c", '0', + ascii("SRC=eval.c g.c blah blah blah \\\\")), + + # Test that . only matches \n in DOTALL mode. + ('a.b', 'acb', '0', ascii('acb')), + ('a.b', 'a\nb', '', ascii(None)), + ('a.*b', 'acc\nccb', '', ascii(None)), + ('a.{4,5}b', 'acc\nccb', '', ascii(None)), + ('a.b', 'a\rb', '0', ascii('a\rb')), + # Changed to positional flags in regex 2023.12.23. + ('a.b(?s)', 'a\nb', '', ascii(None)), + ('(?s)a.b', 'a\nb', '0', ascii('a\nb')), + ('a.*(?s)b', 'acc\nccb', '', ascii(None)), + ('(?s)a.*b', 'acc\nccb', '0', ascii('acc\nccb')), + ('(?s)a.{4,5}b', 'acc\nccb', '0', ascii('acc\nccb')), + + (')', '', '', regex.error, self.TRAILING_CHARS), # Unmatched right bracket. + ('', '', '0', "''"), # Empty pattern. + ('abc', 'abc', '0', ascii('abc')), + ('abc', 'xbc', '', ascii(None)), + ('abc', 'axc', '', ascii(None)), + ('abc', 'abx', '', ascii(None)), + ('abc', 'xabcy', '0', ascii('abc')), + ('abc', 'ababc', '0', ascii('abc')), + ('ab*c', 'abc', '0', ascii('abc')), + ('ab*bc', 'abc', '0', ascii('abc')), + + ('ab*bc', 'abbc', '0', ascii('abbc')), + ('ab*bc', 'abbbbc', '0', ascii('abbbbc')), + ('ab+bc', 'abbc', '0', ascii('abbc')), + ('ab+bc', 'abc', '', ascii(None)), + ('ab+bc', 'abq', '', ascii(None)), + ('ab+bc', 'abbbbc', '0', ascii('abbbbc')), + ('ab?bc', 'abbc', '0', ascii('abbc')), + ('ab?bc', 'abc', '0', ascii('abc')), + ('ab?bc', 'abbbbc', '', ascii(None)), + ('ab?c', 'abc', '0', ascii('abc')), + + ('^abc$', 'abc', '0', ascii('abc')), + ('^abc$', 'abcc', '', ascii(None)), + ('^abc', 'abcc', '0', ascii('abc')), + ('^abc$', 'aabc', '', ascii(None)), + ('abc$', 'aabc', '0', ascii('abc')), + ('^', 'abc', '0', ascii('')), + ('$', 'abc', '0', ascii('')), + ('a.c', 'abc', '0', ascii('abc')), + ('a.c', 'axc', '0', ascii('axc')), + ('a.*c', 'axyzc', '0', ascii('axyzc')), + + ('a.*c', 'axyzd', '', ascii(None)), + ('a[bc]d', 'abc', '', ascii(None)), + ('a[bc]d', 'abd', '0', ascii('abd')), + ('a[b-d]e', 'abd', '', ascii(None)), + ('a[b-d]e', 'ace', '0', ascii('ace')), + ('a[b-d]', 'aac', '0', ascii('ac')), + ('a[-b]', 'a-', '0', ascii('a-')), + ('a[\\-b]', 'a-', '0', ascii('a-')), + ('a[b-]', 'a-', '0', ascii('a-')), + ('a[]b', '-', '', regex.error, self.BAD_SET), + + ('a[', '-', '', regex.error, self.BAD_SET), + ('a\\', '-', '', regex.error, self.BAD_ESCAPE), + ('abc)', '-', '', regex.error, self.TRAILING_CHARS), + ('(abc', '-', '', regex.error, self.MISSING_RPAREN), + ('a]', 'a]', '0', ascii('a]')), + ('a[]]b', 'a]b', '0', ascii('a]b')), + ('a[]]b', 'a]b', '0', ascii('a]b')), + ('a[^bc]d', 'aed', '0', ascii('aed')), + ('a[^bc]d', 'abd', '', ascii(None)), + ('a[^-b]c', 'adc', '0', ascii('adc')), + + ('a[^-b]c', 'a-c', '', ascii(None)), + ('a[^]b]c', 'a]c', '', ascii(None)), + ('a[^]b]c', 'adc', '0', ascii('adc')), + ('\\ba\\b', 'a-', '0', ascii('a')), + ('\\ba\\b', '-a', '0', ascii('a')), + ('\\ba\\b', '-a-', '0', ascii('a')), + ('\\by\\b', 'xy', '', ascii(None)), + ('\\by\\b', 'yz', '', ascii(None)), + ('\\by\\b', 'xyz', '', ascii(None)), + ('x\\b', 'xyz', '', ascii(None)), + + ('x\\B', 'xyz', '0', ascii('x')), + ('\\Bz', 'xyz', '0', ascii('z')), + ('z\\B', 'xyz', '', ascii(None)), + ('\\Bx', 'xyz', '', ascii(None)), + ('\\Ba\\B', 'a-', '', ascii(None)), + ('\\Ba\\B', '-a', '', ascii(None)), + ('\\Ba\\B', '-a-', '', ascii(None)), + ('\\By\\B', 'xy', '', ascii(None)), + ('\\By\\B', 'yz', '', ascii(None)), + ('\\By\\b', 'xy', '0', ascii('y')), + + ('\\by\\B', 'yz', '0', ascii('y')), + ('\\By\\B', 'xyz', '0', ascii('y')), + ('ab|cd', 'abc', '0', ascii('ab')), + ('ab|cd', 'abcd', '0', ascii('ab')), + ('()ef', 'def', '0,1', ascii(('ef', ''))), + ('$b', 'b', '', ascii(None)), + ('a\\(b', 'a(b', '', ascii(('a(b',))), + ('a\\(*b', 'ab', '0', ascii('ab')), + ('a\\(*b', 'a((b', '0', ascii('a((b')), + ('a\\\\b', 'a\\b', '0', ascii('a\\b')), + + ('((a))', 'abc', '0,1,2', ascii(('a', 'a', 'a'))), + ('(a)b(c)', 'abc', '0,1,2', ascii(('abc', 'a', 'c'))), + ('a+b+c', 'aabbabc', '0', ascii('abc')), + ('(a+|b)*', 'ab', '0,1', ascii(('ab', 'b'))), + ('(a+|b)+', 'ab', '0,1', ascii(('ab', 'b'))), + ('(a+|b)?', 'ab', '0,1', ascii(('a', 'a'))), + (')(', '-', '', regex.error, self.TRAILING_CHARS), + ('[^ab]*', 'cde', '0', ascii('cde')), + ('abc', '', '', ascii(None)), + ('a*', '', '0', ascii('')), + + ('a|b|c|d|e', 'e', '0', ascii('e')), + ('(a|b|c|d|e)f', 'ef', '0,1', ascii(('ef', 'e'))), + ('abcd*efg', 'abcdefg', '0', ascii('abcdefg')), + ('ab*', 'xabyabbbz', '0', ascii('ab')), + ('ab*', 'xayabbbz', '0', ascii('a')), + ('(ab|cd)e', 'abcde', '0,1', ascii(('cde', 'cd'))), + ('[abhgefdc]ij', 'hij', '0', ascii('hij')), + ('^(ab|cd)e', 'abcde', '', ascii(None)), + ('(abc|)ef', 'abcdef', '0,1', ascii(('ef', ''))), + ('(a|b)c*d', 'abcd', '0,1', ascii(('bcd', 'b'))), + + ('(ab|ab*)bc', 'abc', '0,1', ascii(('abc', 'a'))), + ('a([bc]*)c*', 'abc', '0,1', ascii(('abc', 'bc'))), + ('a([bc]*)(c*d)', 'abcd', '0,1,2', ascii(('abcd', 'bc', 'd'))), + ('a([bc]+)(c*d)', 'abcd', '0,1,2', ascii(('abcd', 'bc', 'd'))), + ('a([bc]*)(c+d)', 'abcd', '0,1,2', ascii(('abcd', 'b', 'cd'))), + ('a[bcd]*dcdcde', 'adcdcde', '0', ascii('adcdcde')), + ('a[bcd]+dcdcde', 'adcdcde', '', ascii(None)), + ('(ab|a)b*c', 'abc', '0,1', ascii(('abc', 'ab'))), + ('((a)(b)c)(d)', 'abcd', '1,2,3,4', ascii(('abc', 'a', 'b', 'd'))), + ('[a-zA-Z_][a-zA-Z0-9_]*', 'alpha', '0', ascii('alpha')), + + ('^a(bc+|b[eh])g|.h$', 'abh', '0,1', ascii(('bh', None))), + ('(bc+d$|ef*g.|h?i(j|k))', 'effgz', '0,1,2', ascii(('effgz', + 'effgz', None))), + ('(bc+d$|ef*g.|h?i(j|k))', 'ij', '0,1,2', ascii(('ij', 'ij', + 'j'))), + ('(bc+d$|ef*g.|h?i(j|k))', 'effg', '', ascii(None)), + ('(bc+d$|ef*g.|h?i(j|k))', 'bcdd', '', ascii(None)), + ('(bc+d$|ef*g.|h?i(j|k))', 'reffgz', '0,1,2', ascii(('effgz', + 'effgz', None))), + ('(((((((((a)))))))))', 'a', '0', ascii('a')), + ('multiple words of text', 'uh-uh', '', ascii(None)), + ('multiple words', 'multiple words, yeah', '0', + ascii('multiple words')), + ('(.*)c(.*)', 'abcde', '0,1,2', ascii(('abcde', 'ab', 'de'))), + + ('\\((.*), (.*)\\)', '(a, b)', '2,1', ascii(('b', 'a'))), + ('[k]', 'ab', '', ascii(None)), + ('a[-]?c', 'ac', '0', ascii('ac')), + ('(abc)\\1', 'abcabc', '1', ascii('abc')), + ('([a-c]*)\\1', 'abcabc', '1', ascii('abc')), + ('^(.+)?B', 'AB', '1', ascii('A')), + ('(a+).\\1$', 'aaaaa', '0,1', ascii(('aaaaa', 'aa'))), + ('^(a+).\\1$', 'aaaa', '', ascii(None)), + ('(abc)\\1', 'abcabc', '0,1', ascii(('abcabc', 'abc'))), + ('([a-c]+)\\1', 'abcabc', '0,1', ascii(('abcabc', 'abc'))), + + ('(a)\\1', 'aa', '0,1', ascii(('aa', 'a'))), + ('(a+)\\1', 'aa', '0,1', ascii(('aa', 'a'))), + ('(a+)+\\1', 'aa', '0,1', ascii(('aa', 'a'))), + ('(a).+\\1', 'aba', '0,1', ascii(('aba', 'a'))), + ('(a)ba*\\1', 'aba', '0,1', ascii(('aba', 'a'))), + ('(aa|a)a\\1$', 'aaa', '0,1', ascii(('aaa', 'a'))), + ('(a|aa)a\\1$', 'aaa', '0,1', ascii(('aaa', 'a'))), + ('(a+)a\\1$', 'aaa', '0,1', ascii(('aaa', 'a'))), + ('([abc]*)\\1', 'abcabc', '0,1', ascii(('abcabc', 'abc'))), + ('(a)(b)c|ab', 'ab', '0,1,2', ascii(('ab', None, None))), + + ('(a)+x', 'aaax', '0,1', ascii(('aaax', 'a'))), + ('([ac])+x', 'aacx', '0,1', ascii(('aacx', 'c'))), + ('([^/]*/)*sub1/', 'd:msgs/tdir/sub1/trial/away.cpp', '0,1', + ascii(('d:msgs/tdir/sub1/', 'tdir/'))), + ('([^.]*)\\.([^:]*):[T ]+(.*)', 'track1.title:TBlah blah blah', + '0,1,2,3', ascii(('track1.title:TBlah blah blah', 'track1', + 'title', 'Blah blah blah'))), + ('([^N]*N)+', 'abNNxyzN', '0,1', ascii(('abNNxyzN', 'xyzN'))), + ('([^N]*N)+', 'abNNxyz', '0,1', ascii(('abNN', 'N'))), + ('([abc]*)x', 'abcx', '0,1', ascii(('abcx', 'abc'))), + ('([abc]*)x', 'abc', '', ascii(None)), + ('([xyz]*)x', 'abcx', '0,1', ascii(('x', ''))), + ('(a)+b|aac', 'aac', '0,1', ascii(('aac', None))), + + # Test symbolic groups. + ('(?Paaa)a', 'aaaa', '', regex.error, self.BAD_GROUP_NAME), + ('(?Paaa)a', 'aaaa', '0,id', ascii(('aaaa', 'aaa'))), + ('(?Paa)(?P=id)', 'aaaa', '0,id', ascii(('aaaa', 'aa'))), + ('(?Paa)(?P=xd)', 'aaaa', '', regex.error, self.UNKNOWN_GROUP), + + # Character properties. + (r"\g", "g", '0', ascii('g')), + (r"\g<1>", "g", '', regex.error, self.INVALID_GROUP_REF), + (r"(.)\g<1>", "gg", '0', ascii('gg')), + (r"(.)\g<1>", "gg", '', ascii(('gg', 'g'))), + (r"\N", "N", '0', ascii('N')), + (r"\N{LATIN SMALL LETTER A}", "a", '0', ascii('a')), + (r"\p", "p", '0', ascii('p')), + (r"\p{Ll}", "a", '0', ascii('a')), + (r"\P", "P", '0', ascii('P')), + (r"\P{Lu}", "p", '0', ascii('p')), + + # All tests from Perl. + ('abc', 'abc', '0', ascii('abc')), + ('abc', 'xbc', '', ascii(None)), + ('abc', 'axc', '', ascii(None)), + ('abc', 'abx', '', ascii(None)), + ('abc', 'xabcy', '0', ascii('abc')), + ('abc', 'ababc', '0', ascii('abc')), + + ('ab*c', 'abc', '0', ascii('abc')), + ('ab*bc', 'abc', '0', ascii('abc')), + ('ab*bc', 'abbc', '0', ascii('abbc')), + ('ab*bc', 'abbbbc', '0', ascii('abbbbc')), + ('ab{0,}bc', 'abbbbc', '0', ascii('abbbbc')), + ('ab+bc', 'abbc', '0', ascii('abbc')), + ('ab+bc', 'abc', '', ascii(None)), + ('ab+bc', 'abq', '', ascii(None)), + ('ab{1,}bc', 'abq', '', ascii(None)), + ('ab+bc', 'abbbbc', '0', ascii('abbbbc')), + + ('ab{1,}bc', 'abbbbc', '0', ascii('abbbbc')), + ('ab{1,3}bc', 'abbbbc', '0', ascii('abbbbc')), + ('ab{3,4}bc', 'abbbbc', '0', ascii('abbbbc')), + ('ab{4,5}bc', 'abbbbc', '', ascii(None)), + ('ab?bc', 'abbc', '0', ascii('abbc')), + ('ab?bc', 'abc', '0', ascii('abc')), + ('ab{0,1}bc', 'abc', '0', ascii('abc')), + ('ab?bc', 'abbbbc', '', ascii(None)), + ('ab?c', 'abc', '0', ascii('abc')), + ('ab{0,1}c', 'abc', '0', ascii('abc')), + + ('^abc$', 'abc', '0', ascii('abc')), + ('^abc$', 'abcc', '', ascii(None)), + ('^abc', 'abcc', '0', ascii('abc')), + ('^abc$', 'aabc', '', ascii(None)), + ('abc$', 'aabc', '0', ascii('abc')), + ('^', 'abc', '0', ascii('')), + ('$', 'abc', '0', ascii('')), + ('a.c', 'abc', '0', ascii('abc')), + ('a.c', 'axc', '0', ascii('axc')), + ('a.*c', 'axyzc', '0', ascii('axyzc')), + + ('a.*c', 'axyzd', '', ascii(None)), + ('a[bc]d', 'abc', '', ascii(None)), + ('a[bc]d', 'abd', '0', ascii('abd')), + ('a[b-d]e', 'abd', '', ascii(None)), + ('a[b-d]e', 'ace', '0', ascii('ace')), + ('a[b-d]', 'aac', '0', ascii('ac')), + ('a[-b]', 'a-', '0', ascii('a-')), + ('a[b-]', 'a-', '0', ascii('a-')), + ('a[b-a]', '-', '', regex.error, self.BAD_CHAR_RANGE), + ('a[]b', '-', '', regex.error, self.BAD_SET), + + ('a[', '-', '', regex.error, self.BAD_SET), + ('a]', 'a]', '0', ascii('a]')), + ('a[]]b', 'a]b', '0', ascii('a]b')), + ('a[^bc]d', 'aed', '0', ascii('aed')), + ('a[^bc]d', 'abd', '', ascii(None)), + ('a[^-b]c', 'adc', '0', ascii('adc')), + ('a[^-b]c', 'a-c', '', ascii(None)), + ('a[^]b]c', 'a]c', '', ascii(None)), + ('a[^]b]c', 'adc', '0', ascii('adc')), + ('ab|cd', 'abc', '0', ascii('ab')), + + ('ab|cd', 'abcd', '0', ascii('ab')), + ('()ef', 'def', '0,1', ascii(('ef', ''))), + ('*a', '-', '', regex.error, self.NOTHING_TO_REPEAT), + ('(*)b', '-', '', regex.error, self.NOTHING_TO_REPEAT), + ('$b', 'b', '', ascii(None)), + ('a\\', '-', '', regex.error, self.BAD_ESCAPE), + ('a\\(b', 'a(b', '', ascii(('a(b',))), + ('a\\(*b', 'ab', '0', ascii('ab')), + ('a\\(*b', 'a((b', '0', ascii('a((b')), + ('a\\\\b', 'a\\b', '0', ascii('a\\b')), + + ('abc)', '-', '', regex.error, self.TRAILING_CHARS), + ('(abc', '-', '', regex.error, self.MISSING_RPAREN), + ('((a))', 'abc', '0,1,2', ascii(('a', 'a', 'a'))), + ('(a)b(c)', 'abc', '0,1,2', ascii(('abc', 'a', 'c'))), + ('a+b+c', 'aabbabc', '0', ascii('abc')), + ('a{1,}b{1,}c', 'aabbabc', '0', ascii('abc')), + ('a**', '-', '', regex.error, self.MULTIPLE_REPEAT), + ('a.+?c', 'abcabc', '0', ascii('abc')), + ('(a+|b)*', 'ab', '0,1', ascii(('ab', 'b'))), + ('(a+|b){0,}', 'ab', '0,1', ascii(('ab', 'b'))), + + ('(a+|b)+', 'ab', '0,1', ascii(('ab', 'b'))), + ('(a+|b){1,}', 'ab', '0,1', ascii(('ab', 'b'))), + ('(a+|b)?', 'ab', '0,1', ascii(('a', 'a'))), + ('(a+|b){0,1}', 'ab', '0,1', ascii(('a', 'a'))), + (')(', '-', '', regex.error, self.TRAILING_CHARS), + ('[^ab]*', 'cde', '0', ascii('cde')), + ('abc', '', '', ascii(None)), + ('a*', '', '0', ascii('')), + ('([abc])*d', 'abbbcd', '0,1', ascii(('abbbcd', 'c'))), + ('([abc])*bcd', 'abcd', '0,1', ascii(('abcd', 'a'))), + + ('a|b|c|d|e', 'e', '0', ascii('e')), + ('(a|b|c|d|e)f', 'ef', '0,1', ascii(('ef', 'e'))), + ('abcd*efg', 'abcdefg', '0', ascii('abcdefg')), + ('ab*', 'xabyabbbz', '0', ascii('ab')), + ('ab*', 'xayabbbz', '0', ascii('a')), + ('(ab|cd)e', 'abcde', '0,1', ascii(('cde', 'cd'))), + ('[abhgefdc]ij', 'hij', '0', ascii('hij')), + ('^(ab|cd)e', 'abcde', '', ascii(None)), + ('(abc|)ef', 'abcdef', '0,1', ascii(('ef', ''))), + ('(a|b)c*d', 'abcd', '0,1', ascii(('bcd', 'b'))), + + ('(ab|ab*)bc', 'abc', '0,1', ascii(('abc', 'a'))), + ('a([bc]*)c*', 'abc', '0,1', ascii(('abc', 'bc'))), + ('a([bc]*)(c*d)', 'abcd', '0,1,2', ascii(('abcd', 'bc', 'd'))), + ('a([bc]+)(c*d)', 'abcd', '0,1,2', ascii(('abcd', 'bc', 'd'))), + ('a([bc]*)(c+d)', 'abcd', '0,1,2', ascii(('abcd', 'b', 'cd'))), + ('a[bcd]*dcdcde', 'adcdcde', '0', ascii('adcdcde')), + ('a[bcd]+dcdcde', 'adcdcde', '', ascii(None)), + ('(ab|a)b*c', 'abc', '0,1', ascii(('abc', 'ab'))), + ('((a)(b)c)(d)', 'abcd', '1,2,3,4', ascii(('abc', 'a', 'b', 'd'))), + ('[a-zA-Z_][a-zA-Z0-9_]*', 'alpha', '0', ascii('alpha')), + + ('^a(bc+|b[eh])g|.h$', 'abh', '0,1', ascii(('bh', None))), + ('(bc+d$|ef*g.|h?i(j|k))', 'effgz', '0,1,2', ascii(('effgz', + 'effgz', None))), + ('(bc+d$|ef*g.|h?i(j|k))', 'ij', '0,1,2', ascii(('ij', 'ij', + 'j'))), + ('(bc+d$|ef*g.|h?i(j|k))', 'effg', '', ascii(None)), + ('(bc+d$|ef*g.|h?i(j|k))', 'bcdd', '', ascii(None)), + ('(bc+d$|ef*g.|h?i(j|k))', 'reffgz', '0,1,2', ascii(('effgz', + 'effgz', None))), + ('((((((((((a))))))))))', 'a', '10', ascii('a')), + ('((((((((((a))))))))))\\10', 'aa', '0', ascii('aa')), + + # Python does not have the same rules for \\41 so this is a syntax error + # ('((((((((((a))))))))))\\41', 'aa', '', ascii(None)), + # ('((((((((((a))))))))))\\41', 'a!', '0', ascii('a!')), + ('((((((((((a))))))))))\\41', '', '', regex.error, + self.INVALID_GROUP_REF), + ('(?i)((((((((((a))))))))))\\41', '', '', regex.error, + self.INVALID_GROUP_REF), + + ('(((((((((a)))))))))', 'a', '0', ascii('a')), + ('multiple words of text', 'uh-uh', '', ascii(None)), + ('multiple words', 'multiple words, yeah', '0', + ascii('multiple words')), + ('(.*)c(.*)', 'abcde', '0,1,2', ascii(('abcde', 'ab', 'de'))), + ('\\((.*), (.*)\\)', '(a, b)', '2,1', ascii(('b', 'a'))), + ('[k]', 'ab', '', ascii(None)), + ('a[-]?c', 'ac', '0', ascii('ac')), + ('(abc)\\1', 'abcabc', '1', ascii('abc')), + ('([a-c]*)\\1', 'abcabc', '1', ascii('abc')), + ('(?i)abc', 'ABC', '0', ascii('ABC')), + + ('(?i)abc', 'XBC', '', ascii(None)), + ('(?i)abc', 'AXC', '', ascii(None)), + ('(?i)abc', 'ABX', '', ascii(None)), + ('(?i)abc', 'XABCY', '0', ascii('ABC')), + ('(?i)abc', 'ABABC', '0', ascii('ABC')), + ('(?i)ab*c', 'ABC', '0', ascii('ABC')), + ('(?i)ab*bc', 'ABC', '0', ascii('ABC')), + ('(?i)ab*bc', 'ABBC', '0', ascii('ABBC')), + ('(?i)ab*?bc', 'ABBBBC', '0', ascii('ABBBBC')), + ('(?i)ab{0,}?bc', 'ABBBBC', '0', ascii('ABBBBC')), + + ('(?i)ab+?bc', 'ABBC', '0', ascii('ABBC')), + ('(?i)ab+bc', 'ABC', '', ascii(None)), + ('(?i)ab+bc', 'ABQ', '', ascii(None)), + ('(?i)ab{1,}bc', 'ABQ', '', ascii(None)), + ('(?i)ab+bc', 'ABBBBC', '0', ascii('ABBBBC')), + ('(?i)ab{1,}?bc', 'ABBBBC', '0', ascii('ABBBBC')), + ('(?i)ab{1,3}?bc', 'ABBBBC', '0', ascii('ABBBBC')), + ('(?i)ab{3,4}?bc', 'ABBBBC', '0', ascii('ABBBBC')), + ('(?i)ab{4,5}?bc', 'ABBBBC', '', ascii(None)), + ('(?i)ab??bc', 'ABBC', '0', ascii('ABBC')), + + ('(?i)ab??bc', 'ABC', '0', ascii('ABC')), + ('(?i)ab{0,1}?bc', 'ABC', '0', ascii('ABC')), + ('(?i)ab??bc', 'ABBBBC', '', ascii(None)), + ('(?i)ab??c', 'ABC', '0', ascii('ABC')), + ('(?i)ab{0,1}?c', 'ABC', '0', ascii('ABC')), + ('(?i)^abc$', 'ABC', '0', ascii('ABC')), + ('(?i)^abc$', 'ABCC', '', ascii(None)), + ('(?i)^abc', 'ABCC', '0', ascii('ABC')), + ('(?i)^abc$', 'AABC', '', ascii(None)), + ('(?i)abc$', 'AABC', '0', ascii('ABC')), + + ('(?i)^', 'ABC', '0', ascii('')), + ('(?i)$', 'ABC', '0', ascii('')), + ('(?i)a.c', 'ABC', '0', ascii('ABC')), + ('(?i)a.c', 'AXC', '0', ascii('AXC')), + ('(?i)a.*?c', 'AXYZC', '0', ascii('AXYZC')), + ('(?i)a.*c', 'AXYZD', '', ascii(None)), + ('(?i)a[bc]d', 'ABC', '', ascii(None)), + ('(?i)a[bc]d', 'ABD', '0', ascii('ABD')), + ('(?i)a[b-d]e', 'ABD', '', ascii(None)), + ('(?i)a[b-d]e', 'ACE', '0', ascii('ACE')), + + ('(?i)a[b-d]', 'AAC', '0', ascii('AC')), + ('(?i)a[-b]', 'A-', '0', ascii('A-')), + ('(?i)a[b-]', 'A-', '0', ascii('A-')), + ('(?i)a[b-a]', '-', '', regex.error, self.BAD_CHAR_RANGE), + ('(?i)a[]b', '-', '', regex.error, self.BAD_SET), + ('(?i)a[', '-', '', regex.error, self.BAD_SET), + ('(?i)a]', 'A]', '0', ascii('A]')), + ('(?i)a[]]b', 'A]B', '0', ascii('A]B')), + ('(?i)a[^bc]d', 'AED', '0', ascii('AED')), + ('(?i)a[^bc]d', 'ABD', '', ascii(None)), + + ('(?i)a[^-b]c', 'ADC', '0', ascii('ADC')), + ('(?i)a[^-b]c', 'A-C', '', ascii(None)), + ('(?i)a[^]b]c', 'A]C', '', ascii(None)), + ('(?i)a[^]b]c', 'ADC', '0', ascii('ADC')), + ('(?i)ab|cd', 'ABC', '0', ascii('AB')), + ('(?i)ab|cd', 'ABCD', '0', ascii('AB')), + ('(?i)()ef', 'DEF', '0,1', ascii(('EF', ''))), + ('(?i)*a', '-', '', regex.error, self.NOTHING_TO_REPEAT), + ('(?i)(*)b', '-', '', regex.error, self.NOTHING_TO_REPEAT), + ('(?i)$b', 'B', '', ascii(None)), + + ('(?i)a\\', '-', '', regex.error, self.BAD_ESCAPE), + ('(?i)a\\(b', 'A(B', '', ascii(('A(B',))), + ('(?i)a\\(*b', 'AB', '0', ascii('AB')), + ('(?i)a\\(*b', 'A((B', '0', ascii('A((B')), + ('(?i)a\\\\b', 'A\\B', '0', ascii('A\\B')), + ('(?i)abc)', '-', '', regex.error, self.TRAILING_CHARS), + ('(?i)(abc', '-', '', regex.error, self.MISSING_RPAREN), + ('(?i)((a))', 'ABC', '0,1,2', ascii(('A', 'A', 'A'))), + ('(?i)(a)b(c)', 'ABC', '0,1,2', ascii(('ABC', 'A', 'C'))), + ('(?i)a+b+c', 'AABBABC', '0', ascii('ABC')), + + ('(?i)a{1,}b{1,}c', 'AABBABC', '0', ascii('ABC')), + ('(?i)a**', '-', '', regex.error, self.MULTIPLE_REPEAT), + ('(?i)a.+?c', 'ABCABC', '0', ascii('ABC')), + ('(?i)a.*?c', 'ABCABC', '0', ascii('ABC')), + ('(?i)a.{0,5}?c', 'ABCABC', '0', ascii('ABC')), + ('(?i)(a+|b)*', 'AB', '0,1', ascii(('AB', 'B'))), + ('(?i)(a+|b){0,}', 'AB', '0,1', ascii(('AB', 'B'))), + ('(?i)(a+|b)+', 'AB', '0,1', ascii(('AB', 'B'))), + ('(?i)(a+|b){1,}', 'AB', '0,1', ascii(('AB', 'B'))), + ('(?i)(a+|b)?', 'AB', '0,1', ascii(('A', 'A'))), + + ('(?i)(a+|b){0,1}', 'AB', '0,1', ascii(('A', 'A'))), + ('(?i)(a+|b){0,1}?', 'AB', '0,1', ascii(('', None))), + ('(?i))(', '-', '', regex.error, self.TRAILING_CHARS), + ('(?i)[^ab]*', 'CDE', '0', ascii('CDE')), + ('(?i)abc', '', '', ascii(None)), + ('(?i)a*', '', '0', ascii('')), + ('(?i)([abc])*d', 'ABBBCD', '0,1', ascii(('ABBBCD', 'C'))), + ('(?i)([abc])*bcd', 'ABCD', '0,1', ascii(('ABCD', 'A'))), + ('(?i)a|b|c|d|e', 'E', '0', ascii('E')), + ('(?i)(a|b|c|d|e)f', 'EF', '0,1', ascii(('EF', 'E'))), + + ('(?i)abcd*efg', 'ABCDEFG', '0', ascii('ABCDEFG')), + ('(?i)ab*', 'XABYABBBZ', '0', ascii('AB')), + ('(?i)ab*', 'XAYABBBZ', '0', ascii('A')), + ('(?i)(ab|cd)e', 'ABCDE', '0,1', ascii(('CDE', 'CD'))), + ('(?i)[abhgefdc]ij', 'HIJ', '0', ascii('HIJ')), + ('(?i)^(ab|cd)e', 'ABCDE', '', ascii(None)), + ('(?i)(abc|)ef', 'ABCDEF', '0,1', ascii(('EF', ''))), + ('(?i)(a|b)c*d', 'ABCD', '0,1', ascii(('BCD', 'B'))), + ('(?i)(ab|ab*)bc', 'ABC', '0,1', ascii(('ABC', 'A'))), + ('(?i)a([bc]*)c*', 'ABC', '0,1', ascii(('ABC', 'BC'))), + + ('(?i)a([bc]*)(c*d)', 'ABCD', '0,1,2', ascii(('ABCD', 'BC', 'D'))), + ('(?i)a([bc]+)(c*d)', 'ABCD', '0,1,2', ascii(('ABCD', 'BC', 'D'))), + ('(?i)a([bc]*)(c+d)', 'ABCD', '0,1,2', ascii(('ABCD', 'B', 'CD'))), + ('(?i)a[bcd]*dcdcde', 'ADCDCDE', '0', ascii('ADCDCDE')), + ('(?i)a[bcd]+dcdcde', 'ADCDCDE', '', ascii(None)), + ('(?i)(ab|a)b*c', 'ABC', '0,1', ascii(('ABC', 'AB'))), + ('(?i)((a)(b)c)(d)', 'ABCD', '1,2,3,4', ascii(('ABC', 'A', 'B', + 'D'))), + ('(?i)[a-zA-Z_][a-zA-Z0-9_]*', 'ALPHA', '0', ascii('ALPHA')), + ('(?i)^a(bc+|b[eh])g|.h$', 'ABH', '0,1', ascii(('BH', None))), + ('(?i)(bc+d$|ef*g.|h?i(j|k))', 'EFFGZ', '0,1,2', ascii(('EFFGZ', + 'EFFGZ', None))), + + ('(?i)(bc+d$|ef*g.|h?i(j|k))', 'IJ', '0,1,2', ascii(('IJ', 'IJ', + 'J'))), + ('(?i)(bc+d$|ef*g.|h?i(j|k))', 'EFFG', '', ascii(None)), + ('(?i)(bc+d$|ef*g.|h?i(j|k))', 'BCDD', '', ascii(None)), + ('(?i)(bc+d$|ef*g.|h?i(j|k))', 'REFFGZ', '0,1,2', ascii(('EFFGZ', + 'EFFGZ', None))), + ('(?i)((((((((((a))))))))))', 'A', '10', ascii('A')), + ('(?i)((((((((((a))))))))))\\10', 'AA', '0', ascii('AA')), + #('(?i)((((((((((a))))))))))\\41', 'AA', '', ascii(None)), + #('(?i)((((((((((a))))))))))\\41', 'A!', '0', ascii('A!')), + ('(?i)(((((((((a)))))))))', 'A', '0', ascii('A')), + ('(?i)(?:(?:(?:(?:(?:(?:(?:(?:(?:(a))))))))))', 'A', '1', + ascii('A')), + ('(?i)(?:(?:(?:(?:(?:(?:(?:(?:(?:(a|b|c))))))))))', 'C', '1', + ascii('C')), + ('(?i)multiple words of text', 'UH-UH', '', ascii(None)), + + ('(?i)multiple words', 'MULTIPLE WORDS, YEAH', '0', + ascii('MULTIPLE WORDS')), + ('(?i)(.*)c(.*)', 'ABCDE', '0,1,2', ascii(('ABCDE', 'AB', 'DE'))), + ('(?i)\\((.*), (.*)\\)', '(A, B)', '2,1', ascii(('B', 'A'))), + ('(?i)[k]', 'AB', '', ascii(None)), + # ('(?i)abcd', 'ABCD', SUCCEED, 'found+"-"+\\found+"-"+\\\\found', ascii(ABCD-$&-\\ABCD)), + # ('(?i)a(bc)d', 'ABCD', SUCCEED, 'g1+"-"+\\g1+"-"+\\\\g1', ascii(BC-$1-\\BC)), + ('(?i)a[-]?c', 'AC', '0', ascii('AC')), + ('(?i)(abc)\\1', 'ABCABC', '1', ascii('ABC')), + ('(?i)([a-c]*)\\1', 'ABCABC', '1', ascii('ABC')), + ('a(?!b).', 'abad', '0', ascii('ad')), + ('a(?=d).', 'abad', '0', ascii('ad')), + ('a(?=c|d).', 'abad', '0', ascii('ad')), + + ('a(?:b|c|d)(.)', 'ace', '1', ascii('e')), + ('a(?:b|c|d)*(.)', 'ace', '1', ascii('e')), + ('a(?:b|c|d)+?(.)', 'ace', '1', ascii('e')), + ('a(?:b|(c|e){1,2}?|d)+?(.)', 'ace', '1,2', ascii(('c', 'e'))), + + # Lookbehind: split by : but not if it is escaped by -. + ('(?]*?b', 'a>b', '', ascii(None)), + # Bug 490573: minimizing repeat problem. + (r'^a*?$', 'foo', '', ascii(None)), + # Bug 470582: nested groups problem. + (r'^((a)c)?(ab)$', 'ab', '1,2,3', ascii((None, None, 'ab'))), + # Another minimizing repeat problem (capturing groups in assertions). + ('^([ab]*?)(?=(b)?)c', 'abc', '1,2', ascii(('ab', None))), + ('^([ab]*?)(?!(b))c', 'abc', '1,2', ascii(('ab', None))), + ('^([ab]*?)(?(.){0,2})d", "abcd").captures(1), + ['b', 'c']) + self.assertEqual(regex.search(r"(.)+", "a").captures(1), ['a']) + + def test_guards(self): + m = regex.search(r"(X.*?Y\s*){3}(X\s*)+AB:", + "XY\nX Y\nX Y\nXY\nXX AB:") + self.assertEqual(m.span(0, 1, 2), ((3, 21), (12, 15), (16, 18))) + + m = regex.search(r"(X.*?Y\s*){3,}(X\s*)+AB:", + "XY\nX Y\nX Y\nXY\nXX AB:") + self.assertEqual(m.span(0, 1, 2), ((0, 21), (12, 15), (16, 18))) + + m = regex.search(r'\d{4}(\s*\w)?\W*((?!\d)\w){2}', "9999XX") + self.assertEqual(m.span(0, 1, 2), ((0, 6), (-1, -1), (5, 6))) + + m = regex.search(r'A\s*?.*?(\n+.*?\s*?){0,2}\(X', 'A\n1\nS\n1 (X') + self.assertEqual(m.span(0, 1), ((0, 10), (5, 8))) + + m = regex.search(r'Derde\s*:', 'aaaaaa:\nDerde:') + self.assertEqual(m.span(), (8, 14)) + m = regex.search(r'Derde\s*:', 'aaaaa:\nDerde:') + self.assertEqual(m.span(), (7, 13)) + + def test_turkic(self): + # Turkish has dotted and dotless I/i. + pairs = "I=i;I=\u0131;i=\u0130" + + all_chars = set() + matching = set() + for pair in pairs.split(";"): + ch1, ch2 = pair.split("=") + all_chars.update((ch1, ch2)) + matching.add((ch1, ch1)) + matching.add((ch1, ch2)) + matching.add((ch2, ch1)) + matching.add((ch2, ch2)) + + for ch1 in all_chars: + for ch2 in all_chars: + m = regex.match(r"(?i)\A" + ch1 + r"\Z", ch2) + if m: + if (ch1, ch2) not in matching: + self.fail("{} matching {}".format(ascii(ch1), + ascii(ch2))) + else: + if (ch1, ch2) in matching: + self.fail("{} not matching {}".format(ascii(ch1), + ascii(ch2))) + + def test_named_lists(self): + options = ["one", "two", "three"] + self.assertEqual(regex.match(r"333\L444", "333one444", + bar=options).group(), "333one444") + self.assertEqual(regex.match(r"(?i)333\L444", "333TWO444", + bar=options).group(), "333TWO444") + self.assertEqual(regex.match(r"333\L444", "333four444", + bar=options), None) + + options = [b"one", b"two", b"three"] + self.assertEqual(regex.match(br"333\L444", b"333one444", + bar=options).group(), b"333one444") + self.assertEqual(regex.match(br"(?i)333\L444", b"333TWO444", + bar=options).group(), b"333TWO444") + self.assertEqual(regex.match(br"333\L444", b"333four444", + bar=options), None) + + self.assertEqual(repr(type(regex.compile(r"3\L4\L+5", + bar=["one", "two", "three"]))), self.PATTERN_CLASS) + + self.assertEqual(regex.findall(r"^\L", "solid QWERT", + options=set(['good', 'brilliant', '+s\\ol[i}d'])), []) + self.assertEqual(regex.findall(r"^\L", "+solid QWERT", + options=set(['good', 'brilliant', '+solid'])), ['+solid']) + + options = ["STRASSE"] + self.assertEqual(regex.match(r"(?fi)\L", + "stra\N{LATIN SMALL LETTER SHARP S}e", words=options).span(), (0, + 6)) + + options = ["STRASSE", "stress"] + self.assertEqual(regex.match(r"(?fi)\L", + "stra\N{LATIN SMALL LETTER SHARP S}e", words=options).span(), (0, + 6)) + + options = ["stra\N{LATIN SMALL LETTER SHARP S}e"] + self.assertEqual(regex.match(r"(?fi)\L", "STRASSE", + words=options).span(), (0, 7)) + + options = ["kit"] + self.assertEqual(regex.search(r"(?i)\L", "SKITS", + words=options).span(), (1, 4)) + self.assertEqual(regex.search(r"(?i)\L", + "SK\N{LATIN CAPITAL LETTER I WITH DOT ABOVE}TS", + words=options).span(), (1, 4)) + + self.assertEqual(regex.search(r"(?fi)\b(\w+) +\1\b", + " stra\N{LATIN SMALL LETTER SHARP S}e STRASSE ").span(), (1, 15)) + self.assertEqual(regex.search(r"(?fi)\b(\w+) +\1\b", + " STRASSE stra\N{LATIN SMALL LETTER SHARP S}e ").span(), (1, 15)) + + self.assertEqual(regex.search(r"^\L$", "", options=[]).span(), + (0, 0)) + + def test_fuzzy(self): + # Some tests borrowed from TRE library tests. + self.assertEqual(repr(type(regex.compile('(fou){s,e<=1}'))), + self.PATTERN_CLASS) + self.assertEqual(repr(type(regex.compile('(fuu){s}'))), + self.PATTERN_CLASS) + self.assertEqual(repr(type(regex.compile('(fuu){s,e}'))), + self.PATTERN_CLASS) + self.assertEqual(repr(type(regex.compile('(anaconda){1i+1d<1,s<=1}'))), + self.PATTERN_CLASS) + self.assertEqual(repr(type(regex.compile('(anaconda){1i+1d<1,s<=1,e<=10}'))), + self.PATTERN_CLASS) + self.assertEqual(repr(type(regex.compile('(anaconda){s<=1,e<=1,1i+1d<1}'))), + self.PATTERN_CLASS) + + text = 'molasses anaconda foo bar baz smith anderson ' + self.assertEqual(regex.search('(znacnda){s<=1,e<=3,1i+1d<1}', text), + None) + self.assertEqual(regex.search('(znacnda){s<=1,e<=3,1i+1d<2}', + text).span(0, 1), ((9, 17), (9, 17))) + self.assertEqual(regex.search('(ananda){1i+1d<2}', text), None) + self.assertEqual(regex.search(r"(?:\bznacnda){e<=2}", text)[0], + "anaconda") + self.assertEqual(regex.search(r"(?:\bnacnda){e<=2}", text)[0], + "anaconda") + + text = 'anaconda foo bar baz smith anderson' + self.assertEqual(regex.search('(fuu){i<=3,d<=3,e<=5}', text).span(0, + 1), ((0, 0), (0, 0))) + self.assertEqual(regex.search('(?b)(fuu){i<=3,d<=3,e<=5}', + text).span(0, 1), ((9, 10), (9, 10))) + self.assertEqual(regex.search('(fuu){i<=2,d<=2,e<=5}', text).span(0, + 1), ((7, 10), (7, 10))) + self.assertEqual(regex.search('(?e)(fuu){i<=2,d<=2,e<=5}', + text).span(0, 1), ((9, 10), (9, 10))) + self.assertEqual(regex.search('(fuu){i<=3,d<=3,e}', text).span(0, 1), + ((0, 0), (0, 0))) + self.assertEqual(regex.search('(?b)(fuu){i<=3,d<=3,e}', text).span(0, + 1), ((9, 10), (9, 10))) + + self.assertEqual(repr(type(regex.compile('(approximate){s<=3,1i+1d<3}'))), + self.PATTERN_CLASS) + + # No cost limit. + self.assertEqual(regex.search('(foobar){e}', + 'xirefoabralfobarxie').span(0, 1), ((0, 6), (0, 6))) + self.assertEqual(regex.search('(?e)(foobar){e}', + 'xirefoabralfobarxie').span(0, 1), ((0, 3), (0, 3))) + self.assertEqual(regex.search('(?b)(foobar){e}', + 'xirefoabralfobarxie').span(0, 1), ((11, 16), (11, 16))) + + # At most two errors. + self.assertEqual(regex.search('(foobar){e<=2}', + 'xirefoabrzlfd').span(0, 1), ((4, 9), (4, 9))) + self.assertEqual(regex.search('(foobar){e<=2}', 'xirefoabzlfd'), None) + + # At most two inserts or substitutions and max two errors total. + self.assertEqual(regex.search('(foobar){i<=2,s<=2,e<=2}', + 'oobargoobaploowap').span(0, 1), ((5, 11), (5, 11))) + + # Find best whole word match for "foobar". + self.assertEqual(regex.search('\\b(foobar){e}\\b', 'zfoobarz').span(0, + 1), ((0, 8), (0, 8))) + self.assertEqual(regex.search('\\b(foobar){e}\\b', + 'boing zfoobarz goobar woop').span(0, 1), ((0, 6), (0, 6))) + self.assertEqual(regex.search('(?b)\\b(foobar){e}\\b', + 'boing zfoobarz goobar woop').span(0, 1), ((15, 21), (15, 21))) + + # Match whole string, allow only 1 error. + self.assertEqual(regex.search('^(foobar){e<=1}$', 'foobar').span(0, 1), + ((0, 6), (0, 6))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'xfoobar').span(0, + 1), ((0, 7), (0, 7))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'foobarx').span(0, + 1), ((0, 7), (0, 7))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'fooxbar').span(0, + 1), ((0, 7), (0, 7))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'foxbar').span(0, 1), + ((0, 6), (0, 6))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'xoobar').span(0, 1), + ((0, 6), (0, 6))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'foobax').span(0, 1), + ((0, 6), (0, 6))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'oobar').span(0, 1), + ((0, 5), (0, 5))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'fobar').span(0, 1), + ((0, 5), (0, 5))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'fooba').span(0, 1), + ((0, 5), (0, 5))) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'xfoobarx'), None) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'foobarxx'), None) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'xxfoobar'), None) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'xfoxbar'), None) + self.assertEqual(regex.search('^(foobar){e<=1}$', 'foxbarx'), None) + + # At most one insert, two deletes, and three substitutions. + # Additionally, deletes cost two and substitutes one, and total + # cost must be less than 4. + self.assertEqual(regex.search('(foobar){i<=1,d<=2,s<=3,2d+1s<4}', + '3oifaowefbaoraofuiebofasebfaobfaorfeoaro').span(0, 1), ((6, 13), (6, + 13))) + self.assertEqual(regex.search('(?b)(foobar){i<=1,d<=2,s<=3,2d+1s<4}', + '3oifaowefbaoraofuiebofasebfaobfaorfeoaro').span(0, 1), ((34, 39), + (34, 39))) + + # Partially fuzzy matches. + self.assertEqual(regex.search('foo(bar){e<=1}zap', 'foobarzap').span(0, + 1), ((0, 9), (3, 6))) + self.assertEqual(regex.search('foo(bar){e<=1}zap', 'fobarzap'), None) + self.assertEqual(regex.search('foo(bar){e<=1}zap', 'foobrzap').span(0, + 1), ((0, 8), (3, 5))) + + text = ('www.cnn.com 64.236.16.20\nwww.slashdot.org 66.35.250.150\n' + 'For useful information, use www.slashdot.org\nthis is demo data!\n') + self.assertEqual(regex.search(r'(?s)^.*(dot.org){e}.*$', text).span(0, + 1), ((0, 120), (120, 120))) + self.assertEqual(regex.search(r'(?es)^.*(dot.org){e}.*$', text).span(0, + 1), ((0, 120), (93, 100))) + self.assertEqual(regex.search(r'^.*(dot.org){e}.*$', text).span(0, 1), + ((0, 119), (24, 101))) + + # Behaviour is unexpected, but arguably not wrong. It first finds the + # best match, then the best in what follows, etc. + self.assertEqual(regex.findall(r"\b\L{e<=1}\b", + " book cot dog desk ", words="cat dog".split()), ["cot", "dog"]) + self.assertEqual(regex.findall(r"\b\L{e<=1}\b", + " book dog cot desk ", words="cat dog".split()), [" dog", "cot"]) + self.assertEqual(regex.findall(r"(?e)\b\L{e<=1}\b", + " book dog cot desk ", words="cat dog".split()), ["dog", "cot"]) + self.assertEqual(regex.findall(r"(?r)\b\L{e<=1}\b", + " book cot dog desk ", words="cat dog".split()), ["dog ", "cot"]) + self.assertEqual(regex.findall(r"(?er)\b\L{e<=1}\b", + " book cot dog desk ", words="cat dog".split()), ["dog", "cot"]) + self.assertEqual(regex.findall(r"(?r)\b\L{e<=1}\b", + " book dog cot desk ", words="cat dog".split()), ["cot", "dog"]) + self.assertEqual(regex.findall(br"\b\L{e<=1}\b", + b" book cot dog desk ", words=b"cat dog".split()), [b"cot", b"dog"]) + self.assertEqual(regex.findall(br"\b\L{e<=1}\b", + b" book dog cot desk ", words=b"cat dog".split()), [b" dog", b"cot"]) + self.assertEqual(regex.findall(br"(?e)\b\L{e<=1}\b", + b" book dog cot desk ", words=b"cat dog".split()), [b"dog", b"cot"]) + self.assertEqual(regex.findall(br"(?r)\b\L{e<=1}\b", + b" book cot dog desk ", words=b"cat dog".split()), [b"dog ", b"cot"]) + self.assertEqual(regex.findall(br"(?er)\b\L{e<=1}\b", + b" book cot dog desk ", words=b"cat dog".split()), [b"dog", b"cot"]) + self.assertEqual(regex.findall(br"(?r)\b\L{e<=1}\b", + b" book dog cot desk ", words=b"cat dog".split()), [b"cot", b"dog"]) + + self.assertEqual(regex.search(r"(\w+) (\1{e<=1})", "foo fou").groups(), + ("foo", "fou")) + self.assertEqual(regex.search(r"(?r)(\2{e<=1}) (\w+)", + "foo fou").groups(), ("foo", "fou")) + self.assertEqual(regex.search(br"(\w+) (\1{e<=1})", + b"foo fou").groups(), (b"foo", b"fou")) + + self.assertEqual(regex.findall(r"(?:(?:QR)+){e}", "abcde"), ["abcde", + ""]) + self.assertEqual(regex.findall(r"(?:Q+){e}", "abc"), ["abc", ""]) + + # Hg issue 41: = for fuzzy matches + self.assertEqual(regex.match(r"(?:service detection){0[^()]+)|(?R))*\)", "(ab(cd)ef)")[ + : ], ("(ab(cd)ef)", "ef")) + self.assertEqual(regex.search(r"\(((?>[^()]+)|(?R))*\)", + "(ab(cd)ef)").captures(1), ["ab", "cd", "(cd)", "ef"]) + + self.assertEqual(regex.search(r"(?r)\(((?R)|(?>[^()]+))*\)", + "(ab(cd)ef)")[ : ], ("(ab(cd)ef)", "ab")) + self.assertEqual(regex.search(r"(?r)\(((?R)|(?>[^()]+))*\)", + "(ab(cd)ef)").captures(1), ["ef", "cd", "(cd)", "ab"]) + + self.assertEqual(regex.search(r"\(([^()]+|(?R))*\)", + "some text (a(b(c)d)e) more text")[ : ], ("(a(b(c)d)e)", "e")) + + self.assertEqual(regex.search(r"(?r)\(((?R)|[^()]+)*\)", + "some text (a(b(c)d)e) more text")[ : ], ("(a(b(c)d)e)", "a")) + + self.assertEqual(regex.search(r"(foo(\(((?:(?>[^()]+)|(?2))*)\)))", + "foo(bar(baz)+baz(bop))")[ : ], ("foo(bar(baz)+baz(bop))", + "foo(bar(baz)+baz(bop))", "(bar(baz)+baz(bop))", + "bar(baz)+baz(bop)")) + + self.assertEqual(regex.search(r"(?r)(foo(\(((?:(?2)|(?>[^()]+))*)\)))", + "foo(bar(baz)+baz(bop))")[ : ], ("foo(bar(baz)+baz(bop))", + "foo(bar(baz)+baz(bop))", "(bar(baz)+baz(bop))", + "bar(baz)+baz(bop)")) + + rgx = regex.compile(r"""^\s*(<\s*([a-zA-Z:]+)(?:\s*[a-zA-Z:]*\s*=\s*(?:'[^']*'|"[^"]*"))*\s*(/\s*)?>(?:[^<>]*|(?1))*(?(3)|<\s*/\s*\2\s*>))\s*$""") + self.assertEqual(bool(rgx.search('')), True) + self.assertEqual(bool(rgx.search('')), False) + self.assertEqual(bool(rgx.search('')), True) + self.assertEqual(bool(rgx.search('')), False) + self.assertEqual(bool(rgx.search('')), False) + + self.assertEqual(bool(rgx.search('')), False) + self.assertEqual(bool(rgx.search('')), True) + self.assertEqual(bool(rgx.search('< fooo / >')), True) + # The next regex should and does match. Perl 5.14 agrees. + #self.assertEqual(bool(rgx.search('foo')), False) + self.assertEqual(bool(rgx.search('foo')), False) + + self.assertEqual(bool(rgx.search('foo')), True) + self.assertEqual(bool(rgx.search('foo')), True) + self.assertEqual(bool(rgx.search('')), True) + + def test_copy(self): + # PatternObjects are immutable, therefore there's no need to clone them. + r = regex.compile("a") + self.assertTrue(copy.copy(r) is r) + self.assertTrue(copy.deepcopy(r) is r) + + # MatchObjects are normally mutable because the target string can be + # detached. However, after the target string has been detached, a + # MatchObject becomes immutable, so there's no need to clone it. + m = r.match("a") + self.assertTrue(copy.copy(m) is not m) + self.assertTrue(copy.deepcopy(m) is not m) + + self.assertTrue(m.string is not None) + m2 = copy.copy(m) + m2.detach_string() + self.assertTrue(m.string is not None) + self.assertTrue(m2.string is None) + + # The following behaviour matches that of the re module. + it = regex.finditer(".", "ab") + it2 = copy.copy(it) + self.assertEqual(next(it).group(), "a") + self.assertEqual(next(it2).group(), "b") + + # The following behaviour matches that of the re module. + it = regex.finditer(".", "ab") + it2 = copy.deepcopy(it) + self.assertEqual(next(it).group(), "a") + self.assertEqual(next(it2).group(), "b") + + # The following behaviour is designed to match that of copying 'finditer'. + it = regex.splititer(" ", "a b") + it2 = copy.copy(it) + self.assertEqual(next(it), "a") + self.assertEqual(next(it2), "b") + + # The following behaviour is designed to match that of copying 'finditer'. + it = regex.splititer(" ", "a b") + it2 = copy.deepcopy(it) + self.assertEqual(next(it), "a") + self.assertEqual(next(it2), "b") + + def test_format(self): + self.assertEqual(regex.subf(r"(\w+) (\w+)", "{0} => {2} {1}", + "foo bar"), "foo bar => bar foo") + self.assertEqual(regex.subf(r"(?\w+) (?\w+)", + "{word2} {word1}", "foo bar"), "bar foo") + + self.assertEqual(regex.subfn(r"(\w+) (\w+)", "{0} => {2} {1}", + "foo bar"), ("foo bar => bar foo", 1)) + self.assertEqual(regex.subfn(r"(?\w+) (?\w+)", + "{word2} {word1}", "foo bar"), ("bar foo", 1)) + + self.assertEqual(regex.match(r"(\w+) (\w+)", + "foo bar").expandf("{0} => {2} {1}"), "foo bar => bar foo") + + def test_fullmatch(self): + self.assertEqual(bool(regex.fullmatch(r"abc", "abc")), True) + self.assertEqual(bool(regex.fullmatch(r"abc", "abcx")), False) + self.assertEqual(bool(regex.fullmatch(r"abc", "abcx", endpos=3)), True) + + self.assertEqual(bool(regex.fullmatch(r"abc", "xabc", pos=1)), True) + self.assertEqual(bool(regex.fullmatch(r"abc", "xabcy", pos=1)), False) + self.assertEqual(bool(regex.fullmatch(r"abc", "xabcy", pos=1, + endpos=4)), True) + + self.assertEqual(bool(regex.fullmatch(r"(?r)abc", "abc")), True) + self.assertEqual(bool(regex.fullmatch(r"(?r)abc", "abcx")), False) + self.assertEqual(bool(regex.fullmatch(r"(?r)abc", "abcx", endpos=3)), + True) + + self.assertEqual(bool(regex.fullmatch(r"(?r)abc", "xabc", pos=1)), + True) + self.assertEqual(bool(regex.fullmatch(r"(?r)abc", "xabcy", pos=1)), + False) + self.assertEqual(bool(regex.fullmatch(r"(?r)abc", "xabcy", pos=1, + endpos=4)), True) + + def test_issue_18468(self): + self.assertTypedEqual(regex.sub('y', 'a', 'xyz'), 'xaz') + self.assertTypedEqual(regex.sub('y', StrSubclass('a'), + StrSubclass('xyz')), 'xaz') + self.assertTypedEqual(regex.sub(b'y', b'a', b'xyz'), b'xaz') + self.assertTypedEqual(regex.sub(b'y', BytesSubclass(b'a'), + BytesSubclass(b'xyz')), b'xaz') + self.assertTypedEqual(regex.sub(b'y', bytearray(b'a'), + bytearray(b'xyz')), b'xaz') + self.assertTypedEqual(regex.sub(b'y', memoryview(b'a'), + memoryview(b'xyz')), b'xaz') + + for string in ":a:b::c", StrSubclass(":a:b::c"): + self.assertTypedEqual(regex.split(":", string), ['', 'a', 'b', '', + 'c']) + if sys.version_info >= (3, 7, 0): + self.assertTypedEqual(regex.split(":*", string), ['', '', 'a', + '', 'b', '', 'c', '']) + self.assertTypedEqual(regex.split("(:*)", string), ['', ':', + '', '', 'a', ':', '', '', 'b', '::', '', '', 'c', '', '']) + else: + self.assertTypedEqual(regex.split(":*", string), ['', 'a', 'b', + 'c']) + self.assertTypedEqual(regex.split("(:*)", string), ['', ':', + 'a', ':', 'b', '::', 'c']) + + for string in (b":a:b::c", BytesSubclass(b":a:b::c"), + bytearray(b":a:b::c"), memoryview(b":a:b::c")): + self.assertTypedEqual(regex.split(b":", string), [b'', b'a', b'b', + b'', b'c']) + if sys.version_info >= (3, 7, 0): + self.assertTypedEqual(regex.split(b":*", string), [b'', b'', + b'a', b'', b'b', b'', b'c', b'']) + self.assertTypedEqual(regex.split(b"(:*)", string), [b'', b':', + b'', b'', b'a', b':', b'', b'', b'b', b'::', b'', b'', b'c', + b'', b'']) + else: + self.assertTypedEqual(regex.split(b":*", string), [b'', b'a', + b'b', b'c']) + self.assertTypedEqual(regex.split(b"(:*)", string), [b'', b':', + b'a', b':', b'b', b'::', b'c']) + + for string in "a:b::c:::d", StrSubclass("a:b::c:::d"): + self.assertTypedEqual(regex.findall(":+", string), [":", "::", + ":::"]) + self.assertTypedEqual(regex.findall("(:+)", string), [":", "::", + ":::"]) + self.assertTypedEqual(regex.findall("(:)(:*)", string), [(":", ""), + (":", ":"), (":", "::")]) + + for string in (b"a:b::c:::d", BytesSubclass(b"a:b::c:::d"), + bytearray(b"a:b::c:::d"), memoryview(b"a:b::c:::d")): + self.assertTypedEqual(regex.findall(b":+", string), [b":", b"::", + b":::"]) + self.assertTypedEqual(regex.findall(b"(:+)", string), [b":", b"::", + b":::"]) + self.assertTypedEqual(regex.findall(b"(:)(:*)", string), [(b":", + b""), (b":", b":"), (b":", b"::")]) + + for string in 'a', StrSubclass('a'): + self.assertEqual(regex.match('a', string).groups(), ()) + self.assertEqual(regex.match('(a)', string).groups(), ('a',)) + self.assertEqual(regex.match('(a)', string).group(0), 'a') + self.assertEqual(regex.match('(a)', string).group(1), 'a') + self.assertEqual(regex.match('(a)', string).group(1, 1), ('a', + 'a')) + + for string in (b'a', BytesSubclass(b'a'), bytearray(b'a'), + memoryview(b'a')): + self.assertEqual(regex.match(b'a', string).groups(), ()) + self.assertEqual(regex.match(b'(a)', string).groups(), (b'a',)) + self.assertEqual(regex.match(b'(a)', string).group(0), b'a') + self.assertEqual(regex.match(b'(a)', string).group(1), b'a') + self.assertEqual(regex.match(b'(a)', string).group(1, 1), (b'a', + b'a')) + + def test_partial(self): + self.assertEqual(regex.match('ab', 'a', partial=True).partial, True) + self.assertEqual(regex.match('ab', 'a', partial=True).span(), (0, 1)) + self.assertEqual(regex.match(r'cats', 'cat', partial=True).partial, + True) + self.assertEqual(regex.match(r'cats', 'cat', partial=True).span(), (0, + 3)) + self.assertEqual(regex.match(r'cats', 'catch', partial=True), None) + self.assertEqual(regex.match(r'abc\w{3}', 'abcdef', + partial=True).partial, False) + self.assertEqual(regex.match(r'abc\w{3}', 'abcdef', + partial=True).span(), (0, 6)) + self.assertEqual(regex.match(r'abc\w{3}', 'abcde', + partial=True).partial, True) + self.assertEqual(regex.match(r'abc\w{3}', 'abcde', + partial=True).span(), (0, 5)) + + self.assertEqual(regex.match(r'\d{4}$', '1234', partial=True).partial, + False) + + self.assertEqual(regex.match(r'\L', 'post', partial=True, + words=['post']).partial, False) + self.assertEqual(regex.match(r'\L', 'post', partial=True, + words=['post']).span(), (0, 4)) + self.assertEqual(regex.match(r'\L', 'pos', partial=True, + words=['post']).partial, True) + self.assertEqual(regex.match(r'\L', 'pos', partial=True, + words=['post']).span(), (0, 3)) + + self.assertEqual(regex.match(r'(?fi)\L', 'POST', partial=True, + words=['po\uFB06']).partial, False) + self.assertEqual(regex.match(r'(?fi)\L', 'POST', partial=True, + words=['po\uFB06']).span(), (0, 4)) + self.assertEqual(regex.match(r'(?fi)\L', 'POS', partial=True, + words=['po\uFB06']).partial, True) + self.assertEqual(regex.match(r'(?fi)\L', 'POS', partial=True, + words=['po\uFB06']).span(), (0, 3)) + self.assertEqual(regex.match(r'(?fi)\L', 'po\uFB06', + partial=True, words=['POS']), None) + + self.assertEqual(regex.match(r'[a-z]*4R$', 'a', partial=True).span(), + (0, 1)) + self.assertEqual(regex.match(r'[a-z]*4R$', 'ab', partial=True).span(), + (0, 2)) + self.assertEqual(regex.match(r'[a-z]*4R$', 'ab4', partial=True).span(), + (0, 3)) + self.assertEqual(regex.match(r'[a-z]*4R$', 'a4', partial=True).span(), + (0, 2)) + self.assertEqual(regex.match(r'[a-z]*4R$', 'a4R', partial=True).span(), + (0, 3)) + self.assertEqual(regex.match(r'[a-z]*4R$', '4a', partial=True), None) + self.assertEqual(regex.match(r'[a-z]*4R$', 'a44', partial=True), None) + + def test_hg_bugs(self): + # Hg issue 28: regex.compile("(?>b)") causes "TypeError: 'Character' + # object is not subscriptable" + self.assertEqual(bool(regex.compile("(?>b)", flags=regex.V1)), True) + + # Hg issue 29: regex.compile("^((?>\w+)|(?>\s+))*$") causes + # "TypeError: 'GreedyRepeat' object is not iterable" + self.assertEqual(bool(regex.compile(r"^((?>\w+)|(?>\s+))*$", + flags=regex.V1)), True) + + # Hg issue 31: atomic and normal groups in recursive patterns + self.assertEqual(regex.findall(r"\((?:(?>[^()]+)|(?R))*\)", + "a(bcd(e)f)g(h)"), ['(bcd(e)f)', '(h)']) + self.assertEqual(regex.findall(r"\((?:(?:[^()]+)|(?R))*\)", + "a(bcd(e)f)g(h)"), ['(bcd(e)f)', '(h)']) + self.assertEqual(regex.findall(r"\((?:(?>[^()]+)|(?R))*\)", + "a(b(cd)e)f)g)h"), ['(b(cd)e)']) + self.assertEqual(regex.findall(r"\((?:(?>[^()]+)|(?R))*\)", + "a(bc(d(e)f)gh"), ['(d(e)f)']) + self.assertEqual(regex.findall(r"(?r)\((?:(?>[^()]+)|(?R))*\)", + "a(bc(d(e)f)gh"), ['(d(e)f)']) + self.assertEqual([m.group() for m in + regex.finditer(r"\((?:[^()]*+|(?0))*\)", "a(b(c(de)fg)h")], + ['(c(de)fg)']) + + # Hg issue 32: regex.search("a(bc)d", "abcd", regex.I|regex.V1) returns + # None + self.assertEqual(regex.search("a(bc)d", "abcd", regex.I | + regex.V1).group(0), "abcd") + + # Hg issue 33: regex.search("([\da-f:]+)$", "E", regex.I|regex.V1) + # returns None + self.assertEqual(regex.search(r"([\da-f:]+)$", "E", regex.I | + regex.V1).group(0), "E") + self.assertEqual(regex.search(r"([\da-f:]+)$", "e", regex.I | + regex.V1).group(0), "e") + + # Hg issue 34: regex.search("^(?=ab(de))(abd)(e)", "abde").groups() + # returns (None, 'abd', 'e') instead of ('de', 'abd', 'e') + self.assertEqual(regex.search("^(?=ab(de))(abd)(e)", "abde").groups(), + ('de', 'abd', 'e')) + + # Hg issue 35: regex.compile("\ ", regex.X) causes "_regex_core.error: + # bad escape" + self.assertEqual(bool(regex.match(r"\ ", " ", flags=regex.X)), True) + + # Hg issue 36: regex.search("^(a|)\1{2}b", "b") returns None + self.assertEqual(regex.search(r"^(a|)\1{2}b", "b").group(0, 1), ('b', + '')) + + # Hg issue 37: regex.search("^(a){0,0}", "abc").group(0,1) returns + # ('a', 'a') instead of ('', None) + self.assertEqual(regex.search("^(a){0,0}", "abc").group(0, 1), ('', + None)) + + # Hg issue 38: regex.search("(?>.*/)b", "a/b") returns None + self.assertEqual(regex.search("(?>.*/)b", "a/b").group(0), "a/b") + + # Hg issue 39: regex.search("((?i)blah)\\s+\\1", "blah BLAH") doesn't + # return None + # Changed to positional flags in regex 2023.12.23. + self.assertEqual(regex.search(r"((?i)blah)\s+\1", "blah BLAH"), None) + + # Hg issue 40: regex.search("(\()?[^()]+(?(1)\)|)", "(abcd").group(0) + # returns "bcd" instead of "abcd" + self.assertEqual(regex.search(r"(\()?[^()]+(?(1)\)|)", + "(abcd").group(0), "abcd") + + # Hg issue 42: regex.search("(a*)*", "a", flags=regex.V1).span(1) + # returns (0, 1) instead of (1, 1) + self.assertEqual(regex.search("(a*)*", "a").span(1), (1, 1)) + self.assertEqual(regex.search("(a*)*", "aa").span(1), (2, 2)) + self.assertEqual(regex.search("(a*)*", "aaa").span(1), (3, 3)) + + # Hg issue 43: regex.compile("a(?#xxx)*") causes "_regex_core.error: + # nothing to repeat" + self.assertEqual(regex.search("a(?#xxx)*", "aaa").group(), "aaa") + + # Hg issue 44: regex.compile("(?=abc){3}abc") causes + # "_regex_core.error: nothing to repeat" + self.assertEqual(regex.search("(?=abc){3}abc", "abcabcabc").span(), (0, + 3)) + + # Hg issue 45: regex.compile("^(?:a(?:(?:))+)+") causes + # "_regex_core.error: nothing to repeat" + self.assertEqual(regex.search("^(?:a(?:(?:))+)+", "a").span(), (0, 1)) + self.assertEqual(regex.search("^(?:a(?:(?:))+)+", "aa").span(), (0, 2)) + + # Hg issue 46: regex.compile("a(?x: b c )d") causes + # "_regex_core.error: missing )" + self.assertEqual(regex.search("a(?x: b c )d", "abcd").group(0), "abcd") + + # Hg issue 47: regex.compile("a#comment\n*", flags=regex.X) causes + # "_regex_core.error: nothing to repeat" + self.assertEqual(regex.search("a#comment\n*", "aaa", + flags=regex.X).group(0), "aaa") + + # Hg issue 48: regex.search("(a(?(1)\\1)){4}", "a"*10, + # flags=regex.V1).group(0,1) returns ('aaaaa', 'a') instead of ('aaaaaaaaaa', 'aaaa') + self.assertEqual(regex.search(r"(?V1)(a(?(1)\1)){1}", + "aaaaaaaaaa").span(0, 1), ((0, 1), (0, 1))) + self.assertEqual(regex.search(r"(?V1)(a(?(1)\1)){2}", + "aaaaaaaaaa").span(0, 1), ((0, 3), (1, 3))) + self.assertEqual(regex.search(r"(?V1)(a(?(1)\1)){3}", + "aaaaaaaaaa").span(0, 1), ((0, 6), (3, 6))) + self.assertEqual(regex.search(r"(?V1)(a(?(1)\1)){4}", + "aaaaaaaaaa").span(0, 1), ((0, 10), (6, 10))) + + # Hg issue 49: regex.search("(a)(?<=b(?1))", "baz", regex.V1) returns + # None incorrectly + self.assertEqual(regex.search("(?V1)(a)(?<=b(?1))", "baz").group(0), + "a") + + # Hg issue 50: not all keywords are found by named list with + # overlapping keywords when full Unicode casefolding is required + self.assertEqual(regex.findall(r'(?fi)\L', + 'POST, Post, post, po\u017Ft, po\uFB06, and po\uFB05', + keywords=['post','pos']), ['POST', 'Post', 'post', 'po\u017Ft', + 'po\uFB06', 'po\uFB05']) + self.assertEqual(regex.findall(r'(?fi)pos|post', + 'POST, Post, post, po\u017Ft, po\uFB06, and po\uFB05'), ['POS', + 'Pos', 'pos', 'po\u017F', 'po\uFB06', 'po\uFB05']) + self.assertEqual(regex.findall(r'(?fi)post|pos', + 'POST, Post, post, po\u017Ft, po\uFB06, and po\uFB05'), ['POST', + 'Post', 'post', 'po\u017Ft', 'po\uFB06', 'po\uFB05']) + self.assertEqual(regex.findall(r'(?fi)post|another', + 'POST, Post, post, po\u017Ft, po\uFB06, and po\uFB05'), ['POST', + 'Post', 'post', 'po\u017Ft', 'po\uFB06', 'po\uFB05']) + + # Hg issue 51: regex.search("((a)(?1)|(?2))", "a", flags=regex.V1) + # returns None incorrectly + self.assertEqual(regex.search("(?V1)((a)(?1)|(?2))", "a").group(0, 1, + 2), ('a', 'a', None)) + + # Hg issue 52: regex.search("(\\1xx|){6}", "xx", + # flags=regex.V1).span(0,1) returns incorrect value + self.assertEqual(regex.search(r"(?V1)(\1xx|){6}", "xx").span(0, 1), + ((0, 2), (2, 2))) + + # Hg issue 53: regex.search("(a|)+", "a") causes MemoryError + self.assertEqual(regex.search("(a|)+", "a").group(0, 1), ("a", "")) + + # Hg issue 54: regex.search("(a|)*\\d", "a"*80) causes MemoryError + self.assertEqual(regex.search(r"(a|)*\d", "a" * 80), None) + + # Hg issue 55: regex.search("^(?:a?b?)*$", "ac") take a very long time. + self.assertEqual(regex.search("^(?:a?b?)*$", "ac"), None) + + # Hg issue 58: bad named character escape sequences like "\\N{1}" + # treats as "N" + self.assertRaisesRegex(regex.error, self.UNDEF_CHAR_NAME, lambda: + regex.compile("\\N{1}")) + + # Hg issue 59: regex.search("\\Z", "a\na\n") returns None incorrectly + self.assertEqual(regex.search("\\Z", "a\na\n").span(0), (4, 4)) + + # Hg issue 60: regex.search("(q1|.)*(q2|.)*(x(a|bc)*y){2,}", "xayxay") + # returns None incorrectly + self.assertEqual(regex.search("(q1|.)*(q2|.)*(x(a|bc)*y){2,}", + "xayxay").group(0), "xayxay") + + # Hg issue 61: regex.search("[^a]", "A", regex.I).group(0) returns '' + # incorrectly + self.assertEqual(regex.search("(?i)[^a]", "A"), None) + + # Hg issue 63: regex.search("[[:ascii:]]", "\N{KELVIN SIGN}", + # flags=regex.I|regex.V1) doesn't return None + self.assertEqual(regex.search("(?i)[[:ascii:]]", "\N{KELVIN SIGN}"), + None) + + # Hg issue 66: regex.search("((a|b(?1)c){3,5})", "baaaaca", + # flags=regex.V1).groups() returns ('baaaac', 'baaaac') instead of ('aaaa', 'a') + self.assertEqual(regex.search("((a|b(?1)c){3,5})", "baaaaca").group(0, + 1, 2), ('aaaa', 'aaaa', 'a')) + + # Hg issue 71: non-greedy quantifier in lookbehind + self.assertEqual(regex.findall(r"(?<=:\S+ )\w+", ":9 abc :10 def"), + ['abc', 'def']) + self.assertEqual(regex.findall(r"(?<=:\S* )\w+", ":9 abc :10 def"), + ['abc', 'def']) + self.assertEqual(regex.findall(r"(?<=:\S+? )\w+", ":9 abc :10 def"), + ['abc', 'def']) + self.assertEqual(regex.findall(r"(?<=:\S*? )\w+", ":9 abc :10 def"), + ['abc', 'def']) + + # Hg issue 73: conditional patterns + self.assertEqual(regex.search(r"(?:fe)?male", "female").group(), + "female") + self.assertEqual([m.group() for m in + regex.finditer(r"(fe)?male: h(?(1)(er)|(is)) (\w+)", + "female: her dog; male: his cat. asdsasda")], ['female: her dog', + 'male: his cat']) + + # Hg issue 78: "Captures" doesn't work for recursive calls + self.assertEqual(regex.search(r'(?\((?:[^()]++|(?&rec))*\))', + 'aaa(((1+0)+1)+1)bbb').captures('rec'), ['(1+0)', '((1+0)+1)', + '(((1+0)+1)+1)']) + + # Hg issue 80: Escape characters throws an exception + self.assertRaisesRegex(regex.error, self.BAD_ESCAPE, lambda: + regex.sub('x', '\\', 'x'), ) + + # Hg issue 82: error range does not work + fz = "(CAGCCTCCCATTTCAGAATATACATCC){1a(?b))', "ab").spans("x"), [(1, + 2), (0, 2)]) + + # Hg issue 91: match.expand is extremely slow + # Check that the replacement cache works. + self.assertEqual(regex.sub(r'(-)', lambda m: m.expand(r'x'), 'a-b-c'), + 'axbxc') + + # Hg issue 94: Python crashes when executing regex updates + # pattern.findall + rx = regex.compile(r'\bt(est){i<2}', flags=regex.V1) + self.assertEqual(rx.search("Some text"), None) + self.assertEqual(rx.findall("Some text"), []) + + # Hg issue 95: 'pos' for regex.error + self.assertRaisesRegex(regex.error, self.MULTIPLE_REPEAT, lambda: + regex.compile(r'.???')) + + # Hg issue 97: behaviour of regex.escape's special_only is wrong + # + # Hg issue 244: Make `special_only=True` the default in + # `regex.escape()` + self.assertEqual(regex.escape('foo!?', special_only=False), 'foo\\!\\?') + self.assertEqual(regex.escape('foo!?', special_only=True), 'foo!\\?') + self.assertEqual(regex.escape('foo!?'), 'foo!\\?') + + self.assertEqual(regex.escape(b'foo!?', special_only=False), b'foo\\!\\?') + self.assertEqual(regex.escape(b'foo!?', special_only=True), + b'foo!\\?') + self.assertEqual(regex.escape(b'foo!?'), b'foo!\\?') + + # Hg issue 100: strange results from regex.search + self.assertEqual(regex.search('^([^z]*(?:WWWi|W))?$', + 'WWWi').groups(), ('WWWi', )) + self.assertEqual(regex.search('^([^z]*(?:WWWi|w))?$', + 'WWWi').groups(), ('WWWi', )) + self.assertEqual(regex.search('^([^z]*?(?:WWWi|W))?$', + 'WWWi').groups(), ('WWWi', )) + + # Hg issue 101: findall() broken (seems like memory corruption) + pat = regex.compile(r'xxx', flags=regex.FULLCASE | regex.UNICODE) + self.assertEqual([x.group() for x in pat.finditer('yxxx')], ['xxx']) + self.assertEqual(pat.findall('yxxx'), ['xxx']) + + raw = 'yxxx' + self.assertEqual([x.group() for x in pat.finditer(raw)], ['xxx']) + self.assertEqual(pat.findall(raw), ['xxx']) + + pat = regex.compile(r'xxx', flags=regex.FULLCASE | regex.IGNORECASE | + regex.UNICODE) + self.assertEqual([x.group() for x in pat.finditer('yxxx')], ['xxx']) + self.assertEqual(pat.findall('yxxx'), ['xxx']) + + raw = 'yxxx' + self.assertEqual([x.group() for x in pat.finditer(raw)], ['xxx']) + self.assertEqual(pat.findall(raw), ['xxx']) + + # Hg issue 106: * operator not working correctly with sub() + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.sub('(?V0).*', 'x', 'test'), 'xx') + else: + self.assertEqual(regex.sub('(?V0).*', 'x', 'test'), 'x') + self.assertEqual(regex.sub('(?V1).*', 'x', 'test'), 'xx') + + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.sub('(?V0).*?', '|', 'test'), '|||||||||') + else: + self.assertEqual(regex.sub('(?V0).*?', '|', 'test'), '|t|e|s|t|') + self.assertEqual(regex.sub('(?V1).*?', '|', 'test'), '|||||||||') + + # Hg issue 112: re: OK, but regex: SystemError + self.assertEqual(regex.sub(r'^(@)\n(?!.*?@)(.*)', + r'\1\n==========\n\2', '@\n', flags=regex.DOTALL), '@\n==========\n') + + # Hg issue 109: Edit distance of fuzzy match + self.assertEqual(regex.match(r'(?:cats|cat){e<=1}', + 'caz').fuzzy_counts, (1, 0, 0)) + self.assertEqual(regex.match(r'(?e)(?:cats|cat){e<=1}', + 'caz').fuzzy_counts, (1, 0, 0)) + self.assertEqual(regex.match(r'(?b)(?:cats|cat){e<=1}', + 'caz').fuzzy_counts, (1, 0, 0)) + + self.assertEqual(regex.match(r'(?:cat){e<=1}', 'caz').fuzzy_counts, + (1, 0, 0)) + self.assertEqual(regex.match(r'(?e)(?:cat){e<=1}', + 'caz').fuzzy_counts, (1, 0, 0)) + self.assertEqual(regex.match(r'(?b)(?:cat){e<=1}', + 'caz').fuzzy_counts, (1, 0, 0)) + + self.assertEqual(regex.match(r'(?:cats){e<=2}', 'c ats').fuzzy_counts, + (1, 1, 0)) + self.assertEqual(regex.match(r'(?e)(?:cats){e<=2}', + 'c ats').fuzzy_counts, (0, 1, 0)) + self.assertEqual(regex.match(r'(?b)(?:cats){e<=2}', + 'c ats').fuzzy_counts, (0, 1, 0)) + + self.assertEqual(regex.match(r'(?:cats){e<=2}', + 'c a ts').fuzzy_counts, (0, 2, 0)) + self.assertEqual(regex.match(r'(?e)(?:cats){e<=2}', + 'c a ts').fuzzy_counts, (0, 2, 0)) + self.assertEqual(regex.match(r'(?b)(?:cats){e<=2}', + 'c a ts').fuzzy_counts, (0, 2, 0)) + + self.assertEqual(regex.match(r'(?:cats){e<=1}', 'c ats').fuzzy_counts, + (0, 1, 0)) + self.assertEqual(regex.match(r'(?e)(?:cats){e<=1}', + 'c ats').fuzzy_counts, (0, 1, 0)) + self.assertEqual(regex.match(r'(?b)(?:cats){e<=1}', + 'c ats').fuzzy_counts, (0, 1, 0)) + + # Hg issue 115: Infinite loop when processing backreferences + self.assertEqual(regex.findall(r'\bof ([a-z]+) of \1\b', + 'To make use of one of these modules'), []) + + # Hg issue 125: Reference to entire match (\g<0>) in + # Pattern.sub() doesn't work as of 2014.09.22 release. + self.assertEqual(regex.sub(r'x', r'\g<0>', 'x'), 'x') + + # Unreported issue: no such builtin as 'ascii' in Python 2. + self.assertEqual(bool(regex.match(r'a', 'a', regex.DEBUG)), True) + + # Hg issue 131: nested sets behaviour + self.assertEqual(regex.findall(r'(?V1)[[b-e]--cd]', 'abcdef'), ['b', + 'e']) + self.assertEqual(regex.findall(r'(?V1)[b-e--cd]', 'abcdef'), ['b', + 'e']) + self.assertEqual(regex.findall(r'(?V1)[[bcde]--cd]', 'abcdef'), ['b', + 'e']) + self.assertEqual(regex.findall(r'(?V1)[bcde--cd]', 'abcdef'), ['b', + 'e']) + + # Hg issue 132: index out of range on null property \p{} + self.assertRaisesRegex(regex.error, '^unknown property at position 4$', + lambda: regex.compile(r'\p{}')) + + # Issue 23692. + self.assertEqual(regex.match('(?:()|(?(1)()|z)){2}(?(2)a|z)', + 'a').group(0, 1, 2), ('a', '', '')) + self.assertEqual(regex.match('(?:()|(?(1)()|z)){0,2}(?(2)a|z)', + 'a').group(0, 1, 2), ('a', '', '')) + + # Hg issue 137: Posix character class :punct: does not seem to be + # supported. + + # Posix compatibility as recommended here: + # http://www.unicode.org/reports/tr18/#Compatibility_Properties + + # Posix in Unicode. + chars = ''.join(chr(c) for c in range(0x10000)) + + self.assertEqual(ascii(''.join(regex.findall(r'''[[:alnum:]]+''', + chars))), ascii(''.join(regex.findall(r'''[\p{Alpha}\p{PosixDigit}]+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:alpha:]]+''', + chars))), ascii(''.join(regex.findall(r'''\p{Alpha}+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:ascii:]]+''', + chars))), ascii(''.join(regex.findall(r'''[\p{InBasicLatin}]+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:blank:]]+''', + chars))), ascii(''.join(regex.findall(r'''[\p{gc=Space_Separator}\t]+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:cntrl:]]+''', + chars))), ascii(''.join(regex.findall(r'''\p{gc=Control}+''', chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:digit:]]+''', + chars))), ascii(''.join(regex.findall(r'''[0-9]+''', chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:graph:]]+''', + chars))), ascii(''.join(regex.findall(r'''[^\p{Space}\p{gc=Control}\p{gc=Surrogate}\p{gc=Unassigned}]+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:lower:]]+''', + chars))), ascii(''.join(regex.findall(r'''\p{Lower}+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:print:]]+''', + chars))), ascii(''.join(regex.findall(r'''(?V1)[\p{Graph}\p{Blank}--\p{Cntrl}]+''', chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:punct:]]+''', + chars))), + ascii(''.join(regex.findall(r'''(?V1)[\p{gc=Punctuation}\p{gc=Symbol}--\p{Alpha}]+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:space:]]+''', + chars))), ascii(''.join(regex.findall(r'''\p{Whitespace}+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:upper:]]+''', + chars))), ascii(''.join(regex.findall(r'''\p{Upper}+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:word:]]+''', + chars))), ascii(''.join(regex.findall(r'''[\p{Alpha}\p{gc=Mark}\p{Digit}\p{gc=Connector_Punctuation}\p{Join_Control}]+''', + chars)))) + self.assertEqual(ascii(''.join(regex.findall(r'''[[:xdigit:]]+''', + chars))), ascii(''.join(regex.findall(r'''[0-9A-Fa-f]+''', + chars)))) + + # Posix in ASCII. + chars = bytes(range(0x100)) + + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:alnum:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)[\p{Alpha}\p{PosixDigit}]+''', + chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:alpha:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)\p{Alpha}+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:ascii:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)[\x00-\x7F]+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:blank:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)[\p{gc=Space_Separator}\t]+''', + chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:cntrl:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)\p{gc=Control}+''', + chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:digit:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)[0-9]+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:graph:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)[^\p{Space}\p{gc=Control}\p{gc=Surrogate}\p{gc=Unassigned}]+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:lower:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)\p{Lower}+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:print:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?aV1)[\p{Graph}\p{Blank}--\p{Cntrl}]+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:punct:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?aV1)[\p{gc=Punctuation}\p{gc=Symbol}--\p{Alpha}]+''', + chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:space:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)\p{Whitespace}+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:upper:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)\p{Upper}+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:word:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)[\p{Alpha}\p{gc=Mark}\p{Digit}\p{gc=Connector_Punctuation}\p{Join_Control}]+''', chars)))) + self.assertEqual(ascii(b''.join(regex.findall(br'''(?a)[[:xdigit:]]+''', + chars))), ascii(b''.join(regex.findall(br'''(?a)[0-9A-Fa-f]+''', chars)))) + + # Hg issue 138: grapheme anchored search not working properly. + self.assertEqual(ascii(regex.search(r'\X$', 'ab\u2103').group()), + ascii('\u2103')) + + # Hg issue 139: Regular expression with multiple wildcards where first + # should match empty string does not always work. + self.assertEqual(regex.search("([^L]*)([^R]*R)", "LtR").groups(), ('', + 'LtR')) + + # Hg issue 140: Replace with REVERSE and groups has unexpected + # behavior. + self.assertEqual(regex.sub(r'(.)', r'x\1y', 'ab'), 'xayxby') + self.assertEqual(regex.sub(r'(?r)(.)', r'x\1y', 'ab'), 'xayxby') + self.assertEqual(regex.subf(r'(.)', 'x{1}y', 'ab'), 'xayxby') + self.assertEqual(regex.subf(r'(?r)(.)', 'x{1}y', 'ab'), 'xayxby') + + # Hg issue 141: Crash on a certain partial match. + self.assertEqual(regex.fullmatch('(a)*abc', 'ab', + partial=True).span(), (0, 2)) + self.assertEqual(regex.fullmatch('(a)*abc', 'ab', + partial=True).partial, True) + + # Hg issue 143: Partial matches have incorrect span if prefix is '.' + # wildcard. + self.assertEqual(regex.search('OXRG', 'OOGOX', partial=True).span(), + (3, 5)) + self.assertEqual(regex.search('.XRG', 'OOGOX', partial=True).span(), + (3, 5)) + self.assertEqual(regex.search('.{1,3}XRG', 'OOGOX', + partial=True).span(), (1, 5)) + + # Hg issue 144: Latest version problem with matching 'R|R'. + self.assertEqual(regex.match('R|R', 'R').span(), (0, 1)) + + # Hg issue 146: Forced-fail (?!) works improperly in conditional. + self.assertEqual(regex.match(r'(.)(?(1)(?!))', 'xy'), None) + + # Groups cleared after failure. + self.assertEqual(regex.findall(r'(y)?(\d)(?(1)\b\B)', 'ax1y2z3b'), + [('', '1'), ('', '2'), ('', '3')]) + self.assertEqual(regex.findall(r'(y)?+(\d)(?(1)\b\B)', 'ax1y2z3b'), + [('', '1'), ('', '2'), ('', '3')]) + + # Hg issue 147: Fuzzy match can return match points beyond buffer end. + self.assertEqual([m.span() for m in regex.finditer(r'(?i)(?:error){e}', + 'regex failure')], [(0, 5), (5, 10), (10, 13), (13, 13)]) + self.assertEqual([m.span() for m in + regex.finditer(r'(?fi)(?:error){e}', 'regex failure')], [(0, 5), (5, + 10), (10, 13), (13, 13)]) + + # Hg issue 150: Have an option for POSIX-compatible longest match of + # alternates. + self.assertEqual(regex.search(r'(?p)\d+(\w(\d*)?|[eE]([+-]\d+))', + '10b12')[0], '10b12') + self.assertEqual(regex.search(r'(?p)\d+(\w(\d*)?|[eE]([+-]\d+))', + '10E+12')[0], '10E+12') + + self.assertEqual(regex.search(r'(?p)(\w|ae|oe|ue|ss)', 'ae')[0], 'ae') + self.assertEqual(regex.search(r'(?p)one(self)?(selfsufficient)?', + 'oneselfsufficient')[0], 'oneselfsufficient') + + # Hg issue 151: Request: \K. + self.assertEqual(regex.search(r'(ab\Kcd)', 'abcd').group(0, 1), ('cd', + 'abcd')) + self.assertEqual(regex.findall(r'\w\w\K\w\w', 'abcdefgh'), ['cd', + 'gh']) + self.assertEqual(regex.findall(r'(\w\w\K\w\w)', 'abcdefgh'), ['abcd', + 'efgh']) + + self.assertEqual(regex.search(r'(?r)(ab\Kcd)', 'abcd').group(0, 1), + ('ab', 'abcd')) + self.assertEqual(regex.findall(r'(?r)\w\w\K\w\w', 'abcdefgh'), ['ef', + 'ab']) + self.assertEqual(regex.findall(r'(?r)(\w\w\K\w\w)', 'abcdefgh'), + ['efgh', 'abcd']) + + # Hg issue 152: Request: Request: (?(DEFINE)...). + self.assertEqual(regex.search(r'(?(DEFINE)(?\d+)(?\w+))(?&quant) (?&item)', + '5 elephants')[0], '5 elephants') + + self.assertEqual(regex.search(r'(?&routine)(?(DEFINE)(?.))', 'a').group('routine'), None) + self.assertEqual(regex.search(r'(?&routine)(?(DEFINE)(?.))', 'a').captures('routine'), ['a']) + + # Hg issue 153: Request: (*SKIP). + self.assertEqual(regex.search(r'12(*FAIL)|3', '123')[0], '3') + self.assertEqual(regex.search(r'(?r)12(*FAIL)|3', '123')[0], '3') + + self.assertEqual(regex.search(r'\d+(*PRUNE)\d', '123'), None) + self.assertEqual(regex.search(r'\d+(?=(*PRUNE))\d', '123')[0], '123') + self.assertEqual(regex.search(r'\d+(*PRUNE)bcd|[3d]', '123bcd')[0], + '123bcd') + self.assertEqual(regex.search(r'\d+(*PRUNE)bcd|[3d]', '123zzd')[0], + 'd') + self.assertEqual(regex.search(r'\d+?(*PRUNE)bcd|[3d]', '123bcd')[0], + '3bcd') + self.assertEqual(regex.search(r'\d+?(*PRUNE)bcd|[3d]', '123zzd')[0], + 'd') + self.assertEqual(regex.search(r'\d++(?<=3(*PRUNE))zzd|[4d]$', + '123zzd')[0], '123zzd') + self.assertEqual(regex.search(r'\d++(?<=3(*PRUNE))zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'\d++(?<=(*PRUNE)3)zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'\d++(?<=2(*PRUNE)3)zzd|[3d]$', + '124zzd')[0], 'd') + + self.assertEqual(regex.search(r'(?r)\d(*PRUNE)\d+', '123'), None) + self.assertEqual(regex.search(r'(?r)\d(?<=(*PRUNE))\d+', '123')[0], + '123') + self.assertEqual(regex.search(r'(?r)\d+(*PRUNE)bcd|[3d]', + '123bcd')[0], '123bcd') + self.assertEqual(regex.search(r'(?r)\d+(*PRUNE)bcd|[3d]', + '123zzd')[0], 'd') + self.assertEqual(regex.search(r'(?r)\d++(?<=3(*PRUNE))zzd|[4d]$', + '123zzd')[0], '123zzd') + self.assertEqual(regex.search(r'(?r)\d++(?<=3(*PRUNE))zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'(?r)\d++(?<=(*PRUNE)3)zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'(?r)\d++(?<=2(*PRUNE)3)zzd|[3d]$', + '124zzd')[0], 'd') + + self.assertEqual(regex.search(r'\d+(*SKIP)bcd|[3d]', '123bcd')[0], + '123bcd') + self.assertEqual(regex.search(r'\d+(*SKIP)bcd|[3d]', '123zzd')[0], + 'd') + self.assertEqual(regex.search(r'\d+?(*SKIP)bcd|[3d]', '123bcd')[0], + '3bcd') + self.assertEqual(regex.search(r'\d+?(*SKIP)bcd|[3d]', '123zzd')[0], + 'd') + self.assertEqual(regex.search(r'\d++(?<=3(*SKIP))zzd|[4d]$', + '123zzd')[0], '123zzd') + self.assertEqual(regex.search(r'\d++(?<=3(*SKIP))zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'\d++(?<=(*SKIP)3)zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'\d++(?<=2(*SKIP)3)zzd|[3d]$', + '124zzd')[0], 'd') + + self.assertEqual(regex.search(r'(?r)\d+(*SKIP)bcd|[3d]', '123bcd')[0], + '123bcd') + self.assertEqual(regex.search(r'(?r)\d+(*SKIP)bcd|[3d]', '123zzd')[0], + 'd') + self.assertEqual(regex.search(r'(?r)\d++(?<=3(*SKIP))zzd|[4d]$', + '123zzd')[0], '123zzd') + self.assertEqual(regex.search(r'(?r)\d++(?<=3(*SKIP))zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'(?r)\d++(?<=(*SKIP)3)zzd|[4d]$', + '124zzd')[0], 'd') + self.assertEqual(regex.search(r'(?r)\d++(?<=2(*SKIP)3)zzd|[3d]$', + '124zzd')[0], 'd') + + # Hg issue 154: Segmentation fault 11 when working with an atomic group + text = """June 30, December 31, 2013 2012 +some words follow: +more words and numbers 1,234,567 9,876,542 +more words and numbers 1,234,567 9,876,542""" + self.assertEqual(len(regex.findall(r'(?2014|2013 ?2012)', text)), 1) + + # Hg issue 156: regression on atomic grouping + self.assertEqual(regex.match('1(?>2)', '12').span(), (0, 2)) + + # Hg issue 157: regression: segfault on complex lookaround + self.assertEqual(regex.match(r'(?V1w)(?=(?=[^A-Z]*+[A-Z])(?=[^a-z]*+[a-z]))(?=\D*+\d)(?=\p{Alphanumeric}*+\P{Alphanumeric})\A(?s:.){8,255}+\Z', + 'AAaa11!!')[0], 'AAaa11!!') + + # Hg issue 158: Group issue with (?(DEFINE)...) + TEST_REGEX = regex.compile(r'''(?smx) +(?(DEFINE) + (? + ^,[^,]+, + ) +) + +# Group 2 is defined on this line +^,([^,]+), + +(?:(?!(?&subcat)[\r\n]+(?&subcat)).)+ +''') + + TEST_DATA = ''' +,Cat 1, +,Brand 1, +some +thing +,Brand 2, +other +things +,Cat 2, +,Brand, +Some +thing +''' + + self.assertEqual([m.span(1, 2) for m in + TEST_REGEX.finditer(TEST_DATA)], [((-1, -1), (2, 7)), ((-1, -1), (54, + 59))]) + + # Hg issue 161: Unexpected fuzzy match results + self.assertEqual(regex.search('(abcdefgh){e}', + '******abcdefghijklmnopqrtuvwxyz', regex.BESTMATCH).span(), (6, 14)) + self.assertEqual(regex.search('(abcdefghi){e}', + '******abcdefghijklmnopqrtuvwxyz', regex.BESTMATCH).span(), (6, 15)) + + # Hg issue 163: allow lookarounds in conditionals. + self.assertEqual(regex.match(r'(?:(?=\d)\d+\b|\w+)', '123abc').span(), + (0, 6)) + self.assertEqual(regex.match(r'(?(?=\d)\d+\b|\w+)', '123abc'), None) + self.assertEqual(regex.search(r'(?(?<=love\s)you|(?<=hate\s)her)', + "I love you").span(), (7, 10)) + self.assertEqual(regex.findall(r'(?(?<=love\s)you|(?<=hate\s)her)', + "I love you but I don't hate her either"), ['you', 'her']) + + # Hg issue 180: bug of POSIX matching. + self.assertEqual(regex.search(r'(?p)a*(.*?)', 'aaabbb').group(0, 1), + ('aaabbb', 'bbb')) + self.assertEqual(regex.search(r'(?p)a*(.*)', 'aaabbb').group(0, 1), + ('aaabbb', 'bbb')) + self.assertEqual(regex.sub(r'(?p)a*(.*?)', r'\1', 'aaabbb'), 'bbb') + self.assertEqual(regex.sub(r'(?p)a*(.*)', r'\1', 'aaabbb'), 'bbb') + + # Hg issue 192: Named lists reverse matching doesn't work with + # IGNORECASE and V1 + self.assertEqual(regex.match(r'(?irV0)\L', '21', kw=['1']).span(), + (1, 2)) + self.assertEqual(regex.match(r'(?irV1)\L', '21', kw=['1']).span(), + (1, 2)) + + # Hg issue 193: Alternation and .REVERSE flag. + self.assertEqual(regex.search('a|b', '111a222').span(), (3, 4)) + self.assertEqual(regex.search('(?r)a|b', '111a222').span(), (3, 4)) + + # Hg issue 194: .FULLCASE and Backreference + self.assertEqual(regex.search(r'(?if)<(CLI)><\1>', + '').span(), (0, 10)) + self.assertEqual(regex.search(r'(?if)<(CLI)><\1>', + '').span(), (0, 10)) + self.assertEqual(regex.search(r'(?ifr)<\1><(CLI)>', + '').span(), (0, 10)) + + # Hg issue 195: Pickle (or otherwise serial) the compiled regex + r = regex.compile(r'\L', options=['foo', 'bar']) + p = pickle.dumps(r) + r = pickle.loads(p) + self.assertEqual(r.match('foo').span(), (0, 3)) + + # Hg issue 196: Fuzzy matching on repeated regex not working as + # expected + self.assertEqual(regex.match('(x{6}){e<=1}', 'xxxxxx', + flags=regex.BESTMATCH).span(), (0, 6)) + self.assertEqual(regex.match('(x{6}){e<=1}', 'xxxxx', + flags=regex.BESTMATCH).span(), (0, 5)) + self.assertEqual(regex.match('(x{6}){e<=1}', 'x', + flags=regex.BESTMATCH), None) + self.assertEqual(regex.match('(?r)(x{6}){e<=1}', 'xxxxxx', + flags=regex.BESTMATCH).span(), (0, 6)) + self.assertEqual(regex.match('(?r)(x{6}){e<=1}', 'xxxxx', + flags=regex.BESTMATCH).span(), (0, 5)) + self.assertEqual(regex.match('(?r)(x{6}){e<=1}', 'x', + flags=regex.BESTMATCH), None) + + # Hg issue 197: ValueError in regex.compile + self.assertRaises(regex.error, lambda: + regex.compile(b'00000\\0\\00\\^\50\\00\\U05000000')) + + # Hg issue 198: ValueError in regex.compile + self.assertRaises(regex.error, lambda: regex.compile(b"{e', '22', aa=['121', + '22'])), True) + self.assertEqual(bool(regex.search(r'(?ri)\L', '22', aa=['121', + '22'])), True) + self.assertEqual(bool(regex.search(r'(?fi)\L', '22', aa=['121', + '22'])), True) + self.assertEqual(bool(regex.search(r'(?fri)\L', '22', aa=['121', + '22'])), True) + + # Hg issue 208: Named list, (?ri) flags, Backreference + self.assertEqual(regex.search(r'(?r)\1dog..(?<=(\L))$', 'ccdogcc', + aa=['bcb', 'cc']). span(), (0, 7)) + self.assertEqual(regex.search(r'(?ir)\1dog..(?<=(\L))$', + 'ccdogcc', aa=['bcb', 'cc']). span(), (0, 7)) + + # Hg issue 210: Fuzzy matching and Backreference + self.assertEqual(regex.search(r'(2)(?:\1{5}){e<=1}', + '3222212').span(), (1, 7)) + self.assertEqual(regex.search(r'(\d)(?:\1{5}){e<=1}', + '3222212').span(), (1, 7)) + + # Hg issue 211: Segmentation fault with recursive matches and atomic + # groups + self.assertEqual(regex.match(r'''\A(?P(?>\((?&whole)\)|[+\-]))\Z''', + '((-))').span(), (0, 5)) + self.assertEqual(regex.match(r'''\A(?P(?>\((?&whole)\)|[+\-]))\Z''', + '((-)+)'), None) + + # Hg issue 212: Unexpected matching difference with .*? between re and + # regex + self.assertEqual(regex.match(r"x.*? (.).*\1(.*)\1", + 'x |y| z|').span(), (0, 9)) + self.assertEqual(regex.match(r"\.sr (.*?) (.)(.*)\2(.*)\2(.*)", + r'.sr h |||').span(), (0, 35)) + + # Hg issue 213: Segmentation Fault + a = '"\\xF9\\x80\\xAEqdz\\x95L\\xA7\\x89[\\xFE \\x91)\\xF9]\\xDB\'\\x99\\x09=\\x00\\xFD\\x98\\x22\\xDD\\xF1\\xB6\\xC3 Z\\xB6gv\\xA5x\\x93P\\xE1r\\x14\\x8Cv\\x0C\\xC0w\\x15r\\xFFc%" ' + py_regex_pattern = r'''(?P((?>(?"(?>\\.|[^\\"]+)+"|""|(?>'(?>\\.|[^\\']+)+')|''|(?>`(?>\\.|[^\\`]+)+`)|``)))) (?P((?>(?"(?>\\.|[^\\"]+)+"|""|(?>'(?>\\.|[^\\']+)+')|''|(?>`(?>\\.|[^\\`]+)+`)|``))))''' + self.assertEqual(bool(regex.search(py_regex_pattern, a)), False) + + # Hg Issue 216: Invalid match when using negative lookbehind and pipe + self.assertEqual(bool(regex.match('foo(?<=foo)', 'foo')), True) + self.assertEqual(bool(regex.match('foo(?.*\!\w*\:.*)|(?P.*))', + '!')), False) + + # Hg issue 220: Misbehavior of group capture with OR operand + self.assertEqual(regex.match(r'\w*(ea)\w*|\w*e(?!a)\w*', + 'easier').groups(), ('ea', )) + + # Hg issue 225: BESTMATCH in fuzzy match not working + self.assertEqual(regex.search('(^1234$){i,d}', '12234', + regex.BESTMATCH).span(), (0, 5)) + self.assertEqual(regex.search('(^1234$){i,d}', '12234', + regex.BESTMATCH).fuzzy_counts, (0, 1, 0)) + + self.assertEqual(regex.search('(^1234$){s,i,d}', '12234', + regex.BESTMATCH).span(), (0, 5)) + self.assertEqual(regex.search('(^1234$){s,i,d}', '12234', + regex.BESTMATCH).fuzzy_counts, (0, 1, 0)) + + # Hg issue 226: Error matching at start of string + self.assertEqual(regex.search('(^123$){s,i,d}', 'xxxxxxxx123', + regex.BESTMATCH).span(), (0, 11)) + self.assertEqual(regex.search('(^123$){s,i,d}', 'xxxxxxxx123', + regex.BESTMATCH).fuzzy_counts, (0, 8, 0)) + + # Hg issue 227: Incorrect behavior for ? operator with UNICODE + + # IGNORECASE + self.assertEqual(regex.search(r'a?yz', 'xxxxyz', flags=regex.FULLCASE | + regex.IGNORECASE).span(), (4, 6)) + + # Hg issue 230: Is it a bug of (?(DEFINE)...) + self.assertEqual(regex.findall(r'(?:(?![a-d]).)+', 'abcdefgh'), + ['efgh']) + self.assertEqual(regex.findall(r'''(?(DEFINE)(?P(?:(?![a-d]).)))(?&mydef)+''', + 'abcdefgh'), ['efgh']) + + # Hg issue 238: Not fully re backward compatible + self.assertEqual(regex.findall(r'((\w{1,3})(\.{2,10})){1,3}', + '"Erm....yes. T..T...Thank you for that."'), [('Erm....', 'Erm', + '....'), ('T...', 'T', '...')]) + self.assertEqual(regex.findall(r'((\w{1,3})(\.{2,10})){3}', + '"Erm....yes. T..T...Thank you for that."'), []) + self.assertEqual(regex.findall(r'((\w{1,3})(\.{2,10})){2}', + '"Erm....yes. T..T...Thank you for that."'), [('T...', 'T', '...')]) + self.assertEqual(regex.findall(r'((\w{1,3})(\.{2,10})){1}', + '"Erm....yes. T..T...Thank you for that."'), [('Erm....', 'Erm', + '....'), ('T..', 'T', '..'), ('T...', 'T', '...')]) + + # Hg issue 247: Unexpected result with fuzzy matching and lookahead + # expression + self.assertEqual(regex.search(r'(?:ESTONIA(?!\w)){e<=1}', + 'ESTONIAN WORKERS').group(), 'ESTONIAN') + self.assertEqual(regex.search(r'(?:ESTONIA(?=\W)){e<=1}', + 'ESTONIAN WORKERS').group(), 'ESTONIAN') + + self.assertEqual(regex.search(r'(?:(?.))(?&func)', + 'abc').groups(), (None, )) + self.assertEqual(regex.search(r'(?(DEFINE)(?.))(?&func)', + 'abc').groupdict(), {'func': None}) + self.assertEqual(regex.search(r'(?(DEFINE)(?.))(?&func)', + 'abc').capturesdict(), {'func': ['a']}) + + self.assertEqual(regex.search(r'(?(DEFINE)(?.))(?=(?&func))', + 'abc').groups(), (None, )) + self.assertEqual(regex.search(r'(?(DEFINE)(?.))(?=(?&func))', + 'abc').groupdict(), {'func': None}) + self.assertEqual(regex.search(r'(?(DEFINE)(?.))(?=(?&func))', + 'abc').capturesdict(), {'func': ['a']}) + + self.assertEqual(regex.search(r'(?(DEFINE)(?.)).(?<=(?&func))', + 'abc').groups(), (None, )) + self.assertEqual(regex.search(r'(?(DEFINE)(?.)).(?<=(?&func))', + 'abc').groupdict(), {'func': None}) + self.assertEqual(regex.search(r'(?(DEFINE)(?.)).(?<=(?&func))', + 'abc').capturesdict(), {'func': ['a']}) + + # Hg issue 271: Comment logic different between Re and Regex + self.assertEqual(bool(regex.match(r'ab(?#comment\))cd', 'abcd')), True) + + # Hg issue 276: Partial Matches yield incorrect matches and bounds + self.assertEqual(regex.search(r'[a-z]+ [a-z]*?:', 'foo bar', + partial=True).span(), (0, 7)) + self.assertEqual(regex.search(r'(?r):[a-z]*? [a-z]+', 'foo bar', + partial=True).span(), (0, 7)) + + # Hg issue 291: Include Script Extensions as a supported Unicode property + self.assertEqual(bool(regex.match(r'(?u)\p{Script:Beng}', + '\u09EF')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{Script:Bengali}', + '\u09EF')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{Script_Extensions:Bengali}', + '\u09EF')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{Script_Extensions:Beng}', + '\u09EF')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{Script_Extensions:Cakm}', + '\u09EF')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{Script_Extensions:Sylo}', + '\u09EF')), True) + + # Hg issue #293: scx (Script Extensions) property currently matches + # incorrectly + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Latin}', 'P')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Ahom}', 'P')), False) + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Common}', '4')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Caucasian_Albanian}', '4')), + False) + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Arabic}', '\u062A')), True) + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Balinese}', '\u062A')), + False) + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Devanagari}', '\u091C')), + True) + self.assertEqual(bool(regex.match(r'(?u)\p{scx:Batak}', '\u091C')), False) + + # Hg issue 296: Group references are not taken into account when group is reporting the last match + self.assertEqual(regex.fullmatch('(?P.)*(?&x)', 'abc').captures('x'), + ['a', 'b', 'c']) + self.assertEqual(regex.fullmatch('(?P.)*(?&x)', 'abc').group('x'), + 'b') + + self.assertEqual(regex.fullmatch('(?P.)(?P.)(?P.)', + 'abc').captures('x'), ['a', 'b', 'c']) + self.assertEqual(regex.fullmatch('(?P.)(?P.)(?P.)', + 'abc').group('x'), 'c') + + # Hg issue 299: Partial gives misleading results with "open ended" regexp + self.assertEqual(regex.match('(?:ab)*', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)*', 'abab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)*?', '', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)*+', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)*+', 'abab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)+', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)+', 'abab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)+?', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)++', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?:ab)++', 'abab', partial=True).partial, + False) + + self.assertEqual(regex.match('(?r)(?:ab)*', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)*', 'abab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)*?', '', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)*+', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)*+', 'abab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)+', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)+', 'abab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)+?', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)++', 'ab', partial=True).partial, + False) + self.assertEqual(regex.match('(?r)(?:ab)++', 'abab', partial=True).partial, + False) + + self.assertEqual(regex.match('a*', '', partial=True).partial, False) + self.assertEqual(regex.match('a*?', '', partial=True).partial, False) + self.assertEqual(regex.match('a*+', '', partial=True).partial, False) + self.assertEqual(regex.match('a+', '', partial=True).partial, True) + self.assertEqual(regex.match('a+?', '', partial=True).partial, True) + self.assertEqual(regex.match('a++', '', partial=True).partial, True) + self.assertEqual(regex.match('a+', 'a', partial=True).partial, False) + self.assertEqual(regex.match('a+?', 'a', partial=True).partial, False) + self.assertEqual(regex.match('a++', 'a', partial=True).partial, False) + + self.assertEqual(regex.match('(?r)a*', '', partial=True).partial, False) + self.assertEqual(regex.match('(?r)a*?', '', partial=True).partial, False) + self.assertEqual(regex.match('(?r)a*+', '', partial=True).partial, False) + self.assertEqual(regex.match('(?r)a+', '', partial=True).partial, True) + self.assertEqual(regex.match('(?r)a+?', '', partial=True).partial, True) + self.assertEqual(regex.match('(?r)a++', '', partial=True).partial, True) + self.assertEqual(regex.match('(?r)a+', 'a', partial=True).partial, False) + self.assertEqual(regex.match('(?r)a+?', 'a', partial=True).partial, False) + self.assertEqual(regex.match('(?r)a++', 'a', partial=True).partial, False) + + self.assertEqual(regex.match(r"(?:\s*\w+'*)+", 'whatever', partial=True).partial, + False) + + # Hg issue 300: segmentation fault + pattern = ('(?PGGCGTCACACTTTGCTATGCCATAGCAT[AG]TTTATCCATAAGA' + 'TTAGCGGATCCTACCTGACGCTTTTTATCGCAACTCTCTACTGTTTCTCCATAACAGAACATATTGA' + 'CTATCCGGTATTACCCGGCATGACAGGAGTAAAA){e<=1}' + '(?P[ACGT]{1059}){e<=2}' + '(?PTAATCGTCTTGTTTGATACACAAGGGTCGCATCTGCGGCCCTTTTGCTTTTTTAAG' + 'TTGTAAGGATATGCCATTCTAGA){e<=0}' + '(?P[ACGT]{18}){e<=0}' + '(?PAGATCGG[CT]AGAGCGTCGTGTAGGGAAAGAGTGTGG){e<=1}') + + text = ('GCACGGCGTCACACTTTGCTATGCCATAGCATATTTATCCATAAGATTAGCGGATCCTACC' + 'TGACGCTTTTTATCGCAACTCTCTACTGTTTCTCCATAACAGAACATATTGACTATCCGGTATTACC' + 'CGGCATGACAGGAGTAAAAATGGCTATCGACGAAAACAAACAGAAAGCGTTGGCGGCAGCACTGGGC' + 'CAGATTGAGAAACAATTTGGTAAAGGCTCCATCATGCGCCTGGGTGAAGACCGTTCCATGGATGTGG' + 'AAACCATCTCTACCGGTTCGCTTTCACTGGATATCGCGCTTGGGGCAGGTGGTCTGCCGATGGGCCG' + 'TATCGTCGAAATCTACGGACCGGAATCTTCCGGTAAAACCACGCTGACGCTGCAGGTGATCGCCGCA' + 'GCGCAGCGTGAAGGTAAAACCTGTGCGTTTATCGATGCTGAACACGCGCTGGACCCAATCTACGCAC' + 'GTAAACTGGGCGTCGATATCGACAACCTGCTGTGCTCCCAGCCGGACACCGGCGAGCAGGCACTGGA' + 'AATCTGTGACGCCCTGGCGCGTTCTGGCGCAGTAGACGTTATCGTCGTTGACTCCGTGGCGGCACTG' + 'ACGCCGAAAGCGGAAATCGAAGGCGAAATCGGCGACTCTCATATGGGCCTTGCGGCACGTATGATGA' + 'GCCAGGCGATGCGTAAGCTGGCGGGTAACCTGAAGCAGTCCAACACGCTGCTGATCTTCATCAACCC' + 'CATCCGTATGAAAATTGGTGTGATGTTCGGCAACCCGGAAACCACTTACCGGTGGTAACGCGCTGAA' + 'ATTCTACGCCTCTGTTCGTCTCGACATCCGTTAAATCGGCGCGGTGAAAGAGGGCGAAAACGTGGTG' + 'GGTAGCGAAACCCGCGTGAAAGTGGTGAAGAACAAAATCGCTGCGCCGTTTAAACAGGCTGAATTCC' + 'AGATCCTCTACGGCGAAGGTATCAACTTCTACCCCGAACTGGTTGACCTGGGCGTAAAAGAGAAGCT' + 'GATCGAGAAAGCAGGCGCGTGGTACAGCTACAAAGGTGAGAAGATCGGTCAGGGTAAAGCGAATGCG' + 'ACTGCCTGGCTGAAATTTAACCCGGAAACCGCGAAAGAGATCGAGTGAAAAGTACGTGAGTTGCTGC' + 'TGAGCAACCCGAACTCAACGCCGGATTTCTCTGTAGATGATAGCGAAGGCGTAGCAGAAACTAACGA' + 'AGATTTTTAATCGTCTTGTTTGATACACAAGGGTCGCATCTGCGGCCCTTTTGCTTTTTTAAGTTGT' + 'AAGGATATGCCATTCTAGACAGTTAACACACCAACAAAGATCGGTAGAGCGTCGTGTAGGGAAAGAG' + 'TGTGGTACC') + + m = regex.search(pattern, text, flags=regex.BESTMATCH) + self.assertEqual(m.fuzzy_counts, (0, 1, 0)) + self.assertEqual(m.fuzzy_changes, ([], [1206], [])) + + # Hg issue 306: Fuzzy match parameters not respecting quantifier scope + self.assertEqual(regex.search(r'(?e)(dogf(((oo){e<1})|((00){e<1}))d){e<2}', + 'dogfood').fuzzy_counts, (0, 0, 0)) + self.assertEqual(regex.search(r'(?e)(dogf(((oo){e<1})|((00){e<1}))d){e<2}', + 'dogfoot').fuzzy_counts, (1, 0, 0)) + + # Hg issue 312: \X not matching graphemes with zero-width-joins + self.assertEqual(regex.findall(r'\X', + '\U0001F468\u200D\U0001F469\u200D\U0001F467\u200D\U0001F466'), + ['\U0001F468\u200D\U0001F469\u200D\U0001F467\u200D\U0001F466']) + + # Hg issue 320: Abnormal performance + self.assertEqual(bool(regex.search(r'(?=a)a', 'a')), True) + self.assertEqual(bool(regex.search(r'(?!b)a', 'a')), True) + + # Hg issue 327: .fullmatch() causes MemoryError + self.assertEqual(regex.fullmatch(r'((\d)*?)*?', '123').span(), (0, 3)) + + # Hg issue 329: Wrong group matches when question mark quantifier is used within a look behind + self.assertEqual(regex.search(r'''(?(DEFINE)(?(?THIS_SHOULD_NOT_MATCHx?)|(?right))).*(?<=(?&mydef).*)''', + 'x right').capturesdict(), {'mydef': ['right'], 'wrong': [], 'right': + ['right']}) + + # Hg issue 338: specifying allowed characters when fuzzy-matching + self.assertEqual(bool(regex.match(r'(?:cat){e<=1:[u]}', 'cut')), True) + self.assertEqual(bool(regex.match(r'(?:cat){e<=1:u}', 'cut')), True) + + # Hg issue 353: fuzzy changes negative indexes + self.assertEqual(regex.search(r'(?be)(AGTGTTCCCCGCGCCAGCGGGGATAAACCG){s<=5,i<=5,d<=5,s+i+d<=10}', + 'TTCCCCGCGCCAGCGGGGATAAACCG').fuzzy_changes, ([], [], [0, 1, 3, 5])) + + # Git issue 364: Contradictory values in fuzzy_counts and fuzzy_changes + self.assertEqual(regex.match(r'(?:bc){e}', 'c').fuzzy_counts, (1, 0, + 1)) + self.assertEqual(regex.match(r'(?:bc){e}', 'c').fuzzy_changes, ([0], + [], [1])) + self.assertEqual(regex.match(r'(?e)(?:bc){e}', 'c').fuzzy_counts, (0, + 0, 1)) + self.assertEqual(regex.match(r'(?e)(?:bc){e}', 'c').fuzzy_changes, + ([], [], [0])) + self.assertEqual(regex.match(r'(?b)(?:bc){e}', 'c').fuzzy_counts, (0, + 0, 1)) + self.assertEqual(regex.match(r'(?b)(?:bc){e}', 'c').fuzzy_changes, + ([], [], [0])) + + # Git issue 370: Confusions about Fuzzy matching behavior + self.assertEqual(regex.match('(?e)(?:^(\\$ )?\\d{1,3}(,\\d{3})*(\\.\\d{2})$){e}', + '$ 10,112.111.12').fuzzy_counts, (6, 0, 5)) + self.assertEqual(regex.match('(?e)(?:^(\\$ )?\\d{1,3}(,\\d{3})*(\\.\\d{2})$){s<=1}', + '$ 10,112.111.12').fuzzy_counts, (1, 0, 0)) + self.assertEqual(regex.match('(?e)(?:^(\\$ )?\\d{1,3}(,\\d{3})*(\\.\\d{2})$){s<=1,i<=1,d<=1}', + '$ 10,112.111.12').fuzzy_counts, (1, 0, 0)) + self.assertEqual(regex.match('(?e)(?:^(\\$ )?\\d{1,3}(,\\d{3})*(\\.\\d{2})$){s<=3}', + '$ 10,1a2.111.12').fuzzy_counts, (2, 0, 0)) + self.assertEqual(regex.match('(?e)(?:^(\\$ )?\\d{1,3}(,\\d{3})*(\\.\\d{2})$){s<=2}', + '$ 10,1a2.111.12').fuzzy_counts, (2, 0, 0)) + + self.assertEqual(regex.fullmatch(r'(?e)(?:0?,0(?:,0)?){s<=1,d<=1}', + ',0;0').fuzzy_counts, (1, 0, 0)) + self.assertEqual(regex.fullmatch(r'(?e)(?:0??,0(?:,0)?){s<=1,d<=1}', + ',0;0').fuzzy_counts, (1, 0, 0)) + + # Git issue 371: Specifying character set when fuzzy-matching allows characters not in the set + self.assertEqual(regex.search(r"\b(?e)(?:\d{6,20}){i<=5:[\-\\\/]}\b", + "cat dog starting at 00:01132.000. hello world"), None) + + # Git issue 385: Comments in expressions + self.assertEqual(bool(regex.compile('(?#)')), True) + self.assertEqual(bool(regex.compile('(?x)(?#)')), True) + + # Git issue 394: Unexpected behaviour in fuzzy matching with limited character set with IGNORECASE flag + self.assertEqual(regex.findall(r'(\d+){i<=2:[ab]}', '123X4Y5'), + ['123', '4', '5']) + self.assertEqual(regex.findall(r'(?i)(\d+){i<=2:[ab]}', '123X4Y5'), + ['123', '4', '5']) + + # Git issue 403: Fuzzy matching with wrong distance (unnecessary substitutions) + self.assertEqual(regex.match(r'^(test){e<=5}$', 'terstin', + flags=regex.B).fuzzy_counts, (0, 3, 0)) + + # Git issue 408: regex fails with a quantified backreference but succeeds with repeated backref + self.assertEqual(bool(regex.match(r"(?:(x*)\1\1\1)*x$", "x" * 5)), True) + self.assertEqual(bool(regex.match(r"(?:(x*)\1{3})*x$", "x" * 5)), True) + + # Git issue 415: Fuzzy character restrictions don't apply to insertions at "right edge" + self.assertEqual(regex.match(r't(?:es){s<=1:\d}t', 'te5t').group(), + 'te5t') + self.assertEqual(regex.match(r't(?:es){s<=1:\d}t', 'tezt'), None) + self.assertEqual(regex.match(r't(?:es){i<=1:\d}t', 'tes5t').group(), + 'tes5t') + self.assertEqual(regex.match(r't(?:es){i<=1:\d}t', 'teszt'), None) + self.assertEqual(regex.match(r't(?:es){i<=1:\d}t', + 'tes5t').fuzzy_changes, ([], [3], [])) + self.assertEqual(regex.match(r't(es){i<=1,0.*)(?PCTTCC){e<=1}(?P([ACGT]){4,6})(?PCAATACCGACTCCTCACTGTGT){e<=2}(?P([ACGT]){0,6}$)' + + m = regex.match(pattern, sequence, flags=regex.BESTMATCH) + self.assertEqual(m.span(), (0, 50)) + self.assertEqual(m.groupdict(), {'insert': 'TTCAGACGTGTGCT', 'anchor': 'CTTCC', 'umi': 'GATCT', 'sid': 'CAATACCGACTCCTCACTGTGT', 'end': 'GTCT'}) + + m = regex.match(pattern, sequence, flags=regex.ENHANCEMATCH) + self.assertEqual(m.span(), (0, 50)) + self.assertEqual(m.groupdict(), {'insert': 'TTCAGACGTGTGCT', 'anchor': 'CTTCC', 'umi': 'GATCT', 'sid': 'CAATACCGACTCCTCACTGTGT', 'end': 'GTCT'}) + + # Git issue 433: Disagreement between fuzzy_counts and fuzzy_changes + pattern = r'(?P.*)(?PAACACTGG){e<=1}(?P([AT][CG]){5}){e<=2}(?PGTAACCGAAG){e<=2}(?P([ACGT]){0,6}$)' + + sequence = 'GGAAAACACTGGTCTCAGTCTCGTAACCGAAGTGGTCG' + m = regex.match(pattern, sequence, flags=regex.BESTMATCH) + self.assertEqual(m.fuzzy_counts, (0, 0, 0)) + self.assertEqual(m.fuzzy_changes, ([], [], [])) + + sequence = 'GGAAAACACTGGTCTCAGTCTCGTCCCCGAAGTGGTCG' + m = regex.match(pattern, sequence, flags=regex.BESTMATCH) + self.assertEqual(m.fuzzy_counts, (2, 0, 0)) + self.assertEqual(m.fuzzy_changes, ([24, 25], [], [])) + + # Git issue 439: Unmatched groups: sub vs subf + self.assertEqual(regex.sub(r'(test1)|(test2)', r'matched: \1\2', 'test1'), 'matched: test1') + self.assertEqual(regex.subf(r'(test1)|(test2)', r'matched: {1}{2}', 'test1'), 'matched: test1') + self.assertEqual(regex.search(r'(test1)|(test2)', 'matched: test1').expand(r'matched: \1\2'), 'matched: test1'), + self.assertEqual(regex.search(r'(test1)|(test2)', 'matched: test1').expandf(r'matched: {1}{2}'), 'matched: test1') + + # Git issue 442: Fuzzy regex matching doesn't seem to test insertions correctly + self.assertEqual(regex.search(r"(?:\bha\b){i:[ ]}", "having"), None) + self.assertEqual(regex.search(r"(?:\bha\b){i:[ ]}", "having", flags=regex.I), None) + + # Git issue 467: Scoped inline flags 'a', 'u' and 'L' affect global flags + self.assertEqual(regex.match(r'(?a:\w)\w', 'd\N{CYRILLIC SMALL LETTER ZHE}').span(), (0, 2)) + self.assertEqual(regex.match(r'(?a:\w)(?u:\w)', 'd\N{CYRILLIC SMALL LETTER ZHE}').span(), (0, 2)) + + # Git issue 473: Emoji classified as letter + self.assertEqual(regex.match(r'^\p{LC}+$', '\N{SMILING CAT FACE WITH OPEN MOUTH}'), None) + self.assertEqual(regex.match(r'^\p{So}+$', '\N{SMILING CAT FACE WITH OPEN MOUTH}').span(), (0, 1)) + + # Git issue 474: regex has no equivalent to `re.Match.groups()` for captures + self.assertEqual(regex.match(r'(.)+', 'abc').allcaptures(), (['abc'], ['a', 'b', 'c'])) + self.assertEqual(regex.match(r'(.)+', 'abc').allspans(), ([(0, 3)], [(0, 1), (1, 2), (2, 3)])) + + # Git issue 477: \v for vertical spacing + self.assertEqual(bool(regex.fullmatch(r'\p{HorizSpace}+', '\t \xA0\u1680\u180E\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200A\u202F\u205F\u3000')), True) + self.assertEqual(bool(regex.fullmatch(r'\p{VertSpace}+', '\n\v\f\r\x85\u2028\u2029')), True) + + # Git issue 479: Segmentation fault when using conditional pattern + self.assertEqual(regex.match(r'(?(?<=A)|(?(?![^B])C|D))', 'A'), None) + self.assertEqual(regex.search(r'(?(?<=A)|(?(?![^B])C|D))', 'A').span(), (1, 1)) + + # Git issue 494: Backtracking failure matching regex ^a?(a?)b?c\1$ against string abca + self.assertEqual(regex.search(r"^a?(a?)b?c\1$", "abca").span(), (0, 4)) + + # Git issue 498: Conditional negative lookahead inside positive lookahead fails to match + self.assertEqual(regex.match(r'(?(?=a).|..)', 'ab').span(), (0, 1)) + self.assertEqual(regex.match(r'(?(?=b).|..)', 'ab').span(), (0, 2)) + self.assertEqual(regex.match(r'(?(?!a).|..)', 'ab').span(), (0, 2)) + self.assertEqual(regex.match(r'(?(?!b).|..)', 'ab').span(), (0, 1)) + + # Git issue 525: segfault when fuzzy matching empty list + self.assertEqual(regex.match(r"(\L){e<=5}", "blah", foo=[]).span(), (0, 0)) + + # Git issue 527: `VERBOSE`/`X` flag breaks `\N` escapes + self.assertEqual(regex.compile(r'\N{LATIN SMALL LETTER A}').match('a').span(), (0, 1)) + self.assertEqual(regex.compile(r'\N{LATIN SMALL LETTER A}', flags=regex.X).match('a').span(), (0, 1)) + + # Git issue 539: Bug: Partial matching fails on a simple example + self.assertEqual(regex.match(r"[^/]*b/ccc", "b/ccc", partial=True).span(), (0, 5)) + self.assertEqual(regex.match(r"[^/]*b/ccc", "b/ccb", partial=True), None) + self.assertEqual(regex.match(r"[^/]*b/ccc", "b/cc", partial=True).span(), (0, 4)) + self.assertEqual(regex.match(r"[^/]*b/xyz", "b/xy", partial=True).span(), (0, 4)) + self.assertEqual(regex.match(r"[^/]*b/xyz", "b/yz", partial=True), None) + + self.assertEqual(regex.match(r"(?i)[^/]*b/ccc", "b/ccc", partial=True).span(), (0, 5)) + self.assertEqual(regex.match(r"(?i)[^/]*b/ccc", "b/ccb", partial=True), None) + self.assertEqual(regex.match(r"(?i)[^/]*b/ccc", "b/cc", partial=True).span(), (0, 4)) + self.assertEqual(regex.match(r"(?i)[^/]*b/xyz", "b/xy", partial=True).span(), (0, 4)) + self.assertEqual(regex.match(r"(?i)[^/]*b/xyz", "b/yz", partial=True), None) + + # Git issue 546: Partial match not working in some instances with non-greedy capture + self.assertEqual(bool(regex.match(r'.*?', '<', partial=True)), True) + self.assertEqual(bool(regex.match(r'.*?', '.*?', '', partial=True)), True) + self.assertEqual(bool(regex.match(r'.*?', 'x', partial=True)), True) + self.assertEqual(bool(regex.match(r'.*?', 'xyz abc', partial=True)), True) + self.assertEqual(bool(regex.match(r'.*?', 'xyz abc foo', partial=True)), True) + self.assertEqual(bool(regex.match(r'.*?', 'xyz abc foo ', partial=True)), True) + self.assertEqual(bool(regex.match(r'.*?', 'xyz abc foo bar', partial=True)), True) + + # Git issue 551: + self.assertEqual(bool(regex.match(r'(?V1)[[\s\S]]', 'a')), True) + self.assertEqual(bool(regex.match(r'(?V1)[[\s\S]-a]', 'a')), True) + self.assertEqual(bool(regex.match(r'(?V1)[[\s\S]--a]', 'a')), False) + self.assertEqual(bool(regex.match(r'(?V1)[[a-z]--b]', 'a')), True) + self.assertEqual(bool(regex.match(r'(?V1)[[\s\S]--b]', 'a')), True) + self.assertEqual(bool(regex.match(r'(?V1)[a-[\s\S]]', 'a')), True) + self.assertEqual(bool(regex.match(r'(?V1)[a--[\s\S]]', 'a')), False) + + self.assertEqual(regex.search(r'(?ifu)(H\N{LATIN SMALL LETTER O WITH DIAERESIS}gskolan?)[\\s\\S]*p', + 'Yrkesh\N{LATIN SMALL LETTER O WITH DIAERESIS}gskola . Studie\N{LATIN SMALL LETTER A WITH DIAERESIS}mnen . Studie\N{LATIN SMALL LETTER A WITH DIAERESIS}mnen . Studie\N{LATIN SMALL LETTER A WITH DIAERESIS}mnen . Studie\N{LATIN SMALL LETTER A WITH DIAERESIS}mnen . Studie\N{LATIN SMALL LETTER A WITH DIAERESIS}mnen . Studie\N{LATIN SMALL LETTER A WITH DIAERESIS}mnen . Studie\N{LATIN SMALL LETTER A WITH DIAERESIS}mnen'), + None) + + # Git issue 572: Inline ASCII modifier doesn't seem to affect anything + self.assertEqual(bool(regex.match(r'\d', '\uFF19')), True) + self.assertEqual(bool(regex.match(r'(?a:\d)', '\uFF19')), False) + + # Git issue 575: Issues with ASCII/Unicode modifiers + self.assertEqual(regex.findall('\\d', '9\uFF19'), ['9', '\uff19']) + self.assertEqual(regex.findall('(?u:\\d)', '9\uFF19'), ['9', '\uff19']) + self.assertEqual(regex.findall('(?a:\\d)', '9\uFF19'), ['9']) + + self.assertEqual(regex.findall('\\d', '9\uFF19', flags=regex.U), ['9', '\uff19']) + self.assertEqual(regex.findall('(?u:\\d)', '9\uFF19', flags=regex.U), ['9', '\uff19']) + self.assertEqual(regex.findall('(?a:\\d)', '9\uFF19', flags=regex.U), ['9']) + + self.assertEqual(regex.findall('\\d', '9\uFF19', flags=regex.A), ['9']) + self.assertEqual(regex.findall('(?u:\\d)', '9\uFF19', flags=regex.A), ['9', '\uff19']) + self.assertEqual(regex.findall('(?a:\\d)', '9\uFF19', flags=regex.A), ['9']) + + self.assertEqual(len(regex.findall(r'\p{L}', ''.join(chr(c) for c in range(0x100)), flags=0)), 117) + self.assertEqual(len(regex.findall(r'\p{L}', ''.join(chr(c) for c in range(0x100)), flags=regex.A)), 52) + self.assertEqual(len(regex.findall(r'\p{L}', ''.join(chr(c) for c in range(0x100)), flags=regex.U)), 117) + + self.assertEqual(len(regex.findall(r'(?a:\p{L})', ''.join(chr(c) for c in range(0x100)), flags=0)), 52) + self.assertEqual(len(regex.findall(r'(?a:\p{L})', ''.join(chr(c) for c in range(0x100)), flags=regex.A)), 52) + self.assertEqual(len(regex.findall(r'(?a:\p{L})', ''.join(chr(c) for c in range(0x100)), flags=regex.U)), 52) + + self.assertEqual(len(regex.findall(r'(?u:\p{L})', ''.join(chr(c) for c in range(0x100)), flags=0)), 117) + self.assertEqual(len(regex.findall(r'(?u:\p{L})', ''.join(chr(c) for c in range(0x100)), flags=regex.A)), 117) + self.assertEqual(len(regex.findall(r'(?u:\p{L})', ''.join(chr(c) for c in range(0x100)), flags=regex.U)), 117) + + # Git issue 580: Regression in v2025.7.31: \P{L} no longer matches in simple patterns + self.assertEqual(bool(regex.match(r"\A\P{L}?\p{L}", "hello,")), True) + self.assertEqual(bool(regex.fullmatch(r"\A\P{L}*(?P\p{L}+)\P{L}*\Z", "hello,")), True) + + # Git issue 584: AttributeError: 'AnyAll' object has no attribute 'positive' + self.assertEqual(bool(regex.compile('(\\s|\\S)')), True) + + # Git PR 585: Fix AttributeError: 'AnyAll' object has no attribute '_key' + self.assertEqual(bool(regex.compile('(?:[\\S\\s]|[A-D][M-Z])')), True) + + def test_fuzzy_ext(self): + self.assertEqual(bool(regex.fullmatch(r'(?r)(?:a){e<=1:[a-z]}', 'e')), + True) + self.assertEqual(bool(regex.fullmatch(r'(?:a){e<=1:[a-z]}', 'e')), + True) + self.assertEqual(bool(regex.fullmatch(r'(?:a){e<=1:[a-z]}', '-')), + False) + self.assertEqual(bool(regex.fullmatch(r'(?r)(?:a){e<=1:[a-z]}', '-')), + False) + + self.assertEqual(bool(regex.fullmatch(r'(?:a){e<=1:[a-z]}', 'ae')), + True) + self.assertEqual(bool(regex.fullmatch(r'(?r)(?:a){e<=1:[a-z]}', + 'ae')), True) + self.assertEqual(bool(regex.fullmatch(r'(?:a){e<=1:[a-z]}', 'a-')), + False) + self.assertEqual(bool(regex.fullmatch(r'(?r)(?:a){e<=1:[a-z]}', + 'a-')), False) + + self.assertEqual(bool(regex.fullmatch(r'(?:ab){e<=1:[a-z]}', 'ae')), + True) + self.assertEqual(bool(regex.fullmatch(r'(?r)(?:ab){e<=1:[a-z]}', + 'ae')), True) + self.assertEqual(bool(regex.fullmatch(r'(?:ab){e<=1:[a-z]}', 'a-')), + False) + self.assertEqual(bool(regex.fullmatch(r'(?r)(?:ab){e<=1:[a-z]}', + 'a-')), False) + + self.assertEqual(bool(regex.fullmatch(r'(a)\1{e<=1:[a-z]}', 'ae')), + True) + self.assertEqual(bool(regex.fullmatch(r'(?r)\1{e<=1:[a-z]}(a)', + 'ea')), True) + self.assertEqual(bool(regex.fullmatch(r'(a)\1{e<=1:[a-z]}', 'a-')), + False) + self.assertEqual(bool(regex.fullmatch(r'(?r)\1{e<=1:[a-z]}(a)', + '-a')), False) + + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + 'ts')), True) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + 'st')), True) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + 'st')), True) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + 'ts')), True) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + '-s')), False) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + 's-')), False) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + 's-')), False) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(?:\N{LATIN SMALL LETTER SHARP S}){e<=1:[a-z]}', + '-s')), False) + + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(\N{LATIN SMALL LETTER SHARP S})\1{e<=1:[a-z]}', + 'ssst')), True) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(\N{LATIN SMALL LETTER SHARP S})\1{e<=1:[a-z]}', + 'ssts')), True) + self.assertEqual(bool(regex.fullmatch(r'(?firu)\1{e<=1:[a-z]}(\N{LATIN SMALL LETTER SHARP S})', + 'stss')), True) + self.assertEqual(bool(regex.fullmatch(r'(?firu)\1{e<=1:[a-z]}(\N{LATIN SMALL LETTER SHARP S})', + 'tsss')), True) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(\N{LATIN SMALL LETTER SHARP S})\1{e<=1:[a-z]}', + 'ss-s')), False) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(\N{LATIN SMALL LETTER SHARP S})\1{e<=1:[a-z]}', + 'sss-')), False) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(\N{LATIN SMALL LETTER SHARP S})\1{e<=1:[a-z]}', + '-s')), False) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(\N{LATIN SMALL LETTER SHARP S})\1{e<=1:[a-z]}', + 's-')), False) + + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(ss)\1{e<=1:[a-z]}', + '\N{LATIN SMALL LETTER SHARP S}ts')), True) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(ss)\1{e<=1:[a-z]}', + '\N{LATIN SMALL LETTER SHARP S}st')), True) + self.assertEqual(bool(regex.fullmatch(r'(?firu)\1{e<=1:[a-z]}(ss)', + 'st\N{LATIN SMALL LETTER SHARP S}')), True) + self.assertEqual(bool(regex.fullmatch(r'(?firu)\1{e<=1:[a-z]}(ss)', + 'ts\N{LATIN SMALL LETTER SHARP S}')), True) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(ss)\1{e<=1:[a-z]}', + '\N{LATIN SMALL LETTER SHARP S}-s')), False) + self.assertEqual(bool(regex.fullmatch(r'(?fiu)(ss)\1{e<=1:[a-z]}', + '\N{LATIN SMALL LETTER SHARP S}s-')), False) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(ss)\1{e<=1:[a-z]}', + 's-\N{LATIN SMALL LETTER SHARP S}')), False) + self.assertEqual(bool(regex.fullmatch(r'(?firu)(ss)\1{e<=1:[a-z]}', + '-s\N{LATIN SMALL LETTER SHARP S}')), False) + + def test_subscripted_captures(self): + self.assertEqual(regex.match(r'(?P.)+', + 'abc').expandf('{0} {0[0]} {0[-1]}'), 'abc abc abc') + self.assertEqual(regex.match(r'(?P.)+', + 'abc').expandf('{1} {1[0]} {1[1]} {1[2]} {1[-1]} {1[-2]} {1[-3]}'), + 'c a b c c b a') + self.assertEqual(regex.match(r'(?P.)+', + 'abc').expandf('{x} {x[0]} {x[1]} {x[2]} {x[-1]} {x[-2]} {x[-3]}'), + 'c a b c c b a') + + self.assertEqual(regex.subf(r'(?P.)+', r'{0} {0[0]} {0[-1]}', + 'abc'), 'abc abc abc') + self.assertEqual(regex.subf(r'(?P.)+', + '{1} {1[0]} {1[1]} {1[2]} {1[-1]} {1[-2]} {1[-3]}', 'abc'), + 'c a b c c b a') + self.assertEqual(regex.subf(r'(?P.)+', + '{x} {x[0]} {x[1]} {x[2]} {x[-1]} {x[-2]} {x[-3]}', 'abc'), + 'c a b c c b a') + + def test_more_zerowidth(self): + if sys.version_info >= (3, 7, 0): + self.assertEqual(regex.split(r'\b|:+', 'a::bc'), ['', 'a', '', '', + 'bc', '']) + self.assertEqual(regex.sub(r'\b|:+', '-', 'a::bc'), '-a---bc-') + self.assertEqual(regex.findall(r'\b|:+', 'a::bc'), ['', '', '::', + '', '']) + self.assertEqual([m.span() for m in regex.finditer(r'\b|:+', + 'a::bc')], [(0, 0), (1, 1), (1, 3), (3, 3), (5, 5)]) + self.assertEqual([m.span() for m in regex.finditer(r'(?m)^\s*?$', + 'foo\n\n\nbar')], [(4, 4), (4, 5), (5, 5)]) + + def test_line_ending(self): + self.assertEqual(regex.findall(r'\R', '\r\n\n\x0B\f\r\x85\u2028\u2029'), + ['\r\n', '\n', '\x0B', '\f', '\r', '\x85', '\u2028', '\u2029']) + self.assertEqual(regex.findall(br'\R', b'\r\n\n\x0B\f\r\x85'), [b'\r\n', + b'\n', b'\x0B', b'\f', b'\r']) + +def test_main(): + unittest.main(verbosity=2) + +if __name__ == "__main__": + test_main() diff --git a/venv/lib/python3.13/site-packages/tokenizers/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50525ca5be47f17f4efa5aaa24ca04016c73b0ee Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/decoders/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12ada5dbda08f28a2ffc863cb502bd3f25455ded --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/decoders/__init__.py @@ -0,0 +1,15 @@ +from .. import decoders + + +Decoder = decoders.Decoder +ByteLevel = decoders.ByteLevel +Replace = decoders.Replace +WordPiece = decoders.WordPiece +ByteFallback = decoders.ByteFallback +Fuse = decoders.Fuse +Strip = decoders.Strip +Metaspace = decoders.Metaspace +BPEDecoder = decoders.BPEDecoder +CTC = decoders.CTC +Sequence = decoders.Sequence +DecodeStream = decoders.DecodeStream diff --git a/venv/lib/python3.13/site-packages/tokenizers/decoders/__init__.pyi b/venv/lib/python3.13/site-packages/tokenizers/decoders/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c9876092eb3c41d5a9fc1a9a1accb401c331c841 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/decoders/__init__.pyi @@ -0,0 +1,279 @@ +# Generated content DO NOT EDIT +class DecodeStream: + """ + Class needed for streaming decode + + """ + def __init__(self, ids=None, skip_special_tokens=False): + pass + +class Decoder: + """ + Base class for all decoders + + This class is not supposed to be instantiated directly. Instead, any implementation of + a Decoder will return an instance of this class when instantiated. + """ + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class BPEDecoder(Decoder): + """ + BPEDecoder Decoder + + Args: + suffix (:obj:`str`, `optional`, defaults to :obj:``): + The suffix that was used to characterize an end-of-word. This suffix will + be replaced by whitespaces during the decoding + """ + def __init__(self, suffix=""): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class ByteFallback(Decoder): + """ + ByteFallback Decoder + ByteFallback is a simple trick which converts tokens looking like `<0x61>` + to pure bytes, and attempts to make them into a string. If the tokens + cannot be decoded you will get � instead for each inconvertible byte token + + """ + def __init__(self): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class ByteLevel(Decoder): + """ + ByteLevel Decoder + + This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.ByteLevel` + :class:`~tokenizers.pre_tokenizers.PreTokenizer`. + """ + def __init__(self): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class CTC(Decoder): + """ + CTC Decoder + + Args: + pad_token (:obj:`str`, `optional`, defaults to :obj:``): + The pad token used by CTC to delimit a new token. + word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`): + The word delimiter token. It will be replaced by a + cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to cleanup some tokenization artifacts. + Mainly spaces before punctuation, and some abbreviated english forms. + """ + def __init__(self, pad_token="", word_delimiter_token="|", cleanup=True): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Fuse(Decoder): + """ + Fuse Decoder + Fuse simply fuses every token into a single string. + This is the last step of decoding, this decoder exists only if + there is need to add other decoders *after* the fusion + """ + def __init__(self): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Metaspace(Decoder): + """ + Metaspace Decoder + + Args: + replacement (:obj:`str`, `optional`, defaults to :obj:`▁`): + The replacement character. Must be exactly one character. By default we + use the `▁` (U+2581) meta symbol (Same as in SentencePiece). + + prepend_scheme (:obj:`str`, `optional`, defaults to :obj:`"always"`): + Whether to add a space to the first word if there isn't already one. This + lets us treat `hello` exactly like `say hello`. + Choices: "always", "never", "first". First means the space is only added on the first + token (relevant when special tokens are used or other pre_tokenizer are used). + """ + def __init__(self, replacement="▁", prepend_scheme="always", split=True): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Replace(Decoder): + """ + Replace Decoder + + This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace` + :class:`~tokenizers.pre_tokenizers.PreTokenizer`. + """ + def __init__(self, pattern, content): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Sequence(Decoder): + """ + Sequence Decoder + + Args: + decoders (:obj:`List[Decoder]`) + The decoders that need to be chained + """ + def __init__(self, decoders): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class Strip(Decoder): + """ + Strip normalizer + Strips n left characters of each token, or n right characters of each token + """ + def __init__(self, content, left=0, right=0): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + +class WordPiece(Decoder): + """ + WordPiece Decoder + + Args: + prefix (:obj:`str`, `optional`, defaults to :obj:`##`): + The prefix to use for subwords that are not a beginning-of-word + + cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to cleanup some tokenization artifacts. Mainly spaces before punctuation, + and some abbreviated english forms. + """ + def __init__(self, prefix="##", cleanup=True): + pass + + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass diff --git a/venv/lib/python3.13/site-packages/tokenizers/decoders/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/decoders/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6392580956464d0ee84d65f9dbed30a228c16991 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/decoders/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/implementations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e775892d04a91d645653ea9015954b7985d3147 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/implementations/__init__.py @@ -0,0 +1,6 @@ +from .base_tokenizer import BaseTokenizer +from .bert_wordpiece import BertWordPieceTokenizer +from .byte_level_bpe import ByteLevelBPETokenizer +from .char_level_bpe import CharBPETokenizer +from .sentencepiece_bpe import SentencePieceBPETokenizer +from .sentencepiece_unigram import SentencePieceUnigramTokenizer diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac101e92bdb86274af50d14d4ad91dbd6fcd87f2 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/base_tokenizer.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/base_tokenizer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0ccb4a4bbb3376f9032904947c9edfc991380e6 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/base_tokenizer.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/bert_wordpiece.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/bert_wordpiece.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae56fff290ec58022e097f4e7755a43976d61143 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/bert_wordpiece.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/byte_level_bpe.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/byte_level_bpe.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d8bdefce50521156c9a702099032ca19a6b8b7 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/byte_level_bpe.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/char_level_bpe.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/char_level_bpe.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33c5b975b6dc30ef5873ab0c4e3fe384741ca6c5 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/char_level_bpe.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/sentencepiece_bpe.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/sentencepiece_bpe.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c45dfc2b30ca6c10db158176cc6a439ea13c0b2d Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/sentencepiece_bpe.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/sentencepiece_unigram.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/sentencepiece_unigram.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..450d9b79df77a05b27d56e958626efc59ef28195 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/implementations/__pycache__/sentencepiece_unigram.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/base_tokenizer.py b/venv/lib/python3.13/site-packages/tokenizers/implementations/base_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..432a4754cc86853c27c6929d36c032889780e194 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/implementations/base_tokenizer.py @@ -0,0 +1,459 @@ +from typing import Dict, List, Optional, Tuple, Union + +from tokenizers import AddedToken, EncodeInput, Encoding, InputSequence, Tokenizer +from tokenizers.decoders import Decoder +from tokenizers.models import Model +from tokenizers.normalizers import Normalizer +from tokenizers.pre_tokenizers import PreTokenizer +from tokenizers.processors import PostProcessor + + +Offsets = Tuple[int, int] + + +class BaseTokenizer: + def __init__(self, tokenizer: Tokenizer, parameters=None): + self._tokenizer = tokenizer + self._parameters = parameters if parameters is not None else {} + + def __repr__(self): + return "Tokenizer(vocabulary_size={}, {})".format( + self._tokenizer.get_vocab_size(), + ", ".join(k + "=" + str(v) for k, v in self._parameters.items()), + ) + + def num_special_tokens_to_add(self, is_pair: bool) -> int: + """ + Return the number of special tokens that would be added for single/pair sentences. + :param is_pair: Boolean indicating if the input would be a single sentence or a pair + :return: + """ + return self._tokenizer.num_special_tokens_to_add(is_pair) + + def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]: + """Returns the vocabulary + + Args: + with_added_tokens: boolean: + Whether to include the added tokens in the vocabulary + + Returns: + The vocabulary + """ + return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens) + + def get_added_tokens_decoder(self) -> Dict[int, AddedToken]: + """Returns the added reverse vocabulary + + Returns: + The added vocabulary mapping ints to AddedTokens + """ + return self._tokenizer.get_added_tokens_decoder() + + def get_vocab_size(self, with_added_tokens: bool = True) -> int: + """Return the size of vocabulary, with or without added tokens. + + Args: + with_added_tokens: (`optional`) bool: + Whether to count in added special tokens or not + + Returns: + Size of vocabulary + """ + return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens) + + def enable_padding( + self, + direction: Optional[str] = "right", + pad_to_multiple_of: Optional[int] = None, + pad_id: Optional[int] = 0, + pad_type_id: Optional[int] = 0, + pad_token: Optional[str] = "[PAD]", + length: Optional[int] = None, + ): + """Change the padding strategy + + Args: + direction: (`optional`) str: + Can be one of: `right` or `left` + + pad_to_multiple_of: (`optional`) unsigned int: + If specified, the padding length should always snap to the next multiple of + the given value. For example if we were going to pad with a length of 250 but + `pad_to_multiple_of=8` then we will pad to 256. + + pad_id: (`optional`) unsigned int: + The indice to be used when padding + + pad_type_id: (`optional`) unsigned int: + The type indice to be used when padding + + pad_token: (`optional`) str: + The pad token to be used when padding + + length: (`optional`) unsigned int: + If specified, the length at which to pad. If not specified + we pad using the size of the longest sequence in a batch + """ + return self._tokenizer.enable_padding( + direction=direction, + pad_to_multiple_of=pad_to_multiple_of, + pad_id=pad_id, + pad_type_id=pad_type_id, + pad_token=pad_token, + length=length, + ) + + def no_padding(self): + """Disable padding""" + return self._tokenizer.no_padding() + + @property + def padding(self) -> Optional[dict]: + """Get the current padding parameters + + Returns: + None if padding is disabled, a dict with the currently set parameters + if the padding is enabled. + """ + return self._tokenizer.padding + + def enable_truncation(self, max_length: int, stride: Optional[int] = 0, strategy: Optional[str] = "longest_first"): + """Change the truncation options + + Args: + max_length: unsigned int: + The maximum length at which to truncate + + stride: (`optional`) unsigned int: + The length of the previous first sequence to be included + in the overflowing sequence + + strategy: (`optional`) str: + Can be one of `longest_first`, `only_first` or `only_second` + """ + return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy) + + def no_truncation(self): + """Disable truncation""" + return self._tokenizer.no_truncation() + + @property + def truncation(self) -> Optional[dict]: + """Get the current truncation parameters + + Returns: + None if truncation is disabled, a dict with the current truncation parameters if + truncation is enabled + """ + return self._tokenizer.truncation + + def add_tokens(self, tokens: List[Union[str, AddedToken]]) -> int: + """Add the given tokens to the vocabulary + + Args: + tokens: List[Union[str, AddedToken]]: + A list of tokens to add to the vocabulary. Each token can either be + a string, or an instance of AddedToken + + Returns: + The number of tokens that were added to the vocabulary + """ + return self._tokenizer.add_tokens(tokens) + + def add_special_tokens(self, special_tokens: List[Union[str, AddedToken]]) -> int: + """Add the given special tokens to the vocabulary, and treat them as special tokens. + + The special tokens will never be processed by the model, and will be + removed while decoding. + + Args: + tokens: List[Union[str, AddedToken]]: + A list of special tokens to add to the vocabulary. Each token can either be + a string, or an instance of AddedToken + + Returns: + The number of tokens that were added to the vocabulary + """ + return self._tokenizer.add_special_tokens(special_tokens) + + def normalize(self, sequence: str) -> str: + """Normalize the given sequence + + Args: + sequence: str: + The sequence to normalize + + Returns: + The normalized string + """ + return self._tokenizer.normalize(sequence) + + def encode( + self, + sequence: InputSequence, + pair: Optional[InputSequence] = None, + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> Encoding: + """Encode the given sequence and pair. This method can process raw text sequences as well + as already pre-tokenized sequences. + + Args: + sequence: InputSequence: + The sequence we want to encode. This sequence can be either raw text or + pre-tokenized, according to the `is_pretokenized` argument: + + - If `is_pretokenized=False`: `InputSequence` is expected to be `str` + - If `is_pretokenized=True`: `InputSequence` is expected to be + `Union[List[str], Tuple[str]]` + + is_pretokenized: bool: + Whether the input is already pre-tokenized. + + add_special_tokens: bool: + Whether to add the special tokens while encoding. + + Returns: + An Encoding + """ + if sequence is None: + raise ValueError("encode: `sequence` can't be `None`") + + return self._tokenizer.encode(sequence, pair, is_pretokenized, add_special_tokens) + + def encode_batch( + self, + inputs: List[EncodeInput], + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> List[Encoding]: + """Encode the given inputs. This method accept both raw text sequences as well as already + pre-tokenized sequences. + + Args: + inputs: List[EncodeInput]: + A list of single sequences or pair sequences to encode. Each `EncodeInput` is + expected to be of the following form: + `Union[InputSequence, Tuple[InputSequence, InputSequence]]` + + Each `InputSequence` can either be raw text or pre-tokenized, + according to the `is_pretokenized` argument: + + - If `is_pretokenized=False`: `InputSequence` is expected to be `str` + - If `is_pretokenized=True`: `InputSequence` is expected to be + `Union[List[str], Tuple[str]]` + + is_pretokenized: bool: + Whether the input is already pre-tokenized. + + add_special_tokens: bool: + Whether to add the special tokens while encoding. + + Returns: + A list of Encoding + """ + + if inputs is None: + raise ValueError("encode_batch: `inputs` can't be `None`") + + return self._tokenizer.encode_batch(inputs, is_pretokenized, add_special_tokens) + + async def async_encode_batch( + self, + inputs: List[EncodeInput], + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> List[Encoding]: + """Asynchronously encode a batch (tracks character offsets). + + Args: + inputs: A list of single or pair sequences to encode. + is_pretokenized: Whether inputs are already pre-tokenized. + add_special_tokens: Whether to add special tokens. + + Returns: + A list of Encoding. + """ + if inputs is None: + raise ValueError("async_encode_batch: `inputs` can't be `None`") + # Exposed by the Rust bindings via pyo3_async_runtimes::tokio::future_into_py + return await self._tokenizer.async_encode_batch(inputs, is_pretokenized, add_special_tokens) + + async def async_encode_batch_fast( + self, + inputs: List[EncodeInput], + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> List[Encoding]: + """Asynchronously encode a batch (no character offsets, faster). + + Args: + inputs: A list of single or pair sequences to encode. + is_pretokenized: Whether inputs are already pre-tokenized. + add_special_tokens: Whether to add special tokens. + + Returns: + A list of Encoding. + """ + if inputs is None: + raise ValueError("async_encode_batch_fast: `inputs` can't be `None`") + return await self._tokenizer.async_encode_batch_fast(inputs, is_pretokenized, add_special_tokens) + + def decode(self, ids: List[int], skip_special_tokens: Optional[bool] = True) -> str: + """Decode the given list of ids to a string sequence + + Args: + ids: List[unsigned int]: + A list of ids to be decoded + + skip_special_tokens: (`optional`) boolean: + Whether to remove all the special tokens from the output string + + Returns: + The decoded string + """ + if ids is None: + raise ValueError("None input is not valid. Should be a list of integers.") + + return self._tokenizer.decode(ids, skip_special_tokens=skip_special_tokens) + + def decode_batch(self, sequences: List[List[int]], skip_special_tokens: Optional[bool] = True) -> str: + """Decode the list of sequences to a list of string sequences + + Args: + sequences: List[List[unsigned int]]: + A list of sequence of ids to be decoded + + skip_special_tokens: (`optional`) boolean: + Whether to remove all the special tokens from the output strings + + Returns: + A list of decoded strings + """ + if sequences is None: + raise ValueError("None input is not valid. Should be list of list of integers.") + + return self._tokenizer.decode_batch(sequences, skip_special_tokens=skip_special_tokens) + + def token_to_id(self, token: str) -> Optional[int]: + """Convert the given token to its corresponding id + + Args: + token: str: + The token to convert + + Returns: + The corresponding id if it exists, None otherwise + """ + return self._tokenizer.token_to_id(token) + + def id_to_token(self, id: int) -> Optional[str]: + """Convert the given token id to its corresponding string + + Args: + token: id: + The token id to convert + + Returns: + The corresponding string if it exists, None otherwise + """ + return self._tokenizer.id_to_token(id) + + def save_model(self, directory: str, prefix: Optional[str] = None): + """Save the current model to the given directory + + Args: + directory: str: + A path to the destination directory + + prefix: (Optional) str: + An optional prefix, used to prefix each file name + """ + return self._tokenizer.model.save(directory, prefix=prefix) + + def save(self, path: str, pretty: bool = True): + """Save the current Tokenizer at the given path + + Args: + path: str: + A path to the destination Tokenizer file + """ + return self._tokenizer.save(path, pretty) + + def to_str(self, pretty: bool = False): + """Get a serialized JSON version of the Tokenizer as a str + + Args: + pretty: bool: + Whether the JSON string should be prettified + + Returns: + str + """ + return self._tokenizer.to_str(pretty) + + def post_process( + self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True + ) -> Encoding: + """Apply all the post-processing steps to the given encodings. + + The various steps are: + 1. Truncate according to global params (provided to `enable_truncation`) + 2. Apply the PostProcessor + 3. Pad according to global params. (provided to `enable_padding`) + + Args: + encoding: Encoding: + The main Encoding to post process + + pair: Optional[Encoding]: + An optional pair Encoding + + add_special_tokens: bool: + Whether to add special tokens + + Returns: + The resulting Encoding + """ + return self._tokenizer.post_process(encoding, pair, add_special_tokens) + + @property + def model(self) -> Model: + return self._tokenizer.model + + @model.setter + def model(self, model: Model): + self._tokenizer.model = model + + @property + def normalizer(self) -> Normalizer: + return self._tokenizer.normalizer + + @normalizer.setter + def normalizer(self, normalizer: Normalizer): + self._tokenizer.normalizer = normalizer + + @property + def pre_tokenizer(self) -> PreTokenizer: + return self._tokenizer.pre_tokenizer + + @pre_tokenizer.setter + def pre_tokenizer(self, pre_tokenizer: PreTokenizer): + self._tokenizer.pre_tokenizer = pre_tokenizer + + @property + def post_processor(self) -> PostProcessor: + return self._tokenizer.post_processor + + @post_processor.setter + def post_processor(self, post_processor: PostProcessor): + self._tokenizer.post_processor = post_processor + + @property + def decoder(self) -> Decoder: + return self._tokenizer.decoder + + @decoder.setter + def decoder(self, decoder: Decoder): + self._tokenizer.decoder = decoder diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/bert_wordpiece.py b/venv/lib/python3.13/site-packages/tokenizers/implementations/bert_wordpiece.py new file mode 100644 index 0000000000000000000000000000000000000000..1f34e3ca8a4f8b3ed454e09d828918881232ef90 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/implementations/bert_wordpiece.py @@ -0,0 +1,151 @@ +from typing import Dict, Iterator, List, Optional, Union + +from tokenizers import AddedToken, Tokenizer, decoders, trainers +from tokenizers.models import WordPiece +from tokenizers.normalizers import BertNormalizer +from tokenizers.pre_tokenizers import BertPreTokenizer +from tokenizers.processors import BertProcessing + +from .base_tokenizer import BaseTokenizer + + +class BertWordPieceTokenizer(BaseTokenizer): + """Bert WordPiece Tokenizer""" + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + unk_token: Union[str, AddedToken] = "[UNK]", + sep_token: Union[str, AddedToken] = "[SEP]", + cls_token: Union[str, AddedToken] = "[CLS]", + pad_token: Union[str, AddedToken] = "[PAD]", + mask_token: Union[str, AddedToken] = "[MASK]", + clean_text: bool = True, + handle_chinese_chars: bool = True, + strip_accents: Optional[bool] = None, + lowercase: bool = True, + wordpieces_prefix: str = "##", + ): + if vocab is not None: + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token))) + else: + tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) + + # Let the tokenizer know about special tokens if they are part of the vocab + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + if tokenizer.token_to_id(str(sep_token)) is not None: + tokenizer.add_special_tokens([str(sep_token)]) + if tokenizer.token_to_id(str(cls_token)) is not None: + tokenizer.add_special_tokens([str(cls_token)]) + if tokenizer.token_to_id(str(pad_token)) is not None: + tokenizer.add_special_tokens([str(pad_token)]) + if tokenizer.token_to_id(str(mask_token)) is not None: + tokenizer.add_special_tokens([str(mask_token)]) + + tokenizer.normalizer = BertNormalizer( + clean_text=clean_text, + handle_chinese_chars=handle_chinese_chars, + strip_accents=strip_accents, + lowercase=lowercase, + ) + tokenizer.pre_tokenizer = BertPreTokenizer() + + if vocab is not None: + sep_token_id = tokenizer.token_to_id(str(sep_token)) + if sep_token_id is None: + raise TypeError("sep_token not found in the vocabulary") + cls_token_id = tokenizer.token_to_id(str(cls_token)) + if cls_token_id is None: + raise TypeError("cls_token not found in the vocabulary") + + tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id)) + tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) + + parameters = { + "model": "BertWordPiece", + "unk_token": unk_token, + "sep_token": sep_token, + "cls_token": cls_token, + "pad_token": pad_token, + "mask_token": mask_token, + "clean_text": clean_text, + "handle_chinese_chars": handle_chinese_chars, + "strip_accents": strip_accents, + "lowercase": lowercase, + "wordpieces_prefix": wordpieces_prefix, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab: str, **kwargs): + vocab = WordPiece.read_file(vocab) + return BertWordPieceTokenizer(vocab, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + ): + """Train the model using the given files""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/byte_level_bpe.py b/venv/lib/python3.13/site-packages/tokenizers/implementations/byte_level_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..f65f05e1ddd4c8ec6b3791aa3045762cc06523e3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/implementations/byte_level_bpe.py @@ -0,0 +1,122 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, processors, trainers +from tokenizers.models import BPE +from tokenizers.normalizers import Lowercase, Sequence, unicode_normalizer_from_str + +from .base_tokenizer import BaseTokenizer + + +class ByteLevelBPETokenizer(BaseTokenizer): + """ByteLevelBPETokenizer + + Represents a Byte-level BPE as introduced by OpenAI with their GPT-2 model + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, List[Tuple[str, str]]]] = None, + add_prefix_space: bool = False, + lowercase: bool = False, + dropout: Optional[float] = None, + unicode_normalizer: Optional[str] = None, + continuing_subword_prefix: Optional[str] = None, + end_of_word_suffix: Optional[str] = None, + trim_offsets: bool = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer( + BPE( + vocab, + merges, + dropout=dropout, + continuing_subword_prefix=continuing_subword_prefix or "", + end_of_word_suffix=end_of_word_suffix or "", + ) + ) + else: + tokenizer = Tokenizer(BPE()) + + # Check for Unicode normalization first (before everything else) + normalizers = [] + + if unicode_normalizer: + normalizers += [unicode_normalizer_from_str(unicode_normalizer)] + + if lowercase: + normalizers += [Lowercase()] + + # Create the normalizer structure + if len(normalizers) > 0: + if len(normalizers) > 1: + tokenizer.normalizer = Sequence(normalizers) + else: + tokenizer.normalizer = normalizers[0] + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.ByteLevel(trim_offsets=trim_offsets) + + parameters = { + "model": "ByteLevelBPE", + "add_prefix_space": add_prefix_space, + "lowercase": lowercase, + "dropout": dropout, + "unicode_normalizer": unicode_normalizer, + "continuing_subword_prefix": continuing_subword_prefix, + "end_of_word_suffix": end_of_word_suffix, + "trim_offsets": trim_offsets, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return ByteLevelBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + show_progress=show_progress, + special_tokens=special_tokens, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + show_progress=show_progress, + special_tokens=special_tokens, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/char_level_bpe.py b/venv/lib/python3.13/site-packages/tokenizers/implementations/char_level_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..62b5bcdf06b4026ce48620ee4d681f0c7399b520 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/implementations/char_level_bpe.py @@ -0,0 +1,150 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from .. import AddedToken, Tokenizer, decoders, pre_tokenizers, trainers +from ..models import BPE +from ..normalizers import BertNormalizer, Lowercase, Sequence, unicode_normalizer_from_str +from .base_tokenizer import BaseTokenizer + + +class CharBPETokenizer(BaseTokenizer): + """Original BPE Tokenizer + + Represents the BPE algorithm, as introduced by Rico Sennrich + (https://arxiv.org/abs/1508.07909) + + The defaults settings corresponds to OpenAI GPT BPE tokenizers and differs from the original + Sennrich subword-nmt implementation by the following options that you can deactivate: + - adding a normalizer to clean up the text (deactivate with `bert_normalizer=False`) by: + * removing any control characters and replacing all whitespaces by the classic one. + * handle chinese chars by putting spaces around them. + * strip all accents. + - spitting on punctuation in addition to whitespaces (deactivate it with + `split_on_whitespace_only=True`) + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, List[Tuple[str, str]]]] = None, + unk_token: Union[str, AddedToken] = "", + suffix: str = "", + dropout: Optional[float] = None, + lowercase: bool = False, + unicode_normalizer: Optional[str] = None, + bert_normalizer: bool = True, + split_on_whitespace_only: bool = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer( + BPE( + vocab, + merges, + dropout=dropout, + unk_token=str(unk_token), + end_of_word_suffix=suffix, + ) + ) + else: + tokenizer = Tokenizer(BPE(unk_token=str(unk_token), dropout=dropout, end_of_word_suffix=suffix)) + + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + + # Check for Unicode normalization first (before everything else) + normalizers = [] + + if unicode_normalizer: + normalizers += [unicode_normalizer_from_str(unicode_normalizer)] + + if bert_normalizer: + normalizers += [BertNormalizer(lowercase=False)] + + if lowercase: + normalizers += [Lowercase()] + + # Create the normalizer structure + if len(normalizers) > 0: + if len(normalizers) > 1: + tokenizer.normalizer = Sequence(normalizers) + else: + tokenizer.normalizer = normalizers[0] + + if split_on_whitespace_only: + tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit() + else: + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + tokenizer.decoder = decoders.BPEDecoder(suffix=suffix) + + parameters = { + "model": "BPE", + "unk_token": unk_token, + "suffix": suffix, + "dropout": dropout, + "lowercase": lowercase, + "unicode_normalizer": unicode_normalizer, + "bert_normalizer": bert_normalizer, + "split_on_whitespace_only": split_on_whitespace_only, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return CharBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + suffix: Optional[str] = "", + show_progress: bool = True, + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + end_of_word_suffix=suffix, + show_progress=show_progress, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + suffix: Optional[str] = "", + show_progress: bool = True, + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + end_of_word_suffix=suffix, + show_progress=show_progress, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/sentencepiece_bpe.py b/venv/lib/python3.13/site-packages/tokenizers/implementations/sentencepiece_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..26200489a60dfc6420b43f5dda21ad18ebfe7484 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/implementations/sentencepiece_bpe.py @@ -0,0 +1,103 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, trainers +from tokenizers.models import BPE +from tokenizers.normalizers import NFKC + +from .base_tokenizer import BaseTokenizer + + +class SentencePieceBPETokenizer(BaseTokenizer): + """SentencePiece BPE Tokenizer + + Represents the BPE algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, List[Tuple[str, str]]]] = None, + unk_token: Union[str, AddedToken] = "", + replacement: str = "▁", + add_prefix_space: bool = True, + dropout: Optional[float] = None, + fuse_unk: Optional[bool] = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer(BPE(vocab, merges, dropout=dropout, unk_token=unk_token, fuse_unk=fuse_unk)) + else: + tokenizer = Tokenizer(BPE(dropout=dropout, unk_token=unk_token, fuse_unk=fuse_unk)) + + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + + tokenizer.normalizer = NFKC() + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceBPE", + "unk_token": unk_token, + "replacement": replacement, + "add_prefix_space": add_prefix_space, + "dropout": dropout, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return SentencePieceBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + show_progress: bool = True, + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + show_progress=show_progress, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + show_progress: bool = True, + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + show_progress=show_progress, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/venv/lib/python3.13/site-packages/tokenizers/implementations/sentencepiece_unigram.py b/venv/lib/python3.13/site-packages/tokenizers/implementations/sentencepiece_unigram.py new file mode 100644 index 0000000000000000000000000000000000000000..1237e85eb688c02f480e9aa968f476a7401f6067 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/implementations/sentencepiece_unigram.py @@ -0,0 +1,196 @@ +import json +import os +from typing import Iterator, List, Optional, Union, Tuple + +from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers +from tokenizers.models import Unigram + +from .base_tokenizer import BaseTokenizer + + +class SentencePieceUnigramTokenizer(BaseTokenizer): + """SentencePiece Unigram Tokenizer + + Represents the Unigram algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: Optional[List[Tuple[str, float]]] = None, + replacement: str = "▁", + add_prefix_space: bool = True, + ): + if vocab is not None: + # Let Unigram(..) fail if only one of them is None + tokenizer = Tokenizer(Unigram(vocab)) + else: + tokenizer = Tokenizer(Unigram()) + + tokenizer.normalizer = normalizers.Sequence( + [normalizers.Nmt(), normalizers.NFKC(), normalizers.Replace(Regex(" {2,}"), " ")] + ) + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceUnigram", + "replacement": replacement, + "add_prefix_space": add_prefix_space, + } + + super().__init__(tokenizer, parameters) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: Optional[List[Union[str, AddedToken]]] = None, + initial_alphabet: Optional[List[str]] = None, + unk_token: Optional[str] = None, + ): + """ + Train the model using the given files + + Args: + files (:obj:`List[str]`): + A list of path to the files that we should use for training + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + show_progress (:obj:`bool`): + Whether to show progress bars while training. + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + """ + + if special_tokens is None: + special_tokens = [] + + if initial_alphabet is None: + initial_alphabet = [] + + trainer = trainers.UnigramTrainer( + vocab_size=vocab_size, + special_tokens=special_tokens, + show_progress=show_progress, + initial_alphabet=initial_alphabet, + unk_token=unk_token, + ) + + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: Optional[List[Union[str, AddedToken]]] = None, + initial_alphabet: Optional[List[str]] = None, + unk_token: Optional[str] = None, + length: Optional[int] = None, + ): + """ + Train the model using the given iterator + + Args: + iterator (:obj:`Union[Iterator[str], Iterator[Iterator[str]]]`): + Any iterator over strings or list of strings + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + show_progress (:obj:`bool`): + Whether to show progress bars while training. + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + length (:obj:`int`, `optional`): + The total number of sequences in the iterator. This is used to + provide meaningful progress tracking + """ + + if special_tokens is None: + special_tokens = [] + + if initial_alphabet is None: + initial_alphabet = [] + + trainer = trainers.UnigramTrainer( + vocab_size=vocab_size, + special_tokens=special_tokens, + show_progress=show_progress, + initial_alphabet=initial_alphabet, + unk_token=unk_token, + ) + + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) + + @staticmethod + def from_spm(filename: str): + try: + import sys + + sys.path.append(".") + + import sentencepiece_model_pb2 as model + except Exception: + raise Exception( + "You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/src/sentencepiece/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required." + ) + + m = model.ModelProto() + m.ParseFromString(open(filename, "rb").read()) + + precompiled_charsmap = m.normalizer_spec.precompiled_charsmap + vocab = [(piece.piece, piece.score) for piece in m.pieces] + unk_id = m.trainer_spec.unk_id + model_type = m.trainer_spec.model_type + byte_fallback = m.trainer_spec.byte_fallback + if model_type != 1: + raise Exception( + "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" + ) + + replacement = "▁" + add_prefix_space = True + + tokenizer = Tokenizer(Unigram(vocab, unk_id, byte_fallback)) + + if precompiled_charsmap: + tokenizer.normalizer = normalizers.Sequence( + [ + normalizers.Precompiled(precompiled_charsmap), + normalizers.Replace(Regex(" {2,}"), " "), + ] + ) + else: + tokenizer.normalizer = normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")]) + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceUnigram", + } + + obj = BaseTokenizer.__new__(SentencePieceUnigramTokenizer, tokenizer, parameters) + BaseTokenizer.__init__(obj, tokenizer, parameters) + return obj diff --git a/venv/lib/python3.13/site-packages/tokenizers/models/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68ac211aa8032249db6b929ca64f9130c358d40b --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/models/__init__.py @@ -0,0 +1,8 @@ +# Generated content DO NOT EDIT +from .. import models + +Model = models.Model +BPE = models.BPE +Unigram = models.Unigram +WordLevel = models.WordLevel +WordPiece = models.WordPiece diff --git a/venv/lib/python3.13/site-packages/tokenizers/models/__init__.pyi b/venv/lib/python3.13/site-packages/tokenizers/models/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f6862539069cb10f890770017c7a61f210916c97 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/models/__init__.pyi @@ -0,0 +1,591 @@ +# Generated content DO NOT EDIT +class Model: + """ + Base class for all models + + The model represents the actual tokenization algorithm. This is the part that + will contain and manage the learned vocabulary. + + This class cannot be constructed directly. Please use one of the concrete models. + """ + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class BPE(Model): + """ + An implementation of the BPE (Byte-Pair Encoding) algorithm + + Args: + vocab (:obj:`Dict[str, int]`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + merges (:obj:`List[Tuple[str, str]]`, `optional`): + A list of pairs of tokens (:obj:`Tuple[str, str]`) :obj:`[("a", "b"),...]` + + cache_capacity (:obj:`int`, `optional`): + The number of words that the BPE cache can contain. The cache allows + to speed-up the process by keeping the result of the merge operations + for a number of words. + + dropout (:obj:`float`, `optional`): + A float between 0 and 1 that represents the BPE dropout to use. + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + + continuing_subword_prefix (:obj:`str`, `optional`): + The prefix to attach to subword units that don't represent a beginning of word. + + end_of_word_suffix (:obj:`str`, `optional`): + The suffix to attach to subword units that represent an end of word. + + fuse_unk (:obj:`bool`, `optional`): + Whether to fuse any subsequent unknown tokens into a single one + + byte_fallback (:obj:`bool`, `optional`): + Whether to use spm byte-fallback trick (defaults to False) + + ignore_merges (:obj:`bool`, `optional`): + Whether or not to match tokens with the vocab before using merges. + """ + def __init__( + self, + vocab=None, + merges=None, + cache_capacity=None, + dropout=None, + unk_token=None, + continuing_subword_prefix=None, + end_of_word_suffix=None, + fuse_unk=None, + byte_fallback=False, + ignore_merges=False, + ): + pass + + @staticmethod + def from_file(cls, vocab, merge, **kwargs): + """ + Instantiate a BPE model from the given files. + + This method is roughly equivalent to doing:: + + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + bpe = BPE(vocab, merges) + + If you don't need to keep the :obj:`vocab, merges` values lying around, + this method is more optimized than manually calling + :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + merges (:obj:`str`): + The path to a :obj:`merges.txt` file + + Returns: + :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(self, vocab, merges): + """ + Read a :obj:`vocab.json` and a :obj:`merges.txt` files + + This method provides a way to read and parse the content of these files, + returning the relevant data structures. If you want to instantiate some BPE models + from memory, this method gives you the expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + merges (:obj:`str`): + The path to a :obj:`merges.txt` file + + Returns: + A :obj:`Tuple` with the vocab and the merges: + The vocabulary and merges loaded into memory + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class Unigram(Model): + """ + An implementation of the Unigram algorithm + + Args: + vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`): + A list of vocabulary items and their relative score [("am", -0.2442),...] + """ + def __init__(self, vocab, unk_id, byte_fallback): + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class WordLevel(Model): + """ + An implementation of the WordLevel algorithm + + Most simple tokenizer model based on mapping tokens to their corresponding id. + + Args: + vocab (:obj:`str`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + """ + def __init__(self, vocab, unk_token): + pass + + @staticmethod + def from_file(vocab, unk_token): + """ + Instantiate a WordLevel model from the given file + + This method is roughly equivalent to doing:: + + vocab = WordLevel.read_file(vocab_filename) + wordlevel = WordLevel(vocab) + + If you don't need to keep the :obj:`vocab` values lying around, this method is + more optimized than manually calling :meth:`~tokenizers.models.WordLevel.read_file` to + initialize a :class:`~tokenizers.models.WordLevel` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + Returns: + :class:`~tokenizers.models.WordLevel`: An instance of WordLevel loaded from file + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(vocab): + """ + Read a :obj:`vocab.json` + + This method provides a way to read and parse the content of a vocabulary file, + returning the relevant data structures. If you want to instantiate some WordLevel models + from memory, this method gives you the expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + Returns: + :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class WordPiece(Model): + """ + An implementation of the WordPiece algorithm + + Args: + vocab (:obj:`Dict[str, int]`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + + max_input_chars_per_word (:obj:`int`, `optional`): + The maximum number of characters to authorize in a single word. + """ + def __init__(self, vocab, unk_token, max_input_chars_per_word): + pass + + @staticmethod + def from_file(vocab, **kwargs): + """ + Instantiate a WordPiece model from the given file + + This method is roughly equivalent to doing:: + + vocab = WordPiece.read_file(vocab_filename) + wordpiece = WordPiece(vocab) + + If you don't need to keep the :obj:`vocab` values lying around, this method is + more optimized than manually calling :meth:`~tokenizers.models.WordPiece.read_file` to + initialize a :class:`~tokenizers.models.WordPiece` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.txt` file + + Returns: + :class:`~tokenizers.models.WordPiece`: An instance of WordPiece loaded from file + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(vocab): + """ + Read a :obj:`vocab.txt` file + + This method provides a way to read and parse the content of a standard `vocab.txt` + file as used by the WordPiece Model, returning the relevant data structures. If you + want to instantiate some WordPiece models from memory, this method gives you the + expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.txt` file + + Returns: + :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass diff --git a/venv/lib/python3.13/site-packages/tokenizers/models/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbbab721d57fe910520e2d057ca6a0d2ab50be0b Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/normalizers/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/normalizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86d233bd216821d77f5ccf88f874b6f530cedbf5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/normalizers/__init__.py @@ -0,0 +1,29 @@ +from .. import normalizers + + +Normalizer = normalizers.Normalizer +BertNormalizer = normalizers.BertNormalizer +NFD = normalizers.NFD +NFKD = normalizers.NFKD +NFC = normalizers.NFC +NFKC = normalizers.NFKC +Sequence = normalizers.Sequence +Lowercase = normalizers.Lowercase +Prepend = normalizers.Prepend +Strip = normalizers.Strip +StripAccents = normalizers.StripAccents +Nmt = normalizers.Nmt +Precompiled = normalizers.Precompiled +Replace = normalizers.Replace +ByteLevel = normalizers.ByteLevel + +NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD} + + +def unicode_normalizer_from_str(normalizer: str) -> Normalizer: + if normalizer not in NORMALIZERS: + raise ValueError( + "{} is not a known unicode normalizer. Available are {}".format(normalizer, NORMALIZERS.keys()) + ) + + return NORMALIZERS[normalizer]() diff --git a/venv/lib/python3.13/site-packages/tokenizers/normalizers/__init__.pyi b/venv/lib/python3.13/site-packages/tokenizers/normalizers/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1f5555104abcd14dd39e2927a78fa05b89ae101b --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/normalizers/__init__.pyi @@ -0,0 +1,636 @@ +# Generated content DO NOT EDIT +class Normalizer: + """ + Base class for all normalizers + + This class is not supposed to be instantiated directly. Instead, any implementation of a + Normalizer will return an instance of this class when instantiated. + """ + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class BertNormalizer(Normalizer): + """ + BertNormalizer + + Takes care of normalizing raw text before giving it to a Bert model. + This includes cleaning the text, handling accents, chinese chars and lowercasing + + Args: + clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to clean the text, by removing any control characters + and replacing all whitespaces by the classic one. + + handle_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to handle chinese chars by putting spaces around them. + + strip_accents (:obj:`bool`, `optional`): + Whether to strip all accents. If this option is not specified (ie == None), + then it will be determined by the value for `lowercase` (as in the original Bert). + + lowercase (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to lowercase. + """ + def __init__(self, clean_text=True, handle_chinese_chars=True, strip_accents=None, lowercase=True): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class ByteLevel(Normalizer): + """ + Bytelevel Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Lowercase(Normalizer): + """ + Lowercase Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFC(Normalizer): + """ + NFC Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFD(Normalizer): + """ + NFD Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFKC(Normalizer): + """ + NFKC Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class NFKD(Normalizer): + """ + NFKD Unicode Normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Nmt(Normalizer): + """ + Nmt normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Precompiled(Normalizer): + """ + Precompiled normalizer + Don't use manually it is used for compatibility for SentencePiece. + """ + def __init__(self, precompiled_charsmap): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Prepend(Normalizer): + """ + Prepend normalizer + """ + def __init__(self, prepend): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Replace(Normalizer): + """ + Replace normalizer + """ + def __init__(self, pattern, content): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Sequence(Normalizer): + """ + Allows concatenating multiple other Normalizer as a Sequence. + All the normalizers run in sequence in the given order + + Args: + normalizers (:obj:`List[Normalizer]`): + A list of Normalizer to be run as a sequence + """ + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class Strip(Normalizer): + """ + Strip normalizer + """ + def __init__(self, left=True, right=True): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + +class StripAccents(Normalizer): + """ + StripAccents normalizer + """ + def __init__(self): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass diff --git a/venv/lib/python3.13/site-packages/tokenizers/normalizers/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/normalizers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea65fb7d61e67196daa0906def7447cbf55cc414 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/normalizers/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db8ddc20805b1c525be405134f8fa722ace89667 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__init__.py @@ -0,0 +1,16 @@ +# Generated content DO NOT EDIT +from .. import pre_tokenizers + +PreTokenizer = pre_tokenizers.PreTokenizer +BertPreTokenizer = pre_tokenizers.BertPreTokenizer +ByteLevel = pre_tokenizers.ByteLevel +CharDelimiterSplit = pre_tokenizers.CharDelimiterSplit +Digits = pre_tokenizers.Digits +FixedLength = pre_tokenizers.FixedLength +Metaspace = pre_tokenizers.Metaspace +Punctuation = pre_tokenizers.Punctuation +Sequence = pre_tokenizers.Sequence +Split = pre_tokenizers.Split +UnicodeScripts = pre_tokenizers.UnicodeScripts +Whitespace = pre_tokenizers.Whitespace +WhitespaceSplit = pre_tokenizers.WhitespaceSplit diff --git a/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__init__.pyi b/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a5801cca363ef14d19273fe1135c0b864e118b3e --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__init__.pyi @@ -0,0 +1,689 @@ +# Generated content DO NOT EDIT +class PreTokenizer: + """ + Base class for all pre-tokenizers + + This class is not supposed to be instantiated directly. Instead, any implementation of a + PreTokenizer will return an instance of this class when instantiated. + """ + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class BertPreTokenizer(PreTokenizer): + """ + BertPreTokenizer + + This pre-tokenizer splits tokens on spaces, and also on punctuation. + Each occurrence of a punctuation character will be treated separately. + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class ByteLevel(PreTokenizer): + """ + ByteLevel PreTokenizer + + This pre-tokenizer takes care of replacing all bytes of the given string + with a corresponding representation, as well as splitting into words. + + Args: + add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to add a space to the first word if there isn't already one. This + lets us treat `hello` exactly like `say hello`. + use_regex (:obj:`bool`, `optional`, defaults to :obj:`True`): + Set this to :obj:`False` to prevent this `pre_tokenizer` from using + the GPT2 specific regexp for spliting on whitespace. + """ + def __init__(self, add_prefix_space=True, use_regex=True): + pass + + @staticmethod + def alphabet(): + """ + Returns the alphabet used by this PreTokenizer. + + Since the ByteLevel works as its name suggests, at the byte level, it + encodes each byte value to a unique visible character. This means that there is a + total of 256 different characters composing this alphabet. + + Returns: + :obj:`List[str]`: A list of characters that compose the alphabet + """ + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class CharDelimiterSplit(PreTokenizer): + """ + This pre-tokenizer simply splits on the provided char. Works like `.split(delimiter)` + + Args: + delimiter: str: + The delimiter char that will be used to split input + """ + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Digits(PreTokenizer): + """ + This pre-tokenizer simply splits using the digits in separate tokens + + Args: + individual_digits (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set to True, digits will each be separated as follows:: + + "Call 123 please" -> "Call ", "1", "2", "3", " please" + + If set to False, digits will grouped as follows:: + + "Call 123 please" -> "Call ", "123", " please" + """ + def __init__(self, individual_digits=False): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class FixedLength(PreTokenizer): + """ + This pre-tokenizer splits the text into fixed length chunks as used + [here](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1.full) + + Args: + length (:obj:`int`, `optional`, defaults to :obj:`5`): + The length of the chunks to split the text into. + + Strings are split on the character level rather than the byte level to avoid + splitting unicode characters consisting of multiple bytes. + """ + def __init__(self, length=5): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Metaspace(PreTokenizer): + """ + Metaspace pre-tokenizer + + This pre-tokenizer replaces any whitespace by the provided replacement character. + It then tries to split on these spaces. + + Args: + replacement (:obj:`str`, `optional`, defaults to :obj:`▁`): + The replacement character. Must be exactly one character. By default we + use the `▁` (U+2581) meta symbol (Same as in SentencePiece). + + prepend_scheme (:obj:`str`, `optional`, defaults to :obj:`"always"`): + Whether to add a space to the first word if there isn't already one. This + lets us treat `hello` exactly like `say hello`. + Choices: "always", "never", "first". First means the space is only added on the first + token (relevant when special tokens are used or other pre_tokenizer are used). + + """ + def __init__(self, replacement="_", prepend_scheme="always", split=True): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Punctuation(PreTokenizer): + """ + This pre-tokenizer simply splits on punctuation as individual characters. + + Args: + behavior (:class:`~tokenizers.SplitDelimiterBehavior`): + The behavior to use when splitting. + Choices: "removed", "isolated" (default), "merged_with_previous", "merged_with_next", + "contiguous" + """ + def __init__(self, behavior="isolated"): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Sequence(PreTokenizer): + """ + This pre-tokenizer composes other pre_tokenizers and applies them in sequence + """ + def __init__(self, pretokenizers): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Split(PreTokenizer): + """ + Split PreTokenizer + + This versatile pre-tokenizer splits using the provided pattern and + according to the provided behavior. The pattern can be inverted by + making use of the invert flag. + + Args: + pattern (:obj:`str` or :class:`~tokenizers.Regex`): + A pattern used to split the string. Usually a string or a regex built with `tokenizers.Regex`. + If you want to use a regex pattern, it has to be wrapped around a `tokenizers.Regex`, + otherwise we consider is as a string pattern. For example `pattern="|"` + means you want to split on `|` (imagine a csv file for example), while + `pattern=tokenizers.Regex("1|2")` means you split on either '1' or '2'. + behavior (:class:`~tokenizers.SplitDelimiterBehavior`): + The behavior to use when splitting. + Choices: "removed", "isolated", "merged_with_previous", "merged_with_next", + "contiguous" + + invert (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to invert the pattern. + """ + def __init__(self, pattern, behavior, invert=False): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class UnicodeScripts(PreTokenizer): + """ + This pre-tokenizer splits on characters that belong to different language family + It roughly follows https://github.com/google/sentencepiece/blob/master/data/Scripts.txt + Actually Hiragana and Katakana are fused with Han, and 0x30FC is Han too. + This mimicks SentencePiece Unigram implementation. + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class Whitespace(PreTokenizer): + """ + This pre-tokenizer splits on word boundaries according to the `\w+|[^\w\s]+` + regex pattern. It splits on word characters or characters that aren't words or + whitespaces (punctuation such as hyphens, apostrophes, commas, etc.). + + Example: + Use the `Whitespace` function as shown below:: + + ```python + from tokenizers.pre_tokenizers import Whitespace + + pre_tokenizer = Whitespace() + text = "Hello, world! Let's try the Whitespace pre-tokenizer." + pre_tokenizer.pre_tokenize_str(text) + [('Hello', (0, 5)), + (',', (5, 6)), + ('world', (7, 12)), + ('!', (12, 13)), + ('Let', (14, 17)), + ("'", (17, 18)), + ('s', (18, 19)), + ('try', (20, 23)), + ('the', (24, 27)), + ('Whitespace', (28, 38)), + ('pre', (39, 42)), + ('-', (42, 43)), + ('tokenizer', (43, 52)), + ('.', (52, 53))] + ``` + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + +class WhitespaceSplit(PreTokenizer): + """ + This pre-tokenizer simply splits on the whitespace. Works like `.split()` + """ + def __init__(self): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass diff --git a/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..990278e7b3116854605b1b808def001dab01c658 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/pre_tokenizers/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/processors/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06d124037b6d932615fa0d31b02f8ac82ac0b5fc --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/processors/__init__.py @@ -0,0 +1,9 @@ +# Generated content DO NOT EDIT +from .. import processors + +PostProcessor = processors.PostProcessor +BertProcessing = processors.BertProcessing +ByteLevel = processors.ByteLevel +RobertaProcessing = processors.RobertaProcessing +Sequence = processors.Sequence +TemplateProcessing = processors.TemplateProcessing diff --git a/venv/lib/python3.13/site-packages/tokenizers/processors/__init__.pyi b/venv/lib/python3.13/site-packages/tokenizers/processors/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5136d02bbc4d391eba1b2feb4882c1f563db92f3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/processors/__init__.pyi @@ -0,0 +1,342 @@ +# Generated content DO NOT EDIT +class PostProcessor: + """ + Base class for all post-processors + + This class is not supposed to be instantiated directly. Instead, any implementation of + a PostProcessor will return an instance of this class when instantiated. + """ + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class BertProcessing(PostProcessor): + """ + This post-processor takes care of adding the special tokens needed by + a Bert model: + + - a SEP token + - a CLS token + + Args: + sep (:obj:`Tuple[str, int]`): + A tuple with the string representation of the SEP token, and its id + + cls (:obj:`Tuple[str, int]`): + A tuple with the string representation of the CLS token, and its id + """ + def __init__(self, sep, cls): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class ByteLevel(PostProcessor): + """ + This post-processor takes care of trimming the offsets. + + By default, the ByteLevel BPE might include whitespaces in the produced tokens. If you don't + want the offsets to include these whitespaces, then this PostProcessor must be used. + + Args: + trim_offsets (:obj:`bool`): + Whether to trim the whitespaces from the produced offsets. + """ + def __init__(self, trim_offsets=True): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class RobertaProcessing(PostProcessor): + """ + This post-processor takes care of adding the special tokens needed by + a Roberta model: + + - a SEP token + - a CLS token + + It also takes care of trimming the offsets. + By default, the ByteLevel BPE might include whitespaces in the produced tokens. If you don't + want the offsets to include these whitespaces, then this PostProcessor should be initialized + with :obj:`trim_offsets=True` + + Args: + sep (:obj:`Tuple[str, int]`): + A tuple with the string representation of the SEP token, and its id + + cls (:obj:`Tuple[str, int]`): + A tuple with the string representation of the CLS token, and its id + + trim_offsets (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to trim the whitespaces from the produced offsets. + + add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether the add_prefix_space option was enabled during pre-tokenization. This + is relevant because it defines the way the offsets are trimmed out. + """ + def __init__(self, sep, cls, trim_offsets=True, add_prefix_space=True): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class Sequence(PostProcessor): + """ + Sequence Processor + + Args: + processors (:obj:`List[PostProcessor]`) + The processors that need to be chained + """ + def __init__(self, processors): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass + +class TemplateProcessing(PostProcessor): + """ + Provides a way to specify templates in order to add the special tokens to each + input sequence as relevant. + + Let's take :obj:`BERT` tokenizer as an example. It uses two special tokens, used to + delimitate each sequence. :obj:`[CLS]` is always used at the beginning of the first + sequence, and :obj:`[SEP]` is added at the end of both the first, and the pair + sequences. The final result looks like this: + + - Single sequence: :obj:`[CLS] Hello there [SEP]` + - Pair sequences: :obj:`[CLS] My name is Anthony [SEP] What is my name? [SEP]` + + With the type ids as following:: + + [CLS] ... [SEP] ... [SEP] + 0 0 0 1 1 + + You can achieve such behavior using a TemplateProcessing:: + + TemplateProcessing( + single="[CLS] $0 [SEP]", + pair="[CLS] $A [SEP] $B:1 [SEP]:1", + special_tokens=[("[CLS]", 1), ("[SEP]", 0)], + ) + + In this example, each input sequence is identified using a ``$`` construct. This identifier + lets us specify each input sequence, and the type_id to use. When nothing is specified, + it uses the default values. Here are the different ways to specify it: + + - Specifying the sequence, with default ``type_id == 0``: ``$A`` or ``$B`` + - Specifying the `type_id` with default ``sequence == A``: ``$0``, ``$1``, ``$2``, ... + - Specifying both: ``$A:0``, ``$B:1``, ... + + The same construct is used for special tokens: ``(:)?``. + + **Warning**: You must ensure that you are giving the correct tokens/ids as these + will be added to the Encoding without any further check. If the given ids correspond + to something totally different in a `Tokenizer` using this `PostProcessor`, it + might lead to unexpected results. + + Args: + single (:obj:`Template`): + The template used for single sequences + + pair (:obj:`Template`): + The template used when both sequences are specified + + special_tokens (:obj:`Tokens`): + The list of special tokens used in each sequences + + Types: + + Template (:obj:`str` or :obj:`List`): + - If a :obj:`str` is provided, the whitespace is used as delimiter between tokens + - If a :obj:`List[str]` is provided, a list of tokens + + Tokens (:obj:`List[Union[Tuple[int, str], Tuple[str, int], dict]]`): + - A :obj:`Tuple` with both a token and its associated ID, in any order + - A :obj:`dict` with the following keys: + - "id": :obj:`str` => The special token id, as specified in the Template + - "ids": :obj:`List[int]` => The associated IDs + - "tokens": :obj:`List[str]` => The associated tokens + + The given dict expects the provided :obj:`ids` and :obj:`tokens` lists to have + the same length. + """ + def __init__(self, single, pair, special_tokens): + pass + + def num_special_tokens_to_add(self, is_pair): + """ + Return the number of special tokens that would be added for single/pair sentences. + + Args: + is_pair (:obj:`bool`): + Whether the input would be a pair of sequences + + Returns: + :obj:`int`: The number of tokens to add + """ + pass + + def process(self, encoding, pair=None, add_special_tokens=True): + """ + Post-process the given encodings, generating the final one + + Args: + encoding (:class:`~tokenizers.Encoding`): + The encoding for the first sequence + + pair (:class:`~tokenizers.Encoding`, `optional`): + The encoding for the pair sequence + + add_special_tokens (:obj:`bool`): + Whether to add the special tokens + + Return: + :class:`~tokenizers.Encoding`: The final encoding + """ + pass diff --git a/venv/lib/python3.13/site-packages/tokenizers/processors/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/processors/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b902e3b216eafe4ec5411d0683b2506778a4ade Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/processors/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/tools/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f941e2ed39c7d69fa14abff7dcf973d93843ea06 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/tools/__init__.py @@ -0,0 +1 @@ +from .visualizer import Annotation, EncodingVisualizer diff --git a/venv/lib/python3.13/site-packages/tokenizers/tools/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/tools/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f10b04c91bdf4fbbbed2b12bf0d544f60c7d68f9 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/tools/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/tools/__pycache__/visualizer.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/tools/__pycache__/visualizer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9c73e9edaeb0ef04773738f6c49be0321d3c16b Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/tools/__pycache__/visualizer.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/tokenizers/tools/visualizer-styles.css b/venv/lib/python3.13/site-packages/tokenizers/tools/visualizer-styles.css new file mode 100644 index 0000000000000000000000000000000000000000..f54fde45ada66c902c0b41969d0f40d51c9717da --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/tools/visualizer-styles.css @@ -0,0 +1,170 @@ +.tokenized-text { + width:100%; + padding:2rem; + max-height: 400px; + overflow-y: auto; + box-sizing:border-box; + line-height:4rem; /* Lots of space between lines */ + font-family: "Roboto Light", "Ubuntu Light", "Ubuntu", monospace; + box-shadow: 2px 2px 2px rgba(0,0,0,0.2); + background-color: rgba(0,0,0,0.01); + letter-spacing:2px; /* Give some extra separation between chars */ +} +.non-token{ + /* White space and other things the tokenizer ignores*/ + white-space: pre; + letter-spacing:4px; + border-top:1px solid #A0A0A0; /* A gentle border on top and bottom makes tabs more ovious*/ + border-bottom:1px solid #A0A0A0; + line-height: 1rem; + height: calc(100% - 2px); +} + +.token { + white-space: pre; + position:relative; + color:black; + letter-spacing:2px; +} + +.annotation{ + white-space:nowrap; /* Important - ensures that annotations appears even if the annotated text wraps a line */ + border-radius:4px; + position:relative; + width:fit-content; +} +.annotation:before { + /*The before holds the text and the after holds the background*/ + z-index:1000; /* Make sure this is above the background */ + content:attr(data-label); /* The annotations label is on a data attribute */ + color:white; + position:absolute; + font-size:1rem; + text-align:center; + font-weight:bold; + + top:1.75rem; + line-height:0; + left:0; + width:100%; + padding:0.5rem 0; + /* These make it so an annotation doesn't stretch beyond the annotated text if the label is longer*/ + overflow: hidden; + white-space: nowrap; + text-overflow:ellipsis; +} + +.annotation:after { + content:attr(data-label); /* The content defines the width of the annotation*/ + position:absolute; + font-size:0.75rem; + text-align:center; + font-weight:bold; + text-overflow:ellipsis; + top:1.75rem; + line-height:0; + overflow: hidden; + white-space: nowrap; + + left:0; + width:100%; /* 100% of the parent, which is the annotation whose width is the tokens inside it*/ + + padding:0.5rem 0; + /* Nast hack below: + We set the annotations color in code because we don't know the colors at css time. + But you can't pass a color as a data attribute to get it into the pseudo element (this thing) + So to get around that, annotations have the color set on them with a style attribute and then we + can get the color with currentColor. + Annotations wrap tokens and tokens set the color back to black + */ + background-color: currentColor; +} +.annotation:hover::after, .annotation:hover::before{ + /* When the user hovers over an annotation expand the label to display in full + */ + min-width: fit-content; +} + +.annotation:hover{ + /* Emphasize the annotation start end with a border on hover*/ + border-color: currentColor; + border: 2px solid; +} +.special-token:not(:empty){ + /* + A none empty special token is like UNK (as opposed to CLS which has no representation in the text ) + */ + position:relative; +} +.special-token:empty::before{ + /* Special tokens that don't have text are displayed as pseudo elements so we dont select them with the mouse*/ + content:attr(data-stok); + background:#202020; + font-size:0.75rem; + color:white; + margin: 0 0.25rem; + padding: 0.25rem; + border-radius:4px +} + +.special-token:not(:empty):before { + /* Special tokens that have text (UNK) are displayed above the actual text*/ + content:attr(data-stok); + position:absolute; + bottom:1.75rem; + min-width:100%; + width:100%; + height:1rem; + line-height:1rem; + font-size:1rem; + text-align:center; + color:white; + font-weight:bold; + background:#202020; + border-radius:10%; +} +/* +We want to alternate the color of tokens, but we can't use nth child because tokens might be broken up by annotations +instead we apply even and odd class at generation time and color them that way + */ +.even-token{ + background:#DCDCDC ; + border: 1px solid #DCDCDC; +} +.odd-token{ + background:#A0A0A0; + border: 1px solid #A0A0A0; +} +.even-token.multi-token,.odd-token.multi-token{ + background: repeating-linear-gradient( + 45deg, + transparent, + transparent 1px, + #ccc 1px, + #ccc 1px + ), + /* on "bottom" */ + linear-gradient( + to bottom, + #FFB6C1, + #999 + ); +} + +.multi-token:hover::after { + content:"This char has more than 1 token"; /* The content defines the width of the annotation*/ + color:white; + background-color: black; + position:absolute; + font-size:0.75rem; + text-align:center; + font-weight:bold; + text-overflow:ellipsis; + top:1.75rem; + line-height:0; + overflow: hidden; + white-space: nowrap; + left:0; + width:fit-content; /* 100% of the parent, which is the annotation whose width is the tokens inside it*/ + padding:0.5rem 0; +} diff --git a/venv/lib/python3.13/site-packages/tokenizers/tools/visualizer.py b/venv/lib/python3.13/site-packages/tokenizers/tools/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b6158c7f0b034c7fe3cdd7c00e5623fe8b5d50 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/tools/visualizer.py @@ -0,0 +1,403 @@ +import itertools +import os +import re +from string import Template +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple + +from tokenizers import Encoding, Tokenizer + + +dirname = os.path.dirname(__file__) +css_filename = os.path.join(dirname, "visualizer-styles.css") +with open(css_filename) as f: + css = f.read() + + +class Annotation: + start: int + end: int + label: int + + def __init__(self, start: int, end: int, label: str): + self.start = start + self.end = end + self.label = label + + +AnnotationList = List[Annotation] +PartialIntList = List[Optional[int]] + + +class CharStateKey(NamedTuple): + token_ix: Optional[int] + anno_ix: Optional[int] + + +class CharState: + char_ix: Optional[int] + + def __init__(self, char_ix): + self.char_ix = char_ix + + self.anno_ix: Optional[int] = None + self.tokens: List[int] = [] + + @property + def token_ix(self): + return self.tokens[0] if len(self.tokens) > 0 else None + + @property + def is_multitoken(self): + """ + BPE tokenizers can output more than one token for a char + """ + return len(self.tokens) > 1 + + def partition_key(self) -> CharStateKey: + return CharStateKey( + token_ix=self.token_ix, + anno_ix=self.anno_ix, + ) + + +class Aligned: + pass + + +class EncodingVisualizer: + """ + Build an EncodingVisualizer + + Args: + + tokenizer (:class:`~tokenizers.Tokenizer`): + A tokenizer instance + + default_to_notebook (:obj:`bool`): + Whether to render html output in a notebook by default + + annotation_converter (:obj:`Callable`, `optional`): + An optional (lambda) function that takes an annotation in any format and returns + an Annotation object + """ + + unk_token_regex = re.compile("(.{1}\b)?(unk|oov)(\b.{1})?", flags=re.IGNORECASE) + + def __init__( + self, + tokenizer: Tokenizer, + default_to_notebook: bool = True, + annotation_converter: Optional[Callable[[Any], Annotation]] = None, + ): + if default_to_notebook: + try: + from IPython.core.display import HTML, display + except ImportError: + raise Exception( + """We couldn't import IPython utils for html display. + Are you running in a notebook? + You can also pass `default_to_notebook=False` to get back raw HTML + """ + ) + + self.tokenizer = tokenizer + self.default_to_notebook = default_to_notebook + self.annotation_coverter = annotation_converter + pass + + def __call__( + self, + text: str, + annotations: AnnotationList = [], + default_to_notebook: Optional[bool] = None, + ) -> Optional[str]: + """ + Build a visualization of the given text + + Args: + text (:obj:`str`): + The text to tokenize + + annotations (:obj:`List[Annotation]`, `optional`): + An optional list of annotations of the text. The can either be an annotation class + or anything else if you instantiated the visualizer with a converter function + + default_to_notebook (:obj:`bool`, `optional`, defaults to `False`): + If True, will render the html in a notebook. Otherwise returns an html string. + + Returns: + The HTML string if default_to_notebook is False, otherwise (default) returns None and + renders the HTML in the notebook + + """ + final_default_to_notebook = self.default_to_notebook + if default_to_notebook is not None: + final_default_to_notebook = default_to_notebook + if final_default_to_notebook: + try: + from IPython.core.display import HTML, display + except ImportError: + raise Exception( + """We couldn't import IPython utils for html display. + Are you running in a notebook?""" + ) + if self.annotation_coverter is not None: + annotations = list(map(self.annotation_coverter, annotations)) + encoding = self.tokenizer.encode(text) + html = EncodingVisualizer.__make_html(text, encoding, annotations) + if final_default_to_notebook: + display(HTML(html)) + else: + return html + + @staticmethod + def calculate_label_colors(annotations: AnnotationList) -> Dict[str, str]: + """ + Generates a color palette for all the labels in a given set of annotations + + Args: + annotations (:obj:`Annotation`): + A list of annotations + + Returns: + :obj:`dict`: A dictionary mapping labels to colors in HSL format + """ + if len(annotations) == 0: + return {} + labels = set(map(lambda x: x.label, annotations)) + num_labels = len(labels) + h_step = int(255 / num_labels) + if h_step < 20: + h_step = 20 + s = 32 + l = 64 # noqa: E741 + h = 10 + colors = {} + + for label in sorted(labels): # sort so we always get the same colors for a given set of labels + colors[label] = f"hsl({h},{s}%,{l}%)" + h += h_step + return colors + + @staticmethod + def consecutive_chars_to_html( + consecutive_chars_list: List[CharState], + text: str, + encoding: Encoding, + ): + """ + Converts a list of "consecutive chars" into a single HTML element. + Chars are consecutive if they fall under the same word, token and annotation. + The CharState class is a named tuple with a "partition_key" method that makes it easy to + compare if two chars are consecutive. + + Args: + consecutive_chars_list (:obj:`List[CharState]`): + A list of CharStates that have been grouped together + + text (:obj:`str`): + The original text being processed + + encoding (:class:`~tokenizers.Encoding`): + The encoding returned from the tokenizer + + Returns: + :obj:`str`: The HTML span for a set of consecutive chars + """ + first = consecutive_chars_list[0] + if first.char_ix is None: + # its a special token + stoken = encoding.tokens[first.token_ix] + # special tokens are represented as empty spans. We use the data attribute and css + # magic to display it + return f'' + # We're not in a special token so this group has a start and end. + last = consecutive_chars_list[-1] + start = first.char_ix + end = last.char_ix + 1 + span_text = text[start:end] + css_classes = [] # What css classes will we apply on the resulting span + data_items = {} # What data attributes will we apply on the result span + if first.token_ix is not None: + # We can either be in a token or not (e.g. in white space) + css_classes.append("token") + if first.is_multitoken: + css_classes.append("multi-token") + if first.token_ix % 2: + # We use this to color alternating tokens. + # A token might be split by an annotation that ends in the middle of it, so this + # lets us visually indicate a consecutive token despite its possible splitting in + # the html markup + css_classes.append("odd-token") + else: + # Like above, but a different color so we can see the tokens alternate + css_classes.append("even-token") + if EncodingVisualizer.unk_token_regex.search(encoding.tokens[first.token_ix]) is not None: + # This is a special token that is in the text. probably UNK + css_classes.append("special-token") + # TODO is this the right name for the data attribute ? + data_items["stok"] = encoding.tokens[first.token_ix] + else: + # In this case we are looking at a group/single char that is not tokenized. + # e.g. white space + css_classes.append("non-token") + css = f'''class="{" ".join(css_classes)}"''' + data = "" + for key, val in data_items.items(): + data += f' data-{key}="{val}"' + return f"{span_text}" + + @staticmethod + def __make_html(text: str, encoding: Encoding, annotations: AnnotationList) -> str: + char_states = EncodingVisualizer.__make_char_states(text, encoding, annotations) + current_consecutive_chars = [char_states[0]] + prev_anno_ix = char_states[0].anno_ix + spans = [] + label_colors_dict = EncodingVisualizer.calculate_label_colors(annotations) + cur_anno_ix = char_states[0].anno_ix + if cur_anno_ix is not None: + # If we started in an annotation make a span for it + anno = annotations[cur_anno_ix] + label = anno.label + color = label_colors_dict[label] + spans.append(f'') + + for cs in char_states[1:]: + cur_anno_ix = cs.anno_ix + if cur_anno_ix != prev_anno_ix: + # If we've transitioned in or out of an annotation + spans.append( + # Create a span from the current consecutive characters + EncodingVisualizer.consecutive_chars_to_html( + current_consecutive_chars, + text=text, + encoding=encoding, + ) + ) + current_consecutive_chars = [cs] + + if prev_anno_ix is not None: + # if we transitioned out of an annotation close it's span + spans.append("") + if cur_anno_ix is not None: + # If we entered a new annotation make a span for it + anno = annotations[cur_anno_ix] + label = anno.label + color = label_colors_dict[label] + spans.append(f'') + prev_anno_ix = cur_anno_ix + + if cs.partition_key() == current_consecutive_chars[0].partition_key(): + # If the current charchter is in the same "group" as the previous one + current_consecutive_chars.append(cs) + else: + # Otherwise we make a span for the previous group + spans.append( + EncodingVisualizer.consecutive_chars_to_html( + current_consecutive_chars, + text=text, + encoding=encoding, + ) + ) + # An reset the consecutive_char_list to form a new group + current_consecutive_chars = [cs] + # All that's left is to fill out the final span + # TODO I think there is an edge case here where an annotation's span might not close + spans.append( + EncodingVisualizer.consecutive_chars_to_html( + current_consecutive_chars, + text=text, + encoding=encoding, + ) + ) + res = HTMLBody(spans) # Send the list of spans to the body of our html + return res + + @staticmethod + def __make_anno_map(text: str, annotations: AnnotationList) -> PartialIntList: + """ + Args: + text (:obj:`str`): + The raw text we want to align to + + annotations (:obj:`AnnotationList`): + A (possibly empty) list of annotations + + Returns: + A list of length len(text) whose entry at index i is None if there is no annotation on + character i or k, the index of the annotation that covers index i where k is with + respect to the list of annotations + """ + annotation_map = [None] * len(text) + for anno_ix, a in enumerate(annotations): + for i in range(a.start, a.end): + annotation_map[i] = anno_ix + return annotation_map + + @staticmethod + def __make_char_states(text: str, encoding: Encoding, annotations: AnnotationList) -> List[CharState]: + """ + For each character in the original text, we emit a tuple representing it's "state": + + * which token_ix it corresponds to + * which word_ix it corresponds to + * which annotation_ix it corresponds to + + Args: + text (:obj:`str`): + The raw text we want to align to + + annotations (:obj:`List[Annotation]`): + A (possibly empty) list of annotations + + encoding: (:class:`~tokenizers.Encoding`): + The encoding returned from the tokenizer + + Returns: + :obj:`List[CharState]`: A list of CharStates, indicating for each char in the text what + it's state is + """ + annotation_map = EncodingVisualizer.__make_anno_map(text, annotations) + # Todo make this a dataclass or named tuple + char_states: List[CharState] = [CharState(char_ix) for char_ix in range(len(text))] + for token_ix, token in enumerate(encoding.tokens): + offsets = encoding.token_to_chars(token_ix) + if offsets is not None: + start, end = offsets + for i in range(start, end): + char_states[i].tokens.append(token_ix) + for char_ix, anno_ix in enumerate(annotation_map): + char_states[char_ix].anno_ix = anno_ix + + return char_states + + +def HTMLBody(children: List[str], css_styles=css) -> str: + """ + Generates the full html with css from a list of html spans + + Args: + children (:obj:`List[str]`): + A list of strings, assumed to be html elements + + css_styles (:obj:`str`, `optional`): + Optional alternative implementation of the css + + Returns: + :obj:`str`: An HTML string with style markup + """ + children_text = "".join(children) + return f""" + + + + + +
+ {children_text} +
+ + + """ diff --git a/venv/lib/python3.13/site-packages/tokenizers/trainers/__init__.py b/venv/lib/python3.13/site-packages/tokenizers/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22f94c50b7cf63f0b38231ab1ecec88141a678fd --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/trainers/__init__.py @@ -0,0 +1,8 @@ +# Generated content DO NOT EDIT +from .. import trainers + +Trainer = trainers.Trainer +BpeTrainer = trainers.BpeTrainer +UnigramTrainer = trainers.UnigramTrainer +WordLevelTrainer = trainers.WordLevelTrainer +WordPieceTrainer = trainers.WordPieceTrainer diff --git a/venv/lib/python3.13/site-packages/tokenizers/trainers/__init__.pyi b/venv/lib/python3.13/site-packages/tokenizers/trainers/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..58174ab44f65e9e8c6b182cf519fee7c8588fcb9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/tokenizers/trainers/__init__.pyi @@ -0,0 +1,173 @@ +# Generated content DO NOT EDIT +class Trainer: + """ + Base class for all trainers + + This class is not supposed to be instantiated directly. Instead, any implementation of a + Trainer will return an instance of this class when instantiated. + """ + +class BpeTrainer(Trainer): + """ + Trainer capable of training a BPE model + + Args: + vocab_size (:obj:`int`, `optional`): + The size of the final vocabulary, including all tokens and alphabet. + + min_frequency (:obj:`int`, `optional`): + The minimum frequency a pair should have in order to be merged. + + show_progress (:obj:`bool`, `optional`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + + limit_alphabet (:obj:`int`, `optional`): + The maximum different characters to keep in the alphabet. + + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + + continuing_subword_prefix (:obj:`str`, `optional`): + A prefix to be used for every subword that is not a beginning-of-word. + + end_of_word_suffix (:obj:`str`, `optional`): + A suffix to be used for every subword that is a end-of-word. + + max_token_length (:obj:`int`, `optional`): + Prevents creating tokens longer than the specified size. + This can help with reducing polluting your vocabulary with + highly repetitive tokens like `======` for wikipedia + + """ + def __init__( + self, + vocab_size=30000, + min_frequency=0, + show_progress=True, + special_tokens=[], + limit_alphabet=None, + initial_alphabet=[], + continuing_subword_prefix=None, + end_of_word_suffix=None, + max_token_length=None, + words={}, + ): + pass + +class UnigramTrainer(Trainer): + """ + Trainer capable of training a Unigram model + + Args: + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + + show_progress (:obj:`bool`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`): + A list of special tokens the model should know of. + + initial_alphabet (:obj:`List[str]`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + + shrinking_factor (:obj:`float`): + The shrinking factor used at each step of the training to prune the + vocabulary. + + unk_token (:obj:`str`): + The token used for out-of-vocabulary tokens. + + max_piece_length (:obj:`int`): + The maximum length of a given token. + + n_sub_iterations (:obj:`int`): + The number of iterations of the EM algorithm to perform before + pruning the vocabulary. + """ + def __init__( + self, + vocab_size=8000, + show_progress=True, + special_tokens=[], + initial_alphabet=[], + shrinking_factor=0.75, + unk_token=None, + max_piece_length=16, + n_sub_iterations=2, + ): + pass + +class WordLevelTrainer(Trainer): + """ + Trainer capable of training a WorldLevel model + + Args: + vocab_size (:obj:`int`, `optional`): + The size of the final vocabulary, including all tokens and alphabet. + + min_frequency (:obj:`int`, `optional`): + The minimum frequency a pair should have in order to be merged. + + show_progress (:obj:`bool`, `optional`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`): + A list of special tokens the model should know of. + """ + def __init__(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[]): + pass + +class WordPieceTrainer(Trainer): + """ + Trainer capable of training a WordPiece model + + Args: + vocab_size (:obj:`int`, `optional`): + The size of the final vocabulary, including all tokens and alphabet. + + min_frequency (:obj:`int`, `optional`): + The minimum frequency a pair should have in order to be merged. + + show_progress (:obj:`bool`, `optional`): + Whether to show progress bars while training. + + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + + limit_alphabet (:obj:`int`, `optional`): + The maximum different characters to keep in the alphabet. + + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + + continuing_subword_prefix (:obj:`str`, `optional`): + A prefix to be used for every subword that is not a beginning-of-word. + + end_of_word_suffix (:obj:`str`, `optional`): + A suffix to be used for every subword that is a end-of-word. + """ + def __init__( + self, + vocab_size=30000, + min_frequency=0, + show_progress=True, + special_tokens=[], + limit_alphabet=None, + initial_alphabet=[], + continuing_subword_prefix="##", + end_of_word_suffix=None, + ): + pass diff --git a/venv/lib/python3.13/site-packages/tokenizers/trainers/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/tokenizers/trainers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc052fac8ffc71272848296a4ecccec3f8c54b78 Binary files /dev/null and b/venv/lib/python3.13/site-packages/tokenizers/trainers/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/transformers/models/__pycache__/__init__.cpython-313.pyc b/venv/lib/python3.13/site-packages/transformers/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faa2560f95fa6aa812a3885450d80a8c837e08b4 Binary files /dev/null and b/venv/lib/python3.13/site-packages/transformers/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/venv/lib/python3.13/site-packages/transformers/models/aimv2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/aimv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c013363996040eb995d7924b20e8f86d6b318a83 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/aimv2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_aimv2 import * + from .modeling_aimv2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/aimv2/configuration_aimv2.py b/venv/lib/python3.13/site-packages/transformers/models/aimv2/configuration_aimv2.py new file mode 100644 index 0000000000000000000000000000000000000000..adab18c7447a803117cb0d89abdf38b4dc96167f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/aimv2/configuration_aimv2.py @@ -0,0 +1,284 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_aimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Aimv2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Aimv2VisionModel`]. It is used to instantiate a + AIMv2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the AIMv2 + [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2816): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the Linear layers or Not. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the for initializing all weight matrices. + use_head (`str`, *optional*, defaults to `True`): + Whether to use Attention Pooling Head or Not. + is_native (`str`, *optional*, defaults to `False`): + Whether to use ckpt trained for image native resolution or not. + Example: + + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration + >>> configuration = Aimv2VisionConfig() + + >>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration + >>> model = Aimv2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "aimv2_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + qkv_bias: bool = False, + mlp_bias: bool = False, + hidden_act: str = "silu", + initializer_range: float = 0.02, + use_head: bool = True, + is_native: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + + self.use_head = use_head + self.initializer_range = initializer_range + self.mlp_bias = mlp_bias + self.qkv_bias = qkv_bias + self.rms_norm_eps = rms_norm_eps + self.is_native = is_native + + +class Aimv2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Aimv2TextModel`]. It is used to instantiate a + AIMv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the AIMv2 + [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the AIMv2 text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Aimv2Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 6): + Number of attention heads for each attention layer in the Transformer encoder. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the Linear layers or Not. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the for initializing all weight matrices. + """ + + model_type = "aimv2_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size: int = 49408, + hidden_size: int = 768, + intermediate_size: int = 2048, + num_hidden_layers: int = 12, + num_attention_heads: int = 6, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + qkv_bias: bool = False, + mlp_bias: bool = False, + hidden_act: str = "silu", + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = None, + eos_token_id: int = 49407, + max_position_embeddings: int = 77, + initializer_range: bool = 0.02, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + + self.initializer_range = initializer_range + self.mlp_bias = mlp_bias + self.qkv_bias = qkv_bias + self.rms_norm_eps = rms_norm_eps + + +class Aimv2Config(PretrainedConfig): + r""" + [`Aimv2Config`] is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to + instantiate a AIMv2 model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the AIMv2 + [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Aimv2TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Aimv2VisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Aimv2Config, Aimv2Model + + >>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration + >>> configuration = Aimv2Config() + + >>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration + >>> model = Aimv2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig + >>> from transformers import Aimv2TextConfig, Aimv2VisionConfig + + >>> # Initializing a AIMv2Text and AIMv2Vision configuration + >>> config_text = Aimv2TextConfig() + >>> config_vision = Aimv2VisionConfig() + + >>> config = Aimv2Config(text_config=config_text, vision_config=config_vision) + ```""" + + model_type = "aimv2" + sub_configs = {"text_config": Aimv2TextConfig, "vision_config": Aimv2VisionConfig} + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `Aimv2TextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `Aimv2VisionConfig` with default values.") + + self.text_config = Aimv2TextConfig(**text_config) + self.vision_config = Aimv2VisionConfig(**vision_config) + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.max_logit_scale = 100.0 + + +__all__ = ["Aimv2Config", "Aimv2VisionConfig", "Aimv2TextConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/aimv2/modeling_aimv2.py b/venv/lib/python3.13/site-packages/transformers/models/aimv2/modeling_aimv2.py new file mode 100644 index 0000000000000000000000000000000000000000..c29270b5687db779ecffa3bc90fb9955b453ac05 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/aimv2/modeling_aimv2.py @@ -0,0 +1,744 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_aimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig + + +@dataclass +@auto_docstring +class Aimv2Output(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`Aimv2TextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`Aimv2VisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`Aimv2TextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Aimv2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@use_kernel_forward_from_hub("RMSNorm") +class Aimv2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Aimv2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Aimv2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Aimv2VisionEmbeddings(nn.Module): + def __init__(self, config: Aimv2VisionConfig): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.patch_embed = nn.Conv2d( + config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size + ) + self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + num_patches = (config.image_size // config.patch_size) ** 2 + if not self.config.is_native: + self.position_embedding = nn.Embedding(num_patches, config.hidden_size) + self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) + + @staticmethod + def build_2d_sincos_position_embedding( + height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ) -> torch.Tensor: + grid_w = torch.arange(int(width), dtype=dtype, device=device) + grid_h = torch.arange(int(height), dtype=dtype, device=device) + grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy") + + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim + omega = 1.0 / (temperature**omega) + + out_h = grid_h.flatten()[..., None] @ omega[None, :] + out_w = grid_w.flatten()[..., None] @ omega[None, :] + + return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :] + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + _, _, height, width = pixel_values.size() + hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) + hidden_states = self.rms_norm(hidden_states) + + if self.config.is_native: + pos_embed = self.build_2d_sincos_position_embedding( + height // self.patch_size, + width // self.patch_size, + embed_dim=self.config.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + else: + pos_embed = self.position_embedding(self.position_ids) + + hidden_states = hidden_states + pos_embed + return hidden_states + + +class Aimv2TextEmbeddings(nn.Module): + def __init__(self, config: Aimv2TextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Aimv2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Aimv2EncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Aimv2VisionConfig): + super().__init__() + self.attention = Aimv2Attention(config) + self.ffn = Aimv2MLP(config) + self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + norm_hidden_states = self.rms_norm1(hidden_states) + attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs) + + hidden_states = hidden_states + attn_output + norm_hidden_states = self.rms_norm2(hidden_states) + mlp_output = self.ffn(norm_hidden_states) + + hidden_states = hidden_states + mlp_output + return hidden_states + + +class Aimv2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Aimv2EncoderLayer`]. + + Args: + config: Aimv2Config + """ + + def __init__(self, config: Aimv2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class Aimv2AttentionPoolingHead(nn.Module): + def __init__(self, config: Aimv2VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) + self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, hidden_dim = hidden_states.shape + + cls_token = self.cls_token.expand(batch_size, -1, -1) + + key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads) + value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads) + query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads) + + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + attn_output = F.scaled_dot_product_attention(query, key, value) + + attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim) + attn_output = attn_output.mean(dim=1) + + output = self.output_proj(attn_output) + return output + + +@auto_docstring +class Aimv2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. The model is only intended for inference and doesn't support finetuning. + """ + + config: Aimv2Config + base_model_prefix = "aimv2" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Aimv2EncoderLayer", + "Aimv2AttentionPoolingHead", + "Aimv2VisionEmbeddings", + "Aimv2TextEmbeddings", + ] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + + def _init_weights(self, module): + super()._init_weights(module) + if hasattr(module, "logit_scale"): + if isinstance(module.logit_scale, nn.Parameter): + module.logit_scale.data.fill_(math.log(1 / 0.07)) + elif isinstance(module, Aimv2AttentionPoolingHead): + module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range) + + +@auto_docstring( + custom_intro=""" + The Vision model from AIMv2 without any head or projection on top. + """ +) +class Aimv2VisionModel(Aimv2PreTrainedModel): + config: Aimv2VisionConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Aimv2EncoderLayer, + "attentions": Aimv2Attention, + } + + def __init__(self, config: Aimv2VisionConfig): + super().__init__(config) + self.config = config + self.embeddings = Aimv2VisionEmbeddings(config) + self.encoder = Aimv2Encoder(config) + # The only change from SiglipVisionTransformer is, layernorm -> rms_norm. + self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + self.use_head = config.use_head + if self.use_head: + self.head = Aimv2AttentionPoolingHead(config) + + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embed + + @deprecate_kwarg("attention_mask", version="v4.58.0") + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Siglip2VisionModel + + >>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native") + >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + hidden_states = self.embeddings(pixel_values) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.rms_norm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + ) + + +@auto_docstring( + custom_intro=""" + The text model from AIMv2 without any head or projection on top. + """ +) +class Aimv2TextModel(Aimv2PreTrainedModel): + main_input_name = "input_ids" + + _can_record_outputs = { + "hidden_states": Aimv2EncoderLayer, + "attentions": Aimv2Attention, + } + + def __init__(self, config: Aimv2TextConfig): + super().__init__(config) + self.config = config + self.embeddings = Aimv2TextEmbeddings(config) + self.encoder = Aimv2Encoder(config) + self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + self.eos_token_id = config.eos_token_id + + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.embeddings.token_embedding = value + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + input_ids, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + hidden_states = self.embeddings(input_ids) + batch_size, seq_len, _ = hidden_states.shape + + cache_position = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device) + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + if attention_mask is not None: + attention_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=None, + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.rms_norm(last_hidden_state) + + # Get pooled output + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1), + ] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + +def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor: + """ + This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make + model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566 + """ + square_tensor = torch.pow(tensor, 2) + sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True) + normed_tensor = torch.pow(sum_tensor, 0.5) + return normed_tensor + + +@auto_docstring +class Aimv2Model(Aimv2PreTrainedModel): + config: Aimv2Config + _no_split_modules = ["Aimv2TextEmbeddings", "Aimv2EncoderLayer", "Aimv2VisionEmbeddings"] + _supports_flash_attn = True + + def __init__(self, config: Aimv2Config): + super().__init__(config) + + self.projection_dim = config.projection_dim + self.vision_embed_dim = config.vision_config.hidden_size + self.text_embed_dim = config.text_config.hidden_size + + self.vision_model = Aimv2VisionModel._from_config(config.vision_config) + self.text_model = Aimv2TextModel._from_config(config.text_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + self.max_log_logit_scale = math.log(config.max_logit_scale) + + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`Aimv2TextModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, Aimv2Model + + >>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/aimv2-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ```""" + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + pooled_output = text_outputs.pooler_output + text_features = self.text_projection(pooled_output) + + return text_features + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`Aimv2VisionModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, Aimv2Model + >>> from transformers.image_utils import load_image + + >>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/aimv2-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.inference_mode(): + ... image_features = model.get_image_features(**inputs) + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + pooled_output = vision_outputs.pooler_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Aimv2Output: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Aimv2Model + + >>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit") + >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + **kwargs, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + image_embeds = vision_outputs.pooler_output + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs.pooler_output + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / _get_vector_norm(image_embeds) + text_embeds = text_embeds / _get_vector_norm(text_embeds) + + logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp().to(text_embeds.device) + logits_per_text = (logit_scale * text_embeds) @ image_embeds.t() + logits_per_image = logits_per_text.t() + + return Aimv2Output( + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["Aimv2VisionModel", "Aimv2Model", "Aimv2PreTrainedModel", "Aimv2TextModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/aimv2/modular_aimv2.py b/venv/lib/python3.13/site-packages/transformers/models/aimv2/modular_aimv2.py new file mode 100644 index 0000000000000000000000000000000000000000..45a19d27d041649143dc970737502c5267080d05 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/aimv2/modular_aimv2.py @@ -0,0 +1,707 @@ +# coding=utf-8 +# Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pytorch implementation of AIMv2 Model""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm +from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm +from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoder, SiglipOutput + + +class Aimv2VisionConfig(SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Aimv2VisionModel`]. It is used to instantiate a + AIMv2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the AIMv2 + [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2816): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the Linear layers or Not. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the for initializing all weight matrices. + use_head (`str`, *optional*, defaults to `True`): + Whether to use Attention Pooling Head or Not. + is_native (`str`, *optional*, defaults to `False`): + Whether to use ckpt trained for image native resolution or not. + Example: + + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration + >>> configuration = Aimv2VisionConfig() + + >>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration + >>> model = Aimv2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + qkv_bias: bool = False, + mlp_bias: bool = False, + hidden_act: str = "silu", + initializer_range: float = 0.02, + use_head: bool = True, + is_native: bool = False, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_act=hidden_act, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + qkv_bias=qkv_bias, + **kwargs, + ) + + self.use_head = use_head + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.qkv_bias = qkv_bias + self.rms_norm_eps = rms_norm_eps + self.is_native = is_native + + del self.layer_norm_eps + + +class Aimv2TextConfig(SiglipTextConfig): + r""" + This is the configuration class to store the configuration of a [`Aimv2TextModel`]. It is used to instantiate a + AIMv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the AIMv2 + [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the AIMv2 text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Aimv2Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 6): + Number of attention heads for each attention layer in the Transformer encoder. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the Linear layers or Not. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the for initializing all weight matrices. + """ + + def __init__( + self, + vocab_size: int = 49408, + hidden_size: int = 768, + intermediate_size: int = 2048, + num_hidden_layers: int = 12, + num_attention_heads: int = 6, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + qkv_bias: bool = False, + mlp_bias: bool = False, + hidden_act: str = "silu", + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = None, + eos_token_id: int = 49407, + max_position_embeddings: int = 77, + initializer_range: bool = 0.02, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.qkv_bias = qkv_bias + self.rms_norm_eps = rms_norm_eps + + del self.bos_token_id + del self.pad_token_id + del self.projection_size + del self.layer_norm_eps + + +class Aimv2Config(SiglipConfig): + r""" + [`Aimv2Config`] is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to + instantiate a AIMv2 model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the AIMv2 + [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Aimv2TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Aimv2VisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Aimv2Config, Aimv2Model + + >>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration + >>> configuration = Aimv2Config() + + >>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration + >>> model = Aimv2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig + >>> from transformers import Aimv2TextConfig, Aimv2VisionConfig + + >>> # Initializing a AIMv2Text and AIMv2Vision configuration + >>> config_text = Aimv2TextConfig() + >>> config_vision = Aimv2VisionConfig() + + >>> config = Aimv2Config(text_config=config_text, vision_config=config_vision) + ```""" + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + super().__init__(text_config, vision_config, **kwargs) + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.max_logit_scale = 100.0 + + del self.initializer_factor + + +class Aimv2Output(SiglipOutput): + pass + + +class Aimv2RMSNorm(LlamaRMSNorm): + pass + + +class Aimv2MLP(LlamaMLP): + pass + + +class Aimv2VisionEmbeddings(nn.Module): + def __init__(self, config: Aimv2VisionConfig): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.patch_embed = nn.Conv2d( + config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size + ) + self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + num_patches = (config.image_size // config.patch_size) ** 2 + if not self.config.is_native: + self.position_embedding = nn.Embedding(num_patches, config.hidden_size) + self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) + + @staticmethod + def build_2d_sincos_position_embedding( + height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ) -> torch.Tensor: + grid_w = torch.arange(int(width), dtype=dtype, device=device) + grid_h = torch.arange(int(height), dtype=dtype, device=device) + grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy") + + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim + omega = 1.0 / (temperature**omega) + + out_h = grid_h.flatten()[..., None] @ omega[None, :] + out_w = grid_w.flatten()[..., None] @ omega[None, :] + + return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :] + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + _, _, height, width = pixel_values.size() + hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) + hidden_states = self.rms_norm(hidden_states) + + if self.config.is_native: + pos_embed = self.build_2d_sincos_position_embedding( + height // self.patch_size, + width // self.patch_size, + embed_dim=self.config.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + else: + pos_embed = self.position_embedding(self.position_ids) + + hidden_states = hidden_states + pos_embed + return hidden_states + + +class Aimv2TextEmbeddings(CLIPTextEmbeddings): + pass + + +class Aimv2Attention(SiglipAttention): + def __init__(self, config): + super().__init__(config) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias) + + +class Aimv2EncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Aimv2VisionConfig): + super().__init__() + self.attention = Aimv2Attention(config) + self.ffn = Aimv2MLP(config) + self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + norm_hidden_states = self.rms_norm1(hidden_states) + attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs) + + hidden_states = hidden_states + attn_output + norm_hidden_states = self.rms_norm2(hidden_states) + mlp_output = self.ffn(norm_hidden_states) + + hidden_states = hidden_states + mlp_output + return hidden_states + + +class Aimv2Encoder(SiglipEncoder): + pass + + +class Aimv2AttentionPoolingHead(nn.Module): + def __init__(self, config: Aimv2VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) + self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, hidden_dim = hidden_states.shape + + cls_token = self.cls_token.expand(batch_size, -1, -1) + + key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads) + value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads) + query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads) + + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + attn_output = F.scaled_dot_product_attention(query, key, value) + + attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim) + attn_output = attn_output.mean(dim=1) + + output = self.output_proj(attn_output) + return output + + +@auto_docstring +class Aimv2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. The model is only intended for inference and doesn't support finetuning. + """ + + config: Aimv2Config + base_model_prefix = "aimv2" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Aimv2EncoderLayer", + "Aimv2AttentionPoolingHead", + "Aimv2VisionEmbeddings", + "Aimv2TextEmbeddings", + ] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + + def _init_weights(self, module): + super()._init_weights(module) + if hasattr(module, "logit_scale"): + if isinstance(module.logit_scale, nn.Parameter): + module.logit_scale.data.fill_(math.log(1 / 0.07)) + elif isinstance(module, Aimv2AttentionPoolingHead): + module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range) + + +@auto_docstring( + custom_intro=""" + The Vision model from AIMv2 without any head or projection on top. + """ +) +class Aimv2VisionModel(Aimv2PreTrainedModel): + config: Aimv2VisionConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Aimv2EncoderLayer, + "attentions": Aimv2Attention, + } + + def __init__(self, config: Aimv2VisionConfig): + super().__init__(config) + self.config = config + self.embeddings = Aimv2VisionEmbeddings(config) + self.encoder = Aimv2Encoder(config) + # The only change from SiglipVisionTransformer is, layernorm -> rms_norm. + self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + self.use_head = config.use_head + if self.use_head: + self.head = Aimv2AttentionPoolingHead(config) + + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embed + + @deprecate_kwarg("attention_mask", version="v4.58.0") + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Siglip2VisionModel + + >>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native") + >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + hidden_states = self.embeddings(pixel_values) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.rms_norm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + ) + + +@auto_docstring( + custom_intro=""" + The text model from AIMv2 without any head or projection on top. + """ +) +class Aimv2TextModel(Aimv2PreTrainedModel): + main_input_name = "input_ids" + + _can_record_outputs = { + "hidden_states": Aimv2EncoderLayer, + "attentions": Aimv2Attention, + } + + def __init__(self, config: Aimv2TextConfig): + super().__init__(config) + self.config = config + self.embeddings = Aimv2TextEmbeddings(config) + self.encoder = Aimv2Encoder(config) + self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps) + + self.eos_token_id = config.eos_token_id + + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.embeddings.token_embedding = value + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + input_ids, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + hidden_states = self.embeddings(input_ids) + batch_size, seq_len, _ = hidden_states.shape + + cache_position = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device) + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + if attention_mask is not None: + attention_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=None, + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.rms_norm(last_hidden_state) + + # Get pooled output + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1), + ] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + +@auto_docstring +class Aimv2Model(CLIPModel): + _supports_flash_attn = True + + def __init__(self, config: Aimv2Config): + PreTrainedModel.__init__(self, config) + + self.projection_dim = config.projection_dim + self.vision_embed_dim = config.vision_config.hidden_size + self.text_embed_dim = config.text_config.hidden_size + + self.vision_model = Aimv2VisionModel._from_config(config.vision_config) + self.text_model = Aimv2TextModel._from_config(config.text_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + self.max_log_logit_scale = math.log(config.max_logit_scale) + + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Aimv2Output: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Aimv2Model + + >>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit") + >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + **kwargs, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + image_embeds = vision_outputs.pooler_output + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs.pooler_output + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / _get_vector_norm(image_embeds) + text_embeds = text_embeds / _get_vector_norm(text_embeds) + + logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp().to(text_embeds.device) + logits_per_text = (logit_scale * text_embeds) @ image_embeds.t() + logits_per_image = logits_per_text.t() + + return Aimv2Output( + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = [ + "Aimv2Config", + "Aimv2VisionConfig", + "Aimv2TextConfig", + "Aimv2VisionModel", + "Aimv2Model", + "Aimv2PreTrainedModel", + "Aimv2TextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/align/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa64dfb6064b10e820fca01e9e632aa7ecd68c6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/align/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_align import * + from .modeling_align import * + from .processing_align import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/align/configuration_align.py b/venv/lib/python3.13/site-packages/transformers/models/align/configuration_align.py new file mode 100644 index 0000000000000000000000000000000000000000..b924d85a6ca6b4a256eedf55221fe379aff5a87a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/align/configuration_align.py @@ -0,0 +1,331 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ALIGN model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class AlignTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AlignTextModel`]. It is used to instantiate a + ALIGN text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the ALIGN + [kakaobrain/align-base](https://huggingface.co/kakaobrain/align-base) architecture. The default values here are + copied from BERT. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Align Text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`AlignTextModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`AlignTextModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + + Example: + + ```python + >>> from transformers import AlignTextConfig, AlignTextModel + + >>> # Initializing a AlignTextConfig with kakaobrain/align-base style configuration + >>> configuration = AlignTextConfig() + + >>> # Initializing a AlignTextModel (with random weights) from the kakaobrain/align-base style configuration + >>> model = AlignTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "align_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.pad_token_id = pad_token_id + + +class AlignVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AlignVisionModel`]. It is used to instantiate a + ALIGN vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the ALIGN + [kakaobrain/align-base](https://huggingface.co/kakaobrain/align-base) architecture. The default values are copied + from EfficientNet (efficientnet-b7) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 600): + The input image size. + width_coefficient (`float`, *optional*, defaults to 2.0): + Scaling coefficient for network width at each stage. + depth_coefficient (`float`, *optional*, defaults to 3.1): + Scaling coefficient for network depth at each stage. + depth_divisor `int`, *optional*, defaults to 8): + A unit of network width. + kernel_sizes (`list[int]`, *optional*, defaults to `[3, 3, 5, 3, 5, 5, 3]`): + List of kernel sizes to be used in each block. + in_channels (`list[int]`, *optional*, defaults to `[32, 16, 24, 40, 80, 112, 192]`): + List of input channel sizes to be used in each block for convolutional layers. + out_channels (`list[int]`, *optional*, defaults to `[16, 24, 40, 80, 112, 192, 320]`): + List of output channel sizes to be used in each block for convolutional layers. + depthwise_padding (`list[int]`, *optional*, defaults to `[]`): + List of block indices with square padding. + strides (`list[int]`, *optional*, defaults to `[1, 2, 2, 2, 1, 2, 1]`): + List of stride sizes to be used in each block for convolutional layers. + num_block_repeats (`list[int]`, *optional*, defaults to `[1, 2, 2, 3, 3, 4, 1]`): + List of the number of times each block is to repeated. + expand_ratios (`list[int]`, *optional*, defaults to `[1, 6, 6, 6, 6, 6, 6]`): + List of scaling coefficient of each block. + squeeze_expansion_ratio (`float`, *optional*, defaults to 0.25): + Squeeze expansion ratio. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`, + `"selu", `"gelu_new"`, `"silu"` and `"mish"` are supported. + hidden_dim (`int`, *optional*, defaults to 1280): + The hidden dimension of the layer before the classification head. + pooling_type (`str` or `function`, *optional*, defaults to `"mean"`): + Type of final pooling to be applied before the dense classification head. Available options are [`"mean"`, + `"max"`] + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + batch_norm_eps (`float`, *optional*, defaults to 1e-3): + The epsilon used by the batch normalization layers. + batch_norm_momentum (`float`, *optional*, defaults to 0.99): + The momentum used by the batch normalization layers. + drop_connect_rate (`float`, *optional*, defaults to 0.2): + The drop rate for skip connections. + + Example: + + ```python + >>> from transformers import AlignVisionConfig, AlignVisionModel + + >>> # Initializing a AlignVisionConfig with kakaobrain/align-base style configuration + >>> configuration = AlignVisionConfig() + + >>> # Initializing a AlignVisionModel (with random weights) from the kakaobrain/align-base style configuration + >>> model = AlignVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "align_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + num_channels: int = 3, + image_size: int = 600, + width_coefficient: float = 2.0, + depth_coefficient: float = 3.1, + depth_divisor: int = 8, + kernel_sizes: list[int] = [3, 3, 5, 3, 5, 5, 3], + in_channels: list[int] = [32, 16, 24, 40, 80, 112, 192], + out_channels: list[int] = [16, 24, 40, 80, 112, 192, 320], + depthwise_padding: list[int] = [], + strides: list[int] = [1, 2, 2, 2, 1, 2, 1], + num_block_repeats: list[int] = [1, 2, 2, 3, 3, 4, 1], + expand_ratios: list[int] = [1, 6, 6, 6, 6, 6, 6], + squeeze_expansion_ratio: float = 0.25, + hidden_act: str = "swish", + hidden_dim: int = 2560, + pooling_type: str = "mean", + initializer_range: float = 0.02, + batch_norm_eps: float = 0.001, + batch_norm_momentum: float = 0.99, + drop_connect_rate: float = 0.2, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_channels = num_channels + self.image_size = image_size + self.width_coefficient = width_coefficient + self.depth_coefficient = depth_coefficient + self.depth_divisor = depth_divisor + self.kernel_sizes = kernel_sizes + self.in_channels = in_channels + self.out_channels = out_channels + self.depthwise_padding = depthwise_padding + self.strides = strides + self.num_block_repeats = num_block_repeats + self.expand_ratios = expand_ratios + self.squeeze_expansion_ratio = squeeze_expansion_ratio + self.hidden_act = hidden_act + self.hidden_dim = hidden_dim + self.pooling_type = pooling_type + self.initializer_range = initializer_range + self.batch_norm_eps = batch_norm_eps + self.batch_norm_momentum = batch_norm_momentum + self.drop_connect_rate = drop_connect_rate + self.num_hidden_layers = sum(num_block_repeats) * 4 + + +class AlignConfig(PretrainedConfig): + r""" + [`AlignConfig`] is the configuration class to store the configuration of a [`AlignModel`]. It is used to + instantiate a ALIGN model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the ALIGN + [kakaobrain/align-base](https://huggingface.co/kakaobrain/align-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`AlignTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`AlignVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 640): + Dimensionality of text and vision projection layers. + temperature_init_value (`float`, *optional*, defaults to 1.0): + The initial value of the *temperature* parameter. Default is used as per the original ALIGN implementation. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import AlignConfig, AlignModel + + >>> # Initializing a AlignConfig with kakaobrain/align-base style configuration + >>> configuration = AlignConfig() + + >>> # Initializing a AlignModel (with random weights) from the kakaobrain/align-base style configuration + >>> model = AlignModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a AlignConfig from a AlignTextConfig and a AlignVisionConfig + >>> from transformers import AlignTextConfig, AlignVisionConfig + + >>> # Initializing ALIGN Text and Vision configurations + >>> config_text = AlignTextConfig() + >>> config_vision = AlignVisionConfig() + + >>> config = AlignConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "align" + sub_configs = {"text_config": AlignTextConfig, "vision_config": AlignVisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + projection_dim=640, + temperature_init_value=1.0, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the AlignTextConfig with default values.") + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. Initializing the AlignVisionConfig with default values.") + + self.text_config = AlignTextConfig(**text_config) + self.vision_config = AlignVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.temperature_init_value = temperature_init_value + self.initializer_range = initializer_range + + +__all__ = ["AlignTextConfig", "AlignVisionConfig", "AlignConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/align/modeling_align.py b/venv/lib/python3.13/site-packages/transformers/models/align/modeling_align.py new file mode 100644 index 0000000000000000000000000000000000000000..c226a3b36ac6fb2d7f0621c68ffeb7acfe36c0b2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/align/modeling_align.py @@ -0,0 +1,1292 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ALIGN model.""" + +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithNoAttention, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndNoAttention, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging +from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + """ +) +class AlignVisionModelOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +class AlignTextModelOutput(ModelOutput): + r""" + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring +class AlignOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The output of [`AlignVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`AlignTextModel`]. + vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`): + The output of the [`AlignVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device), label_smoothing=0.1) + + +def align_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +# Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet->AlignVision +def round_filters(config: AlignVisionConfig, num_channels: int): + r""" + Round number of filters based on depth multiplier. + """ + divisor = config.depth_divisor + num_channels *= config.width_coefficient + new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor) + + # Make sure that round down does not go down by more than 10%. + if new_dim < 0.9 * num_channels: + new_dim += divisor + + return int(new_dim) + + +# Copied from transformers.models.efficientnet.modeling_efficientnet.correct_pad +def correct_pad(kernel_size: Union[int, tuple], adjust: bool = True): + r""" + Utility function to get the tuple padding value for the depthwise convolution. + + Args: + kernel_size (`int` or `tuple`): + Kernel size of the convolution layers. + adjust (`bool`, *optional*, defaults to `True`): + Adjusts padding value to apply to right and bottom sides of the input. + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + if adjust: + return (correct[1] - 1, correct[1], correct[0] - 1, correct[0]) + else: + return (correct[1], correct[1], correct[0], correct[0]) + + +# Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetEmbeddings with EfficientNet->AlignVision +class AlignVisionEmbeddings(nn.Module): + r""" + A module that corresponds to the stem module of the original work. + """ + + def __init__(self, config: AlignVisionConfig): + super().__init__() + + self.out_dim = round_filters(config, 32) + self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1)) + self.convolution = nn.Conv2d( + config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False + ) + self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + features = self.padding(pixel_values) + features = self.convolution(features) + features = self.batchnorm(features) + features = self.activation(features) + + return features + + +# Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseConv2d with EfficientNet->AlignVision +class AlignVisionDepthwiseConv2d(nn.Conv2d): + def __init__( + self, + in_channels, + depth_multiplier=1, + kernel_size=3, + stride=1, + padding=0, + dilation=1, + bias=True, + padding_mode="zeros", + ): + out_channels = in_channels * depth_multiplier + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + padding_mode=padding_mode, + ) + + +# Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetExpansionLayer with EfficientNet->AlignVision +class AlignVisionExpansionLayer(nn.Module): + r""" + This corresponds to the expansion phase of each block in the original implementation. + """ + + def __init__(self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int): + super().__init__() + self.expand_conv = nn.Conv2d( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + padding="same", + bias=False, + ) + self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps) + self.expand_act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + # Expand phase + hidden_states = self.expand_conv(hidden_states) + hidden_states = self.expand_bn(hidden_states) + hidden_states = self.expand_act(hidden_states) + + return hidden_states + + +# Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseLayer with EfficientNet->AlignVision +class AlignVisionDepthwiseLayer(nn.Module): + r""" + This corresponds to the depthwise convolution phase of each block in the original implementation. + """ + + def __init__( + self, + config: AlignVisionConfig, + in_dim: int, + stride: int, + kernel_size: int, + adjust_padding: bool, + ): + super().__init__() + self.stride = stride + conv_pad = "valid" if self.stride == 2 else "same" + padding = correct_pad(kernel_size, adjust=adjust_padding) + + self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding) + self.depthwise_conv = AlignVisionDepthwiseConv2d( + in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False + ) + self.depthwise_norm = nn.BatchNorm2d( + num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.depthwise_act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + # Depthwise convolution + if self.stride == 2: + hidden_states = self.depthwise_conv_pad(hidden_states) + + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.depthwise_norm(hidden_states) + hidden_states = self.depthwise_act(hidden_states) + + return hidden_states + + +# Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetSqueezeExciteLayer with EfficientNet->AlignVision +class AlignVisionSqueezeExciteLayer(nn.Module): + r""" + This corresponds to the Squeeze and Excitement phase of each block in the original implementation. + """ + + def __init__(self, config: AlignVisionConfig, in_dim: int, expand_dim: int, expand: bool = False): + super().__init__() + self.dim = expand_dim if expand else in_dim + self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio)) + + self.squeeze = nn.AdaptiveAvgPool2d(output_size=1) + self.reduce = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim_se, + kernel_size=1, + padding="same", + ) + self.expand = nn.Conv2d( + in_channels=self.dim_se, + out_channels=self.dim, + kernel_size=1, + padding="same", + ) + self.act_reduce = ACT2FN[config.hidden_act] + self.act_expand = nn.Sigmoid() + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + inputs = hidden_states + hidden_states = self.squeeze(hidden_states) + hidden_states = self.reduce(hidden_states) + hidden_states = self.act_reduce(hidden_states) + + hidden_states = self.expand(hidden_states) + hidden_states = self.act_expand(hidden_states) + hidden_states = torch.mul(inputs, hidden_states) + + return hidden_states + + +class AlignVisionFinalBlockLayer(nn.Module): + r""" + This corresponds to the final phase of each block in the original implementation. + """ + + def __init__( + self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool + ): + super().__init__() + self.apply_dropout = stride == 1 and not id_skip + self.project_conv = nn.Conv2d( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + padding="same", + bias=False, + ) + self.project_bn = nn.BatchNorm2d( + num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.dropout = nn.Dropout(p=drop_rate) + + def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor: + hidden_states = self.project_conv(hidden_states) + hidden_states = self.project_bn(hidden_states) + + if self.apply_dropout: + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + embeddings + + return hidden_states + + +class AlignVisionBlock(nn.Module): + r""" + This corresponds to the block module of original the EfficientNet vision encoder implementation. + + Args: + config ([`AlignVisionConfig`]): + Model configuration class. + in_dim (`int`): + Number of input channels. + out_dim (`int`): + Number of output channels. + stride (`int`): + Stride size to be used in convolution layers. + expand_ratio (`int`): + Expand ratio to set the output dimensions for the expansion and squeeze-excite layers. + kernel_size (`int`): + Kernel size for the depthwise convolution layer. + drop_rate (`float`): + Dropout rate to be used in the final phase of each block. + id_skip (`bool`): + Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase + of each block. Set to `True` for the first block of each stage. + adjust_padding (`bool`): + Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution + operation, set to `True` for inputs with odd input sizes. + """ + + def __init__( + self, + config: AlignVisionConfig, + in_dim: int, + out_dim: int, + stride: int, + expand_ratio: int, + kernel_size: int, + drop_rate: float, + id_skip: bool, + adjust_padding: bool, + ): + super().__init__() + self.expand_ratio = expand_ratio + self.expand = self.expand_ratio != 1 + expand_in_dim = in_dim * expand_ratio + + if self.expand: + self.expansion = AlignVisionExpansionLayer( + config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride + ) + + self.depthwise_conv = AlignVisionDepthwiseLayer( + config=config, + in_dim=expand_in_dim if self.expand else in_dim, + stride=stride, + kernel_size=kernel_size, + adjust_padding=adjust_padding, + ) + self.squeeze_excite = AlignVisionSqueezeExciteLayer( + config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand + ) + self.projection = AlignVisionFinalBlockLayer( + config=config, + in_dim=expand_in_dim if self.expand else in_dim, + out_dim=out_dim, + stride=stride, + drop_rate=drop_rate, + id_skip=id_skip, + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + embeddings = hidden_states + # Expansion and depthwise convolution phase + if self.expand_ratio != 1: + hidden_states = self.expansion(hidden_states) + hidden_states = self.depthwise_conv(hidden_states) + + # Squeeze and excite phase + hidden_states = self.squeeze_excite(hidden_states) + hidden_states = self.projection(embeddings, hidden_states) + return hidden_states + + +class AlignVisionEncoder(nn.Module): + r""" + Forward propagates the embeddings through each vision encoder (EfficientNet) block. + + Args: + config ([`AlignVisionConfig`]): + Model configuration class. + """ + + def __init__(self, config: AlignVisionConfig): + super().__init__() + self.depth_coefficient = config.depth_coefficient + + def round_repeats(repeats): + # Round number of block repeats based on depth multiplier. + return int(math.ceil(self.depth_coefficient * repeats)) + + num_base_blocks = len(config.in_channels) + num_blocks = sum(round_repeats(n) for n in config.num_block_repeats) + + curr_block_num = 0 + blocks = [] + for i in range(num_base_blocks): + in_dim = round_filters(config, config.in_channels[i]) + out_dim = round_filters(config, config.out_channels[i]) + stride = config.strides[i] + kernel_size = config.kernel_sizes[i] + expand_ratio = config.expand_ratios[i] + + for j in range(round_repeats(config.num_block_repeats[i])): + id_skip = j == 0 + stride = 1 if j > 0 else stride + in_dim = out_dim if j > 0 else in_dim + adjust_padding = curr_block_num not in config.depthwise_padding + drop_rate = config.drop_connect_rate * curr_block_num / num_blocks + + block = AlignVisionBlock( + config=config, + in_dim=in_dim, + out_dim=out_dim, + stride=stride, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + drop_rate=drop_rate, + id_skip=id_skip, + adjust_padding=adjust_padding, + ) + blocks.append(block) + curr_block_num += 1 + + self.blocks = nn.ModuleList(blocks) + + def forward( + self, + hidden_states: torch.FloatTensor, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> BaseModelOutputWithPoolingAndNoAttention: + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for block in self.blocks: + hidden_states = block(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +class AlignTextEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class AlignTextSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) + + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->AlignText +class AlignTextSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class AlignTextAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = AlignTextSelfAttention(config) + self.output = AlignTextSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->AlignText +class AlignTextIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->AlignText +class AlignTextOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class AlignTextLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = AlignTextAttention(config) + self.intermediate = AlignTextIntermediate(config) + self.output = AlignTextOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class AlignTextEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert -> AlignText +class AlignTextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class AlignPreTrainedModel(PreTrainedModel): + config: AlignConfig + base_model_prefix = "align" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, AlignModel): + nn.init.xavier_uniform_(module.text_projection.weight) + module.text_projection.bias.data.zero_() + module.temperature.data.fill_(self.config.temperature_init_value) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring( + custom_intro=""" + The text model from ALIGN without any head or projection on top. + """ +) +class AlignTextModel(AlignPreTrainedModel): + config: AlignTextConfig + _no_split_modules = ["AlignTextEmbeddings"] + + def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = AlignTextEmbeddings(config) + self.encoder = AlignTextEncoder(config) + + self.pooler = AlignTextPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + Examples: + + ```python + >>> from transformers import AutoTokenizer, AlignTextModel + + >>> model = AlignTextModel.from_pretrained("kakaobrain/align-base") + >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The vision model from ALIGN without any head or projection on top. + """ +) +class AlignVisionModel(AlignPreTrainedModel): + config: AlignVisionConfig + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + + def __init__(self, config: AlignVisionConfig): + super().__init__(config) + self.config = config + self.embeddings = AlignVisionEmbeddings(config) + self.encoder = AlignVisionEncoder(config) + + # Final pooling layer + if config.pooling_type == "mean": + self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True) + elif config.pooling_type == "max": + self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True) + else: + raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.convolution + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AlignVisionModel + + >>> model = AlignVisionModel.from_pretrained("kakaobrain/align-base") + >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values) + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + # Apply pooling + last_hidden_state = encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + # Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim) + pooled_output = pooled_output.reshape(pooled_output.shape[:2]) + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring +class AlignModel(AlignPreTrainedModel): + config: AlignConfig + + def __init__(self, config: AlignConfig): + super().__init__(config) + + if not isinstance(config.text_config, AlignTextConfig): + raise TypeError( + "config.text_config is expected to be of type AlignTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, AlignVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type AlignVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + + self.text_model = AlignTextModel(text_config) + self.vision_model = AlignVisionModel(vision_config) + + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim) + self.temperature = nn.Parameter(torch.tensor(self.config.temperature_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`AlignTextModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, AlignModel + + >>> model = AlignModel.from_pretrained("kakaobrain/align-base") + >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ```""" + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + last_hidden_state = text_outputs[0][:, 0, :] + text_features = self.text_projection(last_hidden_state) + + return text_features + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`AlignVisionModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AlignModel + >>> from transformers.image_utils import load_image + + >>> model = AlignModel.from_pretrained("kakaobrain/align-base") + >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.inference_mode(): + ... image_features = model.get_image_features(**inputs) + ```""" + vision_outputs = self.vision_model(pixel_values=pixel_values) + image_features = vision_outputs.pooler_output + return image_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, AlignOutput]: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AlignModel + >>> from transformers.image_utils import load_image + + >>> model = AlignModel.from_pretrained("kakaobrain/align-base") + >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor( + ... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True + ... ) + + >>> with torch.inference_mode(): + ... outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use ALIGN model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[0][:, 0, :] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) / self.temperature + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = align_loss(logits_per_text) + + return AlignOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["AlignPreTrainedModel", "AlignTextModel", "AlignVisionModel", "AlignModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/align/processing_align.py b/venv/lib/python3.13/site-packages/transformers/models/align/processing_align.py new file mode 100644 index 0000000000000000000000000000000000000000..fbca27b2ff397f808d956f97d3c93369924140ae --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/align/processing_align.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for ALIGN +""" + +from ...processing_utils import ProcessingKwargs, ProcessorMixin + + +class AlignProcessorKwargs(ProcessingKwargs, total=False): + # see processing_utils.ProcessingKwargs documentation for usage. + _defaults = { + "text_kwargs": { + "padding": "max_length", + "max_length": 64, + }, + } + + +class AlignProcessor(ProcessorMixin): + r""" + Constructs an ALIGN processor which wraps [`EfficientNetImageProcessor`] and + [`BertTokenizer`]/[`BertTokenizerFast`] into a single processor that inherits both the image processor and + tokenizer functionalities. See the [`~AlignProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more + information. + The preferred way of passing kwargs is as a dictionary per modality, see usage example below. + ```python + from transformers import AlignProcessor + from PIL import Image + model_id = "kakaobrain/align-base" + processor = AlignProcessor.from_pretrained(model_id) + + processor( + images=your_pil_image, + text=["What is that?"], + images_kwargs = {"crop_size": {"height": 224, "width": 224}}, + text_kwargs = {"padding": "do_not_pad"}, + common_kwargs = {"return_tensors": "pt"}, + ) + ``` + + Args: + image_processor ([`EfficientNetImageProcessor`]): + The image processor is a required input. + tokenizer ([`BertTokenizer`, `BertTokenizerFast`]): + The tokenizer is a required input. + + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "EfficientNetImageProcessor" + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + valid_processor_kwargs = AlignProcessorKwargs + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + +__all__ = ["AlignProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/altclip/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/altclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a30de8a2527567b521be3e9b60e99d025571cb18 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/altclip/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_altclip import * + from .modeling_altclip import * + from .processing_altclip import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/altclip/configuration_altclip.py b/venv/lib/python3.13/site-packages/transformers/models/altclip/configuration_altclip.py new file mode 100644 index 0000000000000000000000000000000000000000..474fc48081b5846695bb98cf89610bffb339584f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/altclip/configuration_altclip.py @@ -0,0 +1,372 @@ +# coding=utf-8 +# Copyright 2022 WenXiang ZhongzhiCheng LedellWu LiuGuang BoWenZhang and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AltCLIP model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class AltCLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AltCLIPTextModel`]. It is used to instantiate a + AltCLIP text model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the AltCLIP + [BAAI/AltCLIP](https://huggingface.co/BAAI/AltCLIP) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 250002): + Vocabulary size of the AltCLIP model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AltCLIPTextModel`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 514): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 1): + The vocabulary size of the `token_type_ids` passed when calling [`AltCLIPTextModel`] + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 0.02): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 1): The id of the *padding* token. + bos_token_id (`int`, *optional*, defaults to 0): The id of the *beginning-of-sequence* token. + eos_token_id (`Union[int, list[int]]`, *optional*, defaults to 2): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + project_dim (`int`, *optional*, defaults to 768): + The dimensions of the teacher model before the mapping layer. + + Examples: + + ```python + >>> from transformers import AltCLIPTextModel, AltCLIPTextConfig + + >>> # Initializing a AltCLIPTextConfig with BAAI/AltCLIP style configuration + >>> configuration = AltCLIPTextConfig() + + >>> # Initializing a AltCLIPTextModel (with random weights) from the BAAI/AltCLIP style configuration + >>> model = AltCLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "altclip_text_model" + + def __init__( + self, + vocab_size=250002, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=514, + type_vocab_size=1, + initializer_range=0.02, + initializer_factor=0.02, + layer_norm_eps=1e-05, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + project_dim=768, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.project_dim = project_dim + + +class AltCLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AltCLIPModel`]. It is used to instantiate an + AltCLIP model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the AltCLIP + [BAAI/AltCLIP](https://huggingface.co/BAAI/AltCLIP) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import AltCLIPVisionConfig, AltCLIPVisionModel + + >>> # Initializing a AltCLIPVisionConfig with BAAI/AltCLIP style configuration + >>> configuration = AltCLIPVisionConfig() + + >>> # Initializing a AltCLIPVisionModel (with random weights) from the BAAI/AltCLIP style configuration + >>> model = AltCLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "altclip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class AltCLIPConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AltCLIPModel`]. It is used to instantiate an + AltCLIP model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the AltCLIP + [BAAI/AltCLIP](https://huggingface.co/BAAI/AltCLIP) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`AltCLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`AltCLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 768): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original CLIP implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import AltCLIPConfig, AltCLIPModel + + >>> # Initializing a AltCLIPConfig with BAAI/AltCLIP style configuration + >>> configuration = AltCLIPConfig() + + >>> # Initializing a AltCLIPModel (with random weights) from the BAAI/AltCLIP style configuration + >>> model = AltCLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a AltCLIPConfig from a AltCLIPTextConfig and a AltCLIPVisionConfig + + >>> # Initializing a AltCLIPText and AltCLIPVision configuration + >>> config_text = AltCLIPTextConfig() + >>> config_vision = AltCLIPVisionConfig() + + >>> config = AltCLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "altclip" + sub_configs = {"text_config": AltCLIPTextConfig, "vision_config": AltCLIPVisionConfig} + + def __init__( + self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = AltCLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key != "transformers_version": + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `AltCLIPTextConfig`. The " + f'value `text_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = AltCLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key != "transformers_version": + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize `AltCLIPVisionConfig`. " + f'The value `vision_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `AltCLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `AltCLIPVisionConfig` with default values.") + + self.text_config = AltCLIPTextConfig(**text_config) + self.vision_config = AltCLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + + +__all__ = ["AltCLIPTextConfig", "AltCLIPVisionConfig", "AltCLIPConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/altclip/modeling_altclip.py b/venv/lib/python3.13/site-packages/transformers/models/altclip/modeling_altclip.py new file mode 100644 index 0000000000000000000000000000000000000000..61468141c570b04baaf4d5b3412ace0a4e25b485 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/altclip/modeling_altclip.py @@ -0,0 +1,1388 @@ +# coding=utf-8 +# Copyright 2022 The BAAI Teams Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch AltCLIP model.""" + +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, + BaseModelOutputWithPoolingAndProjection, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int +from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig + + +logger = logging.get_logger(__name__) + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +def clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +@auto_docstring +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->AltCLIP +class AltCLIPOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`AltCLIPTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`AltCLIPVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->AltRoberta +class AltRobertaEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +class AltRobertaSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in AltRobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput +class AltRobertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ALT_ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": AltRobertaSelfAttention, +} + + +class AltRobertaAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = AltRobertaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->AltRoberta +class AltRobertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput +class AltRobertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta +class AltRobertaLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = AltRobertaAttention(config) + self.intermediate = AltRobertaIntermediate(config) + self.output = AltRobertaOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta +class AltRobertaEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaPooler +class AltRobertaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class AltCLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask + else: + self.is_causal = causal_attention_mask is not None + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->AltCLIP +class AltCLIPMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class AltCLIPEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: AltCLIPConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = AltCLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = AltCLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class AltCLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`AltCLIPEncoderLayer`]. + + Args: + config: AltCLIPConfig + """ + + def __init__(self, config: AltCLIPConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->AltCLIP +class AltCLIPVisionEmbeddings(nn.Module): + def __init__(self, config: AltCLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." + ) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +@auto_docstring +class AltCLIPPreTrainedModel(PreTrainedModel): + config: AltCLIPConfig + base_model_prefix = "altclip" + supports_gradient_checkpointing = True + _no_split_module = [] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, AltCLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, AltCLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, AltCLIPMLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, AltCLIPModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + module.text_projection._is_hf_initialized = True + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + module.visual_projection._is_hf_initialized = True + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class AltCLIPVisionTransformer(nn.Module): + def __init__(self, config: AltCLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = AltCLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = AltCLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + ) -> Union[tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class AltCLIPVisionModel(AltCLIPPreTrainedModel): + config: AltCLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: AltCLIPVisionConfig): + super().__init__(config) + self.vision_model = AltCLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AltCLIPVisionModel + + >>> model = AltCLIPVisionModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + +@auto_docstring( + custom_intro=""" + The model behaves as an encoder following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762 + """ +) +class AltRobertaModel(AltCLIPPreTrainedModel): + config: AltCLIPTextConfig + + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->AltRoberta + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = AltRobertaEmbeddings(config) + self.encoder = AltRobertaEncoder(config) + + self.pooler = AltRobertaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class AltCLIPTextModel(AltCLIPPreTrainedModel): + config: AltCLIPTextConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = AltRobertaModel(config, add_pooling_layer=False) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.roberta.embeddings.word_embeddings + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.roberta.embeddings.word_embeddings = value + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: + return super().resize_token_embeddings(new_num_tokens) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndProjection]: + r""" + Examples: + + ```python + >>> from transformers import AutoProcessor, AltCLIPTextModel + + >>> model = AltCLIPTextModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") + + >>> texts = ["it's a cat", "it's a dog"] + + >>> inputs = processor(text=texts, padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + # last module outputs + sequence_output = outputs[0] + + # project every module + sequence_output = self.pre_LN(sequence_output) + + # pooler + projection_state = self.transformation(sequence_output) + pooler_output = projection_state[:, 0] + + return BaseModelOutputWithPoolingAndProjection( + last_hidden_state=projection_state, + pooler_output=pooler_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AltCLIPModel(AltCLIPPreTrainedModel): + config: AltCLIPConfig + + def __init__(self, config: AltCLIPConfig): + super().__init__(config) + + if not isinstance(config.vision_config, AltCLIPVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type AltCLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + if not isinstance(config.text_config, AltCLIPTextConfig): + raise TypeError( + "config.text_config is expected to be of type AltCLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + vision_config._attn_implementation = config._attn_implementation + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.project_dim + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = AltCLIPTextModel(text_config) + self.vision_model = AltCLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`AltCLIPTextModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AltCLIPModel + + >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") + + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ```""" + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + pooled_output = text_outputs.pooler_output + text_features = self.text_projection(pooled_output) + + return text_features + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`AltCLIPVisionModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AltCLIPModel + >>> from transformers.image_utils import load_image + + >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.inference_mode(): + ... image_features = model.get_image_features(**inputs) + ```""" + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + pooled_output = vision_outputs.pooler_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, AltCLIPOutput]: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AltCLIPModel + + >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use AltCLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.T + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return AltCLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +__all__ = ["AltCLIPPreTrainedModel", "AltCLIPVisionModel", "AltCLIPTextModel", "AltCLIPModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/altclip/processing_altclip.py b/venv/lib/python3.13/site-packages/transformers/models/altclip/processing_altclip.py new file mode 100644 index 0000000000000000000000000000000000000000..24631ecacbd72782ea5592159ddc4aae22094276 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/altclip/processing_altclip.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# Copyright 2022 WenXiang ZhongzhiCheng LedellWu LiuGuang BoWenZhang The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for AltCLIP +""" + +from ...processing_utils import ProcessorMixin +from ...utils.deprecation import deprecate_kwarg + + +class AltCLIPProcessor(ProcessorMixin): + r""" + Constructs a AltCLIP processor which wraps a CLIP image processor and a XLM-Roberta tokenizer into a single + processor. + + [`AltCLIPProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`XLMRobertaTokenizerFast`]. See + the [`~AltCLIPProcessor.__call__`] and [`~AltCLIPProcessor.decode`] for more information. + + Args: + image_processor ([`CLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`XLMRobertaTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast") + tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast") + + @deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor") + def __init__(self, image_processor=None, tokenizer=None): + super().__init__(image_processor, tokenizer) + + +__all__ = ["AltCLIPProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..618fceef70d3891f4313958cdadede27d605f12c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_audio_spectrogram_transformer import * + from .feature_extraction_audio_spectrogram_transformer import * + from .modeling_audio_spectrogram_transformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd8f4858fe7add5841fc2e4cdf4ed00bc480a92 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2022 Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Audio Spectogram Transformer (AST) model configuration""" + +from typing import Any + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ASTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ASTModel`]. It is used to instantiate an AST + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the AST + [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + frequency_stride (`int`, *optional*, defaults to 10): + Frequency stride to use when patchifying the spectrograms. + time_stride (`int`, *optional*, defaults to 10): + Temporal stride to use when patchifying the spectrograms. + max_length (`int`, *optional*, defaults to 1024): + Temporal dimension of the spectrograms. + num_mel_bins (`int`, *optional*, defaults to 128): + Frequency dimension of the spectrograms (number of Mel-frequency bins). + + Example: + + ```python + >>> from transformers import ASTConfig, ASTModel + + >>> # Initializing a AST MIT/ast-finetuned-audioset-10-10-0.4593 style configuration + >>> configuration = ASTConfig() + + >>> # Initializing a model (with random weights) from the MIT/ast-finetuned-audioset-10-10-0.4593 style configuration + >>> model = ASTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "audio-spectrogram-transformer" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + patch_size=16, + qkv_bias=True, + frequency_stride=10, + time_stride=10, + max_length=1024, + num_mel_bins=128, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.patch_size = patch_size + self.qkv_bias = qkv_bias + self.frequency_stride = frequency_stride + self.time_stride = time_stride + self.max_length = max_length + self.num_mel_bins = num_mel_bins + + # Overwritten from the parent class: AST is not compatible with `generate`, but has a config parameter sharing the + # same name (`max_length`). Sharing the same name triggers checks regarding the config -> generation_config + # generative parameters deprecation cycle, overwriting this function prevents this from happening. + def _get_non_default_generation_parameters(self) -> dict[str, Any]: + return {} + + +__all__ = ["ASTConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f56c0c3213b7e4a920b271df612e7aaad1934a9f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for Audio Spectrogram Transformer. +""" + +from typing import Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import TensorType, is_speech_available, is_torch_available, logging + + +if is_speech_available(): + import torchaudio.compliance.kaldi as ta_kaldi + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class ASTFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Audio Spectrogram Transformer (AST) feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy + otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + num_mel_bins (`int`, *optional*, defaults to 128): + Number of Mel-frequency bins. + max_length (`int`, *optional*, defaults to 1024): + Maximum length to which to pad/truncate the extracted features. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the log-Mel features using `mean` and `std`. + mean (`float`, *optional*, defaults to -4.2677393): + The mean value used to normalize the log-Mel features. Uses the AudioSet mean by default. + std (`float`, *optional*, defaults to 4.5689974): + The standard deviation value used to normalize the log-Mel features. Uses the AudioSet standard deviation + by default. + return_attention_mask (`bool`, *optional*, defaults to `False`): + Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`. + """ + + model_input_names = ["input_values", "attention_mask"] + + def __init__( + self, + feature_size=1, + sampling_rate=16000, + num_mel_bins=128, + max_length=1024, + padding_value=0.0, + do_normalize=True, + mean=-4.2677393, + std=4.5689974, + return_attention_mask=False, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.num_mel_bins = num_mel_bins + self.max_length = max_length + self.do_normalize = do_normalize + self.mean = mean + self.std = std + self.return_attention_mask = return_attention_mask + + if not is_speech_available(): + mel_filters = mel_filter_bank( + num_frequency_bins=257, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = mel_filters + self.window = window_function(400, "hann", periodic=False) + + def _extract_fbank_features( + self, + waveform: np.ndarray, + max_length: int, + ) -> np.ndarray: + """ + Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs + and hence the waveform should not be normalized before feature extraction. + """ + # waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers + if is_speech_available(): + waveform = torch.from_numpy(waveform).unsqueeze(0) + fbank = ta_kaldi.fbank( + waveform, + sample_frequency=self.sampling_rate, + window_type="hanning", + num_mel_bins=self.num_mel_bins, + ) + else: + waveform = np.squeeze(waveform) + fbank = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + + fbank = torch.from_numpy(fbank) + + n_frames = fbank.shape[0] + difference = max_length - n_frames + + # pad or truncate, depending on difference + if difference > 0: + pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference)) + fbank = pad_module(fbank) + elif difference < 0: + fbank = fbank[0:max_length, :] + + fbank = fbank.numpy() + + return fbank + + def normalize(self, input_values: np.ndarray) -> np.ndarray: + return (input_values - (self.mean)) / (self.std * 2) + + def __call__( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + sampling_rate: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # extract fbank features and pad/truncate to max_length + features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech] + + # convert into BatchFeature + padded_inputs = BatchFeature({"input_values": features}) + + # make sure list is in array format + input_values = padded_inputs.get("input_values") + if isinstance(input_values[0], list): + padded_inputs["input_values"] = [np.asarray(feature, dtype=np.float32) for feature in input_values] + + # normalization + if self.do_normalize: + padded_inputs["input_values"] = [self.normalize(feature) for feature in input_values] + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["ASTFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbcaec5b4d18dc69ce58e575c7ed5895cff43e5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -0,0 +1,481 @@ +# coding=utf-8 +# Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Audio Spectrogram Transformer (AST) model.""" + +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import can_return_tuple, check_model_inputs +from .configuration_audio_spectrogram_transformer import ASTConfig + + +logger = logging.get_logger(__name__) + + +class ASTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + """ + + def __init__(self, config: ASTConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = ASTPatchEmbeddings(config) + + frequency_out_dimension, time_out_dimension = self.get_shape(config) + num_patches = frequency_out_dimension * time_out_dimension + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def get_shape(self, config): + # see Karpathy's cs231n blog on how to calculate the output dimensions + # https://cs231n.github.io/convolutional-networks/#conv + frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1 + time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1 + + return frequency_out_dimension, time_out_dimension + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + batch_size = input_values.shape[0] + embeddings = self.patch_embeddings(input_values) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +class ASTPatchEmbeddings(nn.Module): + """ + This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, + seq_length, hidden_size)` to be consumed by a Transformer. + """ + + def __init__(self, config: ASTConfig): + super().__init__() + + patch_size = config.patch_size + frequency_stride = config.frequency_stride + time_stride = config.time_stride + + self.projection = nn.Conv2d( + 1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride) + ) + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + input_values = input_values.unsqueeze(1) + input_values = input_values.transpose(2, 3) + embeddings = self.projection(input_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST +class ASTSelfAttention(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def forward( + self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size + + key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) + query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST +class ASTSelfOutput(nn.Module): + """ + The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ASTConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST +class ASTAttention(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + self.attention = ASTSelfAttention(config) + self.output = ASTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: set[int]): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + self_attn_output, _ = self.attention(hidden_states, head_mask) + output = self.output(self_attn_output, hidden_states) + return output + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST +class ASTIntermediate(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST +class ASTOutput(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST +class ASTLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ASTConfig): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ASTAttention(config) + self.intermediate = ASTIntermediate(config) + self.output = ASTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states_norm = self.layernorm_before(hidden_states) + attention_output = self.attention(hidden_states_norm, head_mask) + + # first residual connection + hidden_states = attention_output + hidden_states + + # in AST, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + return layer_output + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST +class ASTEncoder(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput: + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module(hidden_states, layer_head_mask) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +@auto_docstring +class ASTPreTrainedModel(PreTrainedModel): + config: ASTConfig + base_model_prefix = "audio_spectrogram_transformer" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": ASTLayer, + "attentions": ASTSelfAttention, + } + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ASTEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + module.distillation_token.data.zero_() + + +@auto_docstring +class ASTModel(ASTPreTrainedModel): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + self.config = config + + self.embeddings = ASTEmbeddings(config) + self.encoder = ASTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ASTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`): + Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a + `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library + (`pip install soundfile`). + To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the + mel features, padding and conversion into a tensor of type `torch.FloatTensor`. + See [`~ASTFeatureExtractor.__call__`] + """ + + if input_values is None: + raise ValueError("You have to specify input_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(input_values) + + encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask) + sequence_output = encoder_outputs.last_hidden_state + sequence_output = self.layernorm(sequence_output) + + pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2 + + return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output) + + +class ASTMLPHead(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + def forward(self, hidden_state): + hidden_state = self.layernorm(hidden_state) + hidden_state = self.dense(hidden_state) + return hidden_state + + +@auto_docstring( + custom_intro=""" + Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled + output) e.g. for datasets like AudioSet, Speech Commands v2. + """ +) +class ASTForAudioClassification(ASTPreTrainedModel): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.audio_spectrogram_transformer = ASTModel(config) + + # Classifier head + self.classifier = ASTMLPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`): + Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via + the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the + mel features, padding and conversion into a tensor of type `torch.FloatTensor`. + See [`~ASTFeatureExtractor.__call__`] + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the audio classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs: BaseModelOutputWithPooling = self.audio_spectrogram_transformer( + input_values, head_mask=head_mask, **kwargs + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config, **kwargs) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["ASTForAudioClassification", "ASTModel", "ASTPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bart/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f4c713f4698d7c901ed06738efc8983e317805c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bart/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bart import * + from .modeling_bart import * + from .modeling_flax_bart import * + from .modeling_tf_bart import * + from .tokenization_bart import * + from .tokenization_bart_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/bart/configuration_bart.py b/venv/lib/python3.13/site-packages/transformers/models/bart/configuration_bart.py new file mode 100644 index 0000000000000000000000000000000000000000..90781feab3b51e6fc30198794c47254f16222c60 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bart/configuration_bart.py @@ -0,0 +1,406 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BART model configuration""" + +import warnings +from collections import OrderedDict +from collections.abc import Mapping +from typing import Any, Optional + +from ... import PreTrainedTokenizer +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx.utils import compute_effective_axis_dimension +from ...utils import TensorType, is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +class BartConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the BART + [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + num_labels (`int`, *optional*, defaults to 3): + The number of labels to use in [`BartForSequenceClassification`]. + forced_eos_token_id (`int`, *optional*, defaults to 2): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import BartConfig, BartModel + + >>> # Initializing a BART facebook/bart-large style configuration + >>> configuration = BartConfig() + + >>> # Initializing a model (with random weights) from the facebook/bart-large style configuration + >>> model = BartModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "bart" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=50265, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + num_labels=3, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + is_encoder_decoder=True, + decoder_start_token_id=2, + forced_eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__( + num_labels=num_labels, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + # ensure backward compatibility for BART CNN models + if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): + self.forced_bos_token_id = self.bos_token_id + warnings.warn( + f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " + "The config can simply be saved and uploaded again to be fixed." + ) + + +class BartOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + elif self.task == "causal-lm": + # TODO: figure this case out. + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + else: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), + ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ] + ) + + return common_inputs + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_outputs = super().outputs + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + return common_outputs + + def _generate_dummy_inputs_for_default_and_seq2seq_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, encoder_seq_length = common_inputs["input_ids"].shape + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_past_length = decoder_seq_length + 3 + decoder_shape = ( + batch, + num_decoder_attention_heads, + decoder_past_length, + self._config.hidden_size // num_decoder_attention_heads, + ) + + common_inputs["decoder_attention_mask"] = torch.cat( + [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 + ) + + common_inputs["past_key_values"] = [] + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + for _ in range(min_num_layers): + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + # TODO: test this. + shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape + for _ in range(min_num_layers, max_num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + return common_inputs + + def _generate_dummy_inputs_for_causal_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + num_encoder_layers, _ = self.num_layers + num_encoder_attention_heads, _ = self.num_attention_heads + past_shape = ( + batch, + num_encoder_attention_heads, + past_key_values_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + + mask_dtype = common_inputs["attention_mask"].dtype + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + common_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) + ] + return common_inputs + + def _generate_dummy_inputs_for_sequence_classification_and_question_answering( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # Copied from OnnxConfig.generate_dummy_inputs + # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + return common_inputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + elif self.task == "causal-lm": + common_inputs = self._generate_dummy_inputs_for_causal_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + else: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + return common_inputs + + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + if self.task in ["default", "seq2seq-lm"]: + flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( + flattened_output, name, idx, t + ) + + +__all__ = ["BartConfig", "BartOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_bart.py b/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_bart.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1f2451cff12042ff4de266523a9a27ece5fed7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_bart.py @@ -0,0 +1,1949 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model.""" + +import math +import warnings +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_bart import BartConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class BartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) + + return super().forward(position_ids + self.offset) + + +class BartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class BartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[BartConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class BartEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class BartDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class BartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +@auto_docstring +class BartPreTrainedModel(PreTrainedModel): + config: BartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] + _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class PretrainedBartModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +class BartPretrainedModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +class BartEncoder(BartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`BartEncoderLayer`]. + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class BartDecoder(BartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + attention_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class BartModel(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = BartEncoder(config, self.shared) + self.decoder = BartDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247 + if self.shared.weight.device == torch.device( + "meta" + ) and self.decoder.embed_tokens.weight.device != torch.device("meta"): + self._tie_or_clone_weights(self.encoder.embed_tokens, self.decoder.embed_tokens) + self._tie_or_clone_weights(self.shared, self.decoder.embed_tokens) + else: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The BART Model with a language modeling head. Can be used for summarization. + """ +) +class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + + def __init__(self, config: BartConfig): + super().__init__(config) + self.model = BartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self.model._tie_weights() + self._tie_or_clone_weights(self.lm_head, self.model.shared) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example summarization: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") + + >>> TXT = "My friends are but they eat too many carbs." + >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['not', 'good', 'healthy', 'great', 'very'] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + +@auto_docstring( + custom_intro=""" + Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """ +) +class BartForSequenceClassification(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = BartModel(config) + self.classification_head = BartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@auto_docstring +class BartForQuestionAnswering(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = BartModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +class BartDecoderWrapper(BartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@auto_docstring( + custom_intro=""" + BART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings). + """ +) +class BartForCausalLM(BartPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = BartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +__all__ = [ + "BartForCausalLM", + "BartForConditionalGeneration", + "BartForQuestionAnswering", + "BartForSequenceClassification", + "BartModel", + "BartPreTrainedModel", + "BartPretrainedModel", + "PretrainedBartModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_flax_bart.py b/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_flax_bart.py new file mode 100644 index 0000000000000000000000000000000000000000..818254f3bfa1e95685fdcde8af8f074edb69761f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_flax_bart.py @@ -0,0 +1,2006 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax Bart model.""" + +import math +import random +from functools import partial +from typing import Callable, Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, + FlaxSeq2SeqQuestionAnsweringModelOutput, + FlaxSeq2SeqSequenceClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_bart import BartConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/bart-base" +_CONFIG_FOR_DOC = "BartConfig" + + +BART_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BartConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +BART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +BART_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +BART_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +class FlaxBartAttention(nn.Module): + config: BartConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class FlaxBartEncoderLayer(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxBartAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + residual = hidden_states + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class FlaxBartEncoderLayerCollection(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxBartDecoderLayer(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxBartAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.encoder_attn = FlaxBartAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.decoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class FlaxBartDecoderLayerCollection(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxBartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + config: BartConfig + inner_dim: int + num_classes: int + pooler_dropout: float + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.dropout = nn.Dropout(rate=self.pooler_dropout) + self.out_proj = nn.Dense( + self.num_classes, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = jnp.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxBartEncoder(nn.Module): + config: BartConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_source_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 + + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + self.embed_positions = nn.Embed( + self.config.max_position_embeddings + self.offset, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) + self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(position_ids + self.offset) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return outputs + + return FlaxBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxBartDecoder(nn.Module): + config: BartConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 + + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + self.embed_positions = nn.Embed( + self.config.max_position_embeddings + self.offset, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) + self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + positions = self.embed_positions(position_ids + self.offset) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return outputs + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class FlaxBartModule(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxBartPreTrainedModel(FlaxPreTrainedModel): + config_class = BartConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: BartConfig, + input_shape: tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + # make sure initialization pass will work for FlaxBartForSequenceClassificationModule + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = input_ids + decoder_attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(BART_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BartConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration + + >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BartConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration + + >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBartAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.", + BART_START_DOCSTRING, +) +class FlaxBartModel(FlaxBartPreTrainedModel): + config: BartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxBartModule + + +append_call_sample_docstring(FlaxBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + + +class FlaxBartForConditionalGenerationModule(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.model = FlaxBartModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.model.shared.num_embeddings, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["shared"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING +) +class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): + module_class = FlaxBartForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BartConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration + + >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBartAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.variables["params"]["shared"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + lm_logits += module.final_logits_bias.astype(self.dtype) + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration + + >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]).sequences + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` + + Mask filling example: + + ```python + >>> import jax + >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration + + >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") + + >>> TXT = "My friends are but they eat too many carbs." + >>> input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"] + + >>> logits = model(input_ids).logits + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() + >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) + >>> values, predictions = jax.lax.top_k(probs, k=1) + + >>> tokenizer.decode(predictions).split() + ``` +""" + +overwrite_call_docstring( + FlaxBartForConditionalGeneration, BART_INPUTS_DOCSTRING + FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +class FlaxBartForSequenceClassificationModule(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 + num_labels: Optional[int] = None + + def setup(self): + self.model = FlaxBartModule(config=self.config, dtype=self.dtype) + self.classification_head = FlaxBartClassificationHead( + config=self.config, + inner_dim=self.config.d_model, + num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels, + pooler_dropout=self.config.classifier_dropout, + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] # last hidden state + + eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) + + # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation + if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer): + if len(jnp.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + + if any(eos_mask.sum(1) == 0): + raise ValueError("There are missing tokens in input_ids") + + # Ensure to keep 1 only for the last token for each example + eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6 + eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0) + + sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1) + logits = self.classification_head(sentence_representation, deterministic=deterministic) + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return FlaxSeq2SeqSequenceClassifierOutput( + logits=logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + BART_START_DOCSTRING, +) +class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel): + module_class = FlaxBartForSequenceClassificationModule + dtype = jnp.float32 + + +append_call_sample_docstring( + FlaxBartForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSeq2SeqSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxBartForQuestionAnsweringModule(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 + num_labels = 2 + + def setup(self): + self.model = FlaxBartModule(config=self.config, dtype=self.dtype) + self.qa_outputs = nn.Dense( + self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return output + + return FlaxSeq2SeqQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BART_START_DOCSTRING, +) +class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel): + module_class = FlaxBartForQuestionAnsweringModule + dtype = jnp.float32 + + +append_call_sample_docstring( + FlaxBartForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxSeq2SeqQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel): + config_class = BartConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: BartConfig, + input_shape: tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + config.is_decoder = True + config.is_encoder_decoder = False + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + return module_init_outputs["params"] + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(BART_DECODE_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + past_key_values: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if encoder_hidden_states is not None and encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # prepare decoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxBartAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxBartDecoderWrapper(nn.Module): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + config: BartConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + embed_dim = self.config.d_model + embed_tokens = nn.Embed( + self.config.vocab_size, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype) + + def __call__(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class FlaxBartForCausalLMModule(nn.Module): + config: BartConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings) + e.g for autoregressive tasks. + """, + BART_START_DOCSTRING, +) +class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel): + module_class = FlaxBartForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxBartForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxBartDecoderPreTrainedModel", + "FlaxBartForCausalLM", + "FlaxBartForConditionalGeneration", + "FlaxBartForQuestionAnswering", + "FlaxBartForSequenceClassification", + "FlaxBartModel", + "FlaxBartPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_tf_bart.py b/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_tf_bart.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6d2317d696f9a9d1edce17f256ceccc8134e49 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bart/modeling_tf_bart.py @@ -0,0 +1,1713 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Bart model.""" + +from __future__ import annotations + +import random + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, + TFSeq2SeqSequenceClassifierOutput, +) + +# Public API +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_bart import BartConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/bart-large" +_CONFIG_FOR_DOC = "BartConfig" + + +LARGE_NEGATIVE = -1e8 + + +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + start_tokens = tf.fill( + (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) + ) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE + mask_cond = tf.range(shape_list(mask)[-1]) + + mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +class TFBartLearnedPositionalEmbedding(keras.layers.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) + + def call( + self, + input_shape: tf.TensorShape | None = None, + past_key_values_length: int = 0, + position_ids: tf.Tensor | None = None, + ): + """Input is expected to be of size [bsz x seqlen].""" + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(seq_len, delta=1, name="range") + position_ids += past_key_values_length + + offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32 + return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype)) + + +class TFBartAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: tuple[tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: bool | None = False, + ) -> tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +class TFBartEncoderLayer(keras.layers.Layer): + def __init__(self, config: BartConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFBartAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" + ) + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: np.ndarray | tf.Tensor | None, + layer_head_mask: tf.Tensor | None, + training: bool | None = False, + ) -> tf.Tensor: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` + """ + residual = hidden_states + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask + ) + + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, self_attn_weights + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.encoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFBartDecoderLayer(keras.layers.Layer): + def __init__(self, config: BartConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFBartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.encoder_attn = TFBartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + name="encoder_attn", + is_decoder=True, + ) + self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + cross_attn_layer_head_mask: tf.Tensor | None = None, + past_key_value: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + training: bool | None = False, + ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`tf.Tensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`tf.Tensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` + past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return ( + hidden_states, + self_attn_weights, + cross_attn_weights, + present_key_value, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "encoder_attn", None) is not None: + with tf.name_scope(self.encoder_attn.name): + self.encoder_attn.build(None) + if getattr(self, "encoder_attn_layer_norm", None) is not None: + with tf.name_scope(self.encoder_attn_layer_norm.name): + self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.decoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFBartClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, inner_dim: int, num_classes: int, pooler_dropout: float, name: str, **kwargs): + super().__init__(name=name, **kwargs) + self.dense = keras.layers.Dense(inner_dim, name="dense") + self.dropout = keras.layers.Dropout(pooler_dropout) + self.out_proj = keras.layers.Dense(num_classes, name="out_proj") + self.input_dim = inner_dim + self.inner_dim = inner_dim + + def call(self, inputs): + hidden_states = self.dropout(inputs) + hidden_states = self.dense(hidden_states) + hidden_states = keras.activations.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.input_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.inner_dim]) + + +class TFBartPretrainedModel(TFPreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + + @property + def dummy_inputs(self): + dummy_inputs = super().dummy_inputs + # Dummy inputs should not contain the default val of 1 + # as this is the padding token and some assertions check it + dummy_inputs["input_ids"] = dummy_inputs["input_ids"] * 2 + if "decoder_input_ids" in dummy_inputs: + dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2 + return dummy_inputs + + def tf_to_pt_weight_rename(self, tf_weight): + if tf_weight == "model.shared.weight": + return tf_weight, "model.decoder.embed_tokens.weight" + else: + return (tf_weight,) + + +BART_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`BartConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +BART_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, TFBartForConditionalGeneration + + >>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="tf") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5) + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, TFBartForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") + >>> TXT = "My friends are but they eat too many carbs." + + >>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large") + >>> input_ids = tokenizer([TXT], return_tensors="tf")["input_ids"] + >>> logits = model(input_ids).logits + >>> probs = tf.nn.softmax(logits[0]) + >>> # probs[5] is associated with the mask token + ``` +""" + + +BART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tf.FloatTensor`, *optional*): + hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + of shape `(batch_size, sequence_length, hidden_size)` is a sequence of + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFBartEncoder(keras.layers.Layer): + config_class = BartConfig + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`TFBartEncoderLayer`]. + + Args: + config: BartConfig + """ + + def __init__(self, config: BartConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.dropout = keras.layers.Dropout(config.dropout) + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + + self.embed_tokens = embed_tokens + self.embed_positions = TFBartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] + self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") + self.embed_dim = config.d_model + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + """ + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + + # check attention mask and invert + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask) + else: + attention_mask = None + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + tf.debugging.assert_equal( + shape_list(head_mask)[0], + len(self.layers), + message=( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(head_mask)[0]}." + ), + ) + + # encoder layers + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if training and (dropout_probability < self.layerdrop): # skip the layer + continue + + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + ) + + if output_attentions: + all_attentions += (attn,) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layernorm_embedding", None) is not None: + with tf.name_scope(self.layernorm_embedding.name): + self.layernorm_embedding.build([None, None, self.embed_dim]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFBartDecoder(keras.layers.Layer): + config_class = BartConfig + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFBartDecoderLayer`] + + Args: + config: BartConfig + embed_tokens: output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.embed_tokens = embed_tokens + self.layerdrop = config.decoder_layerdrop + self.embed_positions = TFBartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] + self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") + + self.dropout = keras.layers.Dropout(config.dropout) + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.tTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + # embed positions + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + hidden_states = inputs_embeds + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] + ) + + if attention_mask is not None: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) + + hidden_states = self.layernorm_embedding(hidden_states + positions) + hidden_states = self.dropout(hidden_states, training=training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + present_key_values = () if use_cache else None + + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: + if attn_mask is not None: + tf.debugging.assert_equal( + shape_list(attn_mask)[0], + len(self.layers), + message=( + f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(attn_mask)[0]}." + ), + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + + if training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + past_key_value=past_key_value, + ) + + if use_cache: + present_key_values += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + + if encoder_hidden_states is not None: + all_cross_attns += (layer_cross_attn,) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns + else: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layernorm_embedding", None) is not None: + with tf.name_scope(self.layernorm_embedding.name): + self.layernorm_embedding.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFBartMainLayer(keras.layers.Layer): + config_class = BartConfig + + def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.shared = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.d_model, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), + name="model.shared", + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix + + self.encoder = TFBartEncoder(config, self.shared, name="encoder") + self.decoder = TFBartDecoder(config, self.shared, name="decoder") + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: tuple | TFBaseModelOutput | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + **kwargs, + ) -> TFSeq2SeqModelOutput | tuple[tf.Tensor]: + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare BART Model outputting raw hidden-states without any specific head on top.", + BART_START_DOCSTRING, +) +class TFBartModel(TFBartPretrainedModel): + _requires_load_weight_prefix = True + + def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSeq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: tuple | TFBaseModelOutput | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + **kwargs, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqModelOutput( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +class BiasLayer(keras.layers.Layer): + """ + Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, + so all weights have to be registered in a layer. + """ + + def __init__(self, shape, initializer, trainable, name, **kwargs): + super().__init__(name=name, **kwargs) + # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of + # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: + # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 + self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) + + def call(self, x): + return x + self.bias + + +@add_start_docstrings( + "The BART Model with a language modeling head. Can be used for summarization.", + BART_START_DOCSTRING, +) +class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss): + _keys_to_ignore_on_load_missing = [r"final_logits_bias"] + _requires_load_weight_prefix = True + + def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") + self.use_cache = config.use_cache + # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False + ) + + def get_decoder(self): + return self.model.decoder + + def get_encoder(self): + return self.model.encoder + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + def get_bias(self): + return {"final_logits_bias": self.bias_layer.bias} + + def set_bias(self, value): + # Replaces the existing layers containing bias for correct (de)serialization. + vocab_size = value["final_logits_bias"].shape[-1] + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False + ) + self.bias_layer.bias.assign(value["final_logits_bias"]) + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(BART_GENERATION_EXAMPLE) + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: TFBaseModelOutput | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + + if labels is not None: + labels = tf.where( + labels == self.config.pad_token_id, + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), + labels, + ) + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) + lm_logits = self.bias_layer(lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return TFSeq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, # index 1 of d outputs + decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs + decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs + encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs + encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out + encoder_attentions=outputs.encoder_attentions, # 2 of e out + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past_key_values is not None: # no xla + past_key_values + decoder_position_ids = past_key_values[0][0].shape[2] + else: # no xla + no past_key_values + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "bias_layer", None) is not None: + with tf.name_scope(self.bias_layer.name): + self.bias_layer.build(None) + + +@add_start_docstrings( + """ + Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + BART_START_DOCSTRING, +) +class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") + self.classification_head = TFBartClassificationHead( + config.d_model, config.num_labels, config.classifier_dropout, name="classification_head" + ) + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: TFBaseModelOutput | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSeq2SeqSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + last_hidden_state = outputs[0] + eos_mask = tf.equal(input_ids, self.config.eos_token_id) + # out the rows with False where present. Then verify all the final + # entries are True + self_masked = tf.reshape(tf.boolean_mask(eos_mask, eos_mask), (tf.shape(input_ids)[0], -1)) + tf.Assert(tf.reduce_all(self_masked[:, -1]), ["All examples must have the same number of tokens."]) + + masked = tf.reshape( + tf.boolean_mask(last_hidden_state, eos_mask), + (tf.shape(input_ids)[0], tf.shape(self_masked)[1], tf.shape(last_hidden_state)[-1]), + ) + + sentence_representation = masked[:, -1, :] + logits = self.classification_head(sentence_representation) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSeq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def serving_output(self, output): + logits = tf.convert_to_tensor(output.logits) + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqSequenceClassifierOutput( + logits=logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "classification_head", None) is not None: + with tf.name_scope(self.classification_head.name): + self.classification_head.build(None) + + +__all__ = ["TFBartForConditionalGeneration", "TFBartForSequenceClassification", "TFBartModel", "TFBartPretrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bart/tokenization_bart.py b/venv/lib/python3.13/site-packages/transformers/models/bart/tokenization_bart.py new file mode 100644 index 0000000000000000000000000000000000000000..f674afe1a4126781ef64f8b5e96436b0025490c2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bart/tokenization_bart.py @@ -0,0 +1,393 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from functools import lru_cache +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} + +# See all BART models at https://huggingface.co/models?filter=bart + + +@lru_cache +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class BartTokenizer(PreTrainedTokenizer): + """ + Constructs a BART tokenizer, which is smilar to the ROBERTa tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import BartTokenizer + + >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (BART tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BART sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. BART does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + +__all__ = ["BartTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bart/tokenization_bart_fast.py b/venv/lib/python3.13/site-packages/transformers/models/bart/tokenization_bart_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..88b002f595299f8d16b605c0ce4fd59330e5c4c4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bart/tokenization_bart_fast.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_bart import BartTokenizer + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + +# See all BART models at https://huggingface.co/models?filter=bart + + +class BartTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" BART tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer, + using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import BartTokenizerFast + + >>> tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (BART tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = BartTokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + trim_offsets=True, + **kwargs, + ): + # we have to specify that this tokens is special otherwise adding it will reset the normalized flag to `False` in `add_special_tokens` + mask_token = ( + AddedToken(mask_token, lstrip=True, normalized=True, special=True) + if isinstance(mask_token, str) + else mask_token + ) + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + **kwargs, + ) + + # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + @property + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + BART tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on Bart. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + if is_split_into_words and not self.add_prefix_space: + raise ValueError( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + if is_split_into_words and not self.add_prefix_space: + raise ValueError( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. BART does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + +__all__ = ["BartTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/beit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/beit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f412a3500683fa9ac5ebac12c860ae91ab4422b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/beit/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_beit import * + from .feature_extraction_beit import * + from .image_processing_beit import * + from .image_processing_beit_fast import * + from .modeling_beit import * + from .modeling_flax_beit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/beit/configuration_beit.py b/venv/lib/python3.13/site-packages/transformers/models/beit/configuration_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..4bea72b7bd0b514fbe8d9de83222ab5f3e03ed9f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/beit/configuration_beit.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BEiT model configuration""" + +import warnings +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +class BeitConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the BEiT + [microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 8192): + Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during + pre-training. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style absolute position embeddings. + use_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use T5-style relative position embeddings in the self-attention layers. + use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use the same relative position embeddings across all self-attention layers of the Transformer. + layer_scale_init_value (`float`, *optional*, defaults to 0.1): + Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_mean_pooling (`bool`, *optional*, defaults to `True`): + Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the + CLS token, before applying the classification head. + pool_scales (`tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (`int`, *optional*, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (`int`, *optional*, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + add_fpn (`bool`, *optional*, defaults to `False`): + Whether to add a FPN as part of the backbone. Only relevant for [`BeitBackbone`]. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. Only relevant for [`BeitBackbone`]. + + Example: + + ```python + >>> from transformers import BeitConfig, BeitModel + + >>> # Initializing a BEiT beit-base-patch16-224-pt22k style configuration + >>> configuration = BeitConfig() + + >>> # Initializing a model (with random weights) from the beit-base-patch16-224-pt22k style configuration + >>> model = BeitModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "beit" + + def __init__( + self, + vocab_size=8192, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + use_mask_token=False, + use_absolute_position_embeddings=False, + use_relative_position_bias=False, + use_shared_relative_position_bias=False, + layer_scale_init_value=0.1, + drop_path_rate=0.1, + use_mean_pooling=True, + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, + semantic_loss_ignore_index=255, + out_features=None, + out_indices=None, + add_fpn=False, + reshape_hidden_states=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.use_mask_token = use_mask_token + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.use_relative_position_bias = use_relative_position_bias + self.use_shared_relative_position_bias = use_shared_relative_position_bias + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.use_mean_pooling = use_mean_pooling + # decode head attributes (semantic segmentation) + self.pool_scales = pool_scales + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + # handle backwards compatibility + if "segmentation_indices" in kwargs: + warnings.warn( + "The `segmentation_indices` argument is deprecated and will be removed in a future version, use `out_indices` instead.", + FutureWarning, + ) + out_indices = kwargs.pop("segmentation_indices") + + # backbone attributes + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.add_fpn = add_fpn + self.reshape_hidden_states = reshape_hidden_states + + +# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig +class BeitOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["BeitConfig", "BeitOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/beit/feature_extraction_beit.py b/venv/lib/python3.13/site-packages/transformers/models/beit/feature_extraction_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..7886897c6d1018e8b27e19170ef562cebc4462da --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/beit/feature_extraction_beit.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for BEiT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_beit import BeitImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class BeitFeatureExtractor(BeitImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class BeitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use BeitImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["BeitFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/beit/image_processing_beit.py b/venv/lib/python3.13/site-packages/transformers/models/beit/image_processing_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..c25880bcfadacdfc7a7a5663ad1766d7230f1df3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/beit/image_processing_beit.py @@ -0,0 +1,504 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Beit.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_torch_available, + is_torch_tensor, + is_vision_available, + logging, +) +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class BeitImageProcessor(BaseImageProcessor): + r""" + Constructs a BEiT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image + is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the + `preprocess` method. + crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`. + Can be overridden by the `crop_size` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + The mean to use if normalizing the image. This is a float or list of floats of length of the number of + channels of the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + The standard deviation to use if normalizing the image. This is a float or list of floats of length of the + number of channels of the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + @filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS) + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + rescale_factor: Union[int, float] = 1 / 255, + do_rescale: bool = True, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_reduce_labels: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 256, "width": 256} + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_reduce_labels = do_reduce_labels + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to (size["height"], size["width"]). + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=True, param_name="size") + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}") + return resize( + image, + size=(size["height"], size["width"]), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + + def _preprocess( + self, + image: ImageInput, + do_reduce_labels: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_reduce_labels: + image = self.reduce_label(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + image = self._preprocess( + image, + do_reduce_labels=False, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def _preprocess_segmentation_map( + self, + segmentation_map: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_reduce_labels: Optional[bool] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """Preprocesses a single segmentation map.""" + # All transformations expect numpy arrays. + segmentation_map = to_numpy_array(segmentation_map) + # Add an axis to the segmentation maps for transformations. + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + added_dimension = True + input_data_format = ChannelDimension.FIRST + else: + added_dimension = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + segmentation_map = self._preprocess( + image=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + resample=resample, + size=size, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_normalize=False, + do_rescale=False, + input_data_format=ChannelDimension.FIRST, + ) + # Remove extra axis if added + if added_dimension: + segmentation_map = np.squeeze(segmentation_map, axis=0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + + def __call__(self, images, segmentation_maps=None, **kwargs): + # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both + # be passed in as positional arguments. + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*) + Segmentation maps to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be + padded with zeros and then cropped + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=True, param_name="size") + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + + images = make_flat_list_of_images(images) + + if segmentation_maps is not None: + segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) + + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + do_center_crop=do_center_crop, + do_rescale=do_rescale, + do_normalize=do_normalize, + resample=resample, + size=size, + rescale_factor=rescale_factor, + crop_size=crop_size, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + + data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_segmentation_map( + segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + resample=resample, + size=size, + do_center_crop=do_center_crop, + crop_size=crop_size, + ) + for segmentation_map in segmentation_maps + ] + data["labels"] = segmentation_maps + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + """ + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`BeitForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`list[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["BeitImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/beit/image_processing_beit_fast.py b/venv/lib/python3.13/site-packages/transformers/models/beit/image_processing_beit_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..4518043e6841d2b6f562027a99680f909989eefb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/beit/image_processing_beit_fast.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Beit.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + is_torch_tensor, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) + + +class BeitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + """ + + do_reduce_labels: Optional[bool] + + +@auto_docstring +class BeitImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + default_to_square = True + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = False + do_rescale = True + do_normalize = True + do_reduce_labels = False + valid_kwargs = BeitFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[BeitFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + return label + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[BeitFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[BeitFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + images_kwargs = kwargs.copy() + images_kwargs["do_reduce_labels"] = False + batch_feature = self._preprocess(images, **images_kwargs) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update({"do_normalize": False, "do_rescale": False}) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ).pixel_values + batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return batch_feature + + def _preprocess( + self, + images: list["torch.Tensor"], + do_reduce_labels: bool, + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + if do_reduce_labels: + images = self.reduce_label(images) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + """ + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`BeitForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`list[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["BeitImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/beit/modeling_beit.py b/venv/lib/python3.13/site-packages/transformers/models/beit/modeling_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4e0d7126513c5aa9371b14d9eb906b17748537 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/beit/modeling_beit.py @@ -0,0 +1,1547 @@ +# coding=utf-8 +# Copyright 2021 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BEiT model.""" + +import collections.abc +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + MaskedLMOutput, + SemanticSegmenterOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging, torch_int +from ...utils.backbone_utils import BackboneMixin +from .configuration_beit import BeitConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`BeitModel`]. + """ +) +class BeitModelOutputWithPooling(BaseModelOutputWithPooling): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + """ + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class BeitDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class BeitEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: BeitConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + else: + self.mask_token = None + self.patch_embeddings = BeitPatchEmbeddings(config) + self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) + num_patches = self.patch_embeddings.num_patches + if config.use_absolute_position_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + else: + self.position_embeddings = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + ) -> torch.Tensor: + if self.position_embeddings is not None and interpolate_pos_encoding is not None: + warnings.warn( + "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always " + "interpolated to the input image size. The argument will be removed in transformers v4.51.0." + ) + + _, _, height, width = pixel_values.shape + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + if self.position_embeddings is not None: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = self.dropout(embeddings) + + return embeddings, (patch_height, patch_width) + + +class BeitPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + embeddings = self.projection(pixel_values) + patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, (patch_height, patch_width) + + +class BeitSelfAttention(nn.Module): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.has_relative_position_bias = bool(window_size) + if self.has_relative_position_bias: + self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Add relative position bias if present. + if self.has_relative_position_bias: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + attention_scores = attention_scores + self.relative_position_bias( + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_scores = attention_scores + relative_position_bias + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class BeitSdpaSelfAttention(BeitSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + attn_bias = None + if self.has_relative_position_bias: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + attn_bias = self.relative_position_bias( + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + if attn_bias is None: + attn_bias = relative_position_bias + else: + attn_bias += relative_position_bias + + scaling = 1 / math.sqrt(self.attention_head_size) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attn_bias, + dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=scaling, + ) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer, None + + +class BeitSelfOutput(nn.Module): + """ + The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +BEIT_SELF_ATTENTION_CLASSES = { + "eager": BeitSelfAttention, + "sdpa": BeitSdpaSelfAttention, +} + + +class BeitAttention(nn.Module): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.attention = BEIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, window_size=window_size) + self.output = BeitSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + self_outputs = self.attention( + hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution + ) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BeitIntermediate(nn.Module): + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class BeitOutput(nn.Module): + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class BeitLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BeitAttention(config, window_size=window_size) + self.intermediate = BeitIntermediate(config) + self.output = BeitOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + init_values = config.layer_scale_init_value + if init_values > 0: + self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True) + self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True) + else: + self.lambda_1, self.lambda_2 = None, None + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int, int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in BEiT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class BeitRelativePositionBias(nn.Module): + def __init__(self, config: BeitConfig, window_size: tuple) -> None: + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, config.num_attention_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + @compile_compatible_method_lru_cache(maxsize=10) + def generate_relative_position_index(self, window_size: tuple[int, int]) -> torch.Tensor: + """ + This method creates the relative position index, modified to support arbitrary window sizes, + as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460). + """ + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # cls to token & token 2 cls & cls to cls + # get pair-wise relative position index for each token inside the window + window_area = window_size[0] * window_size[1] + grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij") + coords = torch.stack(grid) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor: + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = nn.functional.interpolate( + old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear" + ) + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]] + ) + + relative_position_index = self.generate_relative_position_index(window_size) + relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)] + + # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads + relative_position_bias = relative_position_bias.view( + window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1 + ) + # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + + if interpolate_pos_encoding: + relative_position_bias = nn.functional.interpolate( + relative_position_bias.unsqueeze(1), + size=(dim_size, dim_size), + mode="bilinear", + align_corners=False, + ).squeeze(1) + + return relative_position_bias.unsqueeze(0) + + +class BeitEncoder(nn.Module): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.config = config + self.has_relative_position_bias = config.use_shared_relative_position_bias + if self.has_relative_position_bias: + self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")] + self.layer = nn.ModuleList( + [ + BeitLayer( + config, + window_size=window_size if config.use_relative_position_bias else None, + drop_path_rate=dpr[i], + ) + for i in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int, int]] = None, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.has_relative_position_bias: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + relative_position_bias = self.relative_position_bias( + window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + else: + relative_position_bias = None + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + head_mask=layer_head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class BeitPreTrainedModel(PreTrainedModel): + config: BeitConfig + base_model_prefix = "beit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["BeitLayer"] + _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"] + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, BeitEmbeddings): + module.cls_token.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, BeitRelativePositionBias): + module.relative_position_bias_table.data.zero_() + elif isinstance(module, BeitLayer): + if module.lambda_1 is not None: + module.lambda_1.data.fill_(self.config.layer_scale_init_value) + module.lambda_2.data.fill_(self.config.layer_scale_init_value) + + +@auto_docstring +class BeitModel(BeitPreTrainedModel): + def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None: + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = BeitEmbeddings(config) + self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape) + + self.layernorm = ( + nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + self.pooler = BeitPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BeitModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + resolution = pixel_values.shape[2:] + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + resolution=resolution, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BeitModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class BeitPooler(nn.Module): + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.layernorm = ( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.layernorm is not None: + # Mean pool the final hidden states of the patch tokens + patch_tokens = hidden_states[:, 1:, :] + pooled_output = self.layernorm(patch_tokens.mean(1)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, 0] + + return pooled_output + + +@auto_docstring( + custom_intro=""" + Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting + visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT + predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you + will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT. + """ +) +class BeitForMaskedImageModeling(BeitPreTrainedModel): + def __init__(self, config: BeitConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=False) + + # Classifier head + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return None + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, logits = outputs.loss, outputs.logits + >>> list(logits.shape) + [1, 196, 8192] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + sequence_output = self.layernorm(sequence_output) + prediction_scores = self.lm_head(sequence_output[:, 1:]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final + hidden states of the patch tokens) e.g. for ImageNet. + """ +) +class BeitForImageClassification(BeitPreTrainedModel): + def __init__(self, config: BeitConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=True) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BeitConvModule(nn.Module): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, tuple[int, int]], + padding: Union[int, tuple[int, int], str] = 0, + bias: bool = False, + dilation: Union[int, tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + dilation=dilation, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.conv(input) + output = self.bn(output) + output = self.activation(output) + + return output + + +class BeitPyramidPoolingBlock(nn.Module): + def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None: + super().__init__() + self.layers = [ + nn.AdaptiveAvgPool2d(pool_scale), + BeitConvModule(in_channels, channels, kernel_size=1), + ] + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class BeitPyramidPoolingModule(nn.Module): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + align_corners (bool): align_corners argument of F.interpolate. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None: + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.blocks = [] + for i, pool_scale in enumerate(pool_scales): + block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels) + self.blocks.append(block) + self.add_module(str(i), block) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + ppm_outs = [] + for ppm in self.blocks: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class BeitUperHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://huggingface.co/papers/1807.10221). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config: BeitConfig) -> None: + super().__init__() + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] + self.channels = config.hidden_size + self.align_corners = False + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + # PSP Module + self.psp_modules = BeitPyramidPoolingModule( + self.pool_scales, + self.in_channels[-1], + self.channels, + align_corners=self.align_corners, + ) + self.bottleneck = BeitConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = BeitConvModule( + len(self.in_channels) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +class BeitFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented of + [FCNNet](https://huggingface.co/papers/1411.4038>). + + Args: + config (BeitConfig): Configuration. + in_channels + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, tuple[int, int]] = 1 + ) -> None: + super().__init__() + self.in_channels = config.hidden_size + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + BeitConvModule( + self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + for i in range(self.num_convs - 1): + convs.append( + BeitConvModule( + self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = BeitConvModule( + self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) + if self.concat_input: + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) + output = self.classifier(output) + return output + + +@auto_docstring +class BeitForSemanticSegmentation(BeitPreTrainedModel): + def __init__(self, config: BeitConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=False) + + # FPNs + if len(self.config.out_indices) != 4: + raise ValueError( + "BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers, " + "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of " + "a base-sized architecture." + ) + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + nn.BatchNorm2d(config.hidden_size), + nn.GELU(), + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + # Semantic segmentation head(s) + self.decode_head = BeitUperHead(config) + self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None + + # Initialize weights and apply final processing + self.post_init() + + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if auxiliary_logits is not None: + upsampled_auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + main_loss = loss_fct(upsampled_logits, labels) + loss = main_loss + if auxiliary_logits is not None: + auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) + loss += self.config.auxiliary_loss_weight * auxiliary_loss + + return loss + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, BeitForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + >>> model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + # only keep certain features, and reshape + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] + batch_size = pixel_values.shape[0] + patch_resolution = self.config.image_size // self.config.patch_size + features = [ + x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features + ] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + logits = self.decode_head(features) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + + loss = None + if labels is not None: + loss = self.compute_loss(logits, auxiliary_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + BEiT backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class BeitBackbone(BeitPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = BeitEmbeddings(config) + self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape) + + if config.add_fpn: + if len(self.config.out_indices) != 4: + raise ValueError( + "BeitBackbone requires config.out_indices to be a list of 4 integers, " + "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of " + "a base-sized architecture." + ) + hidden_size = config.hidden_size + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2), + nn.BatchNorm2d(hidden_size, eps=config.batch_norm_eps), + nn.GELU(), + nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.Sequential(nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + # initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @auto_docstring + def forward( + self, + pixel_values: Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/beit-base-patch16-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 14, 14] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + batch_size = pixel_values.shape[0] + embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values) + resolution = pixel_values.shape[2:] + + outputs = self.encoder( + embedding_output, + output_hidden_states=True, + output_attentions=output_attentions, + resolution=resolution, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:, :] + hidden_state = hidden_state.permute(0, 2, 1) + hidden_state = hidden_state.reshape(batch_size, -1, patch_height, patch_width) + + feature_maps += (hidden_state,) + + if self.config.add_fpn: + feature_maps = [ + self.fpn1(feature_maps[0]), + self.fpn2(feature_maps[1]), + self.fpn3(feature_maps[2]), + self.fpn4(feature_maps[3]), + ] + feature_maps = tuple(feature_maps) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = [ + "BeitForImageClassification", + "BeitForMaskedImageModeling", + "BeitForSemanticSegmentation", + "BeitModel", + "BeitPreTrainedModel", + "BeitBackbone", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/beit/modeling_flax_beit.py b/venv/lib/python3.13/site-packages/transformers/models/beit/modeling_flax_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..c80deace6b39dc79968eff060a15573b0e2ea5ad --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/beit/modeling_flax_beit.py @@ -0,0 +1,956 @@ +# coding=utf-8 +# Copyright 2021 Microsoft Research and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPooling, + FlaxMaskedLMOutput, + FlaxSequenceClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_beit import BeitConfig + + +@flax.struct.dataclass +class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling): + """ + Class for outputs of [`FlaxBeitModel`]. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + +BEIT_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BeitConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +BEIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`AutoImageProcessor.__call__`] for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def relative_position_index_init(window_size: tuple[int, int]) -> jnp.ndarray: + """ + get pair-wise relative position index for each token inside the window + """ + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + + coords_h = np.arange(window_size[0]) + coords_w = np.arange(window_size[1]) + coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords_flatten = np.reshape(coords, (2, -1)) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + + relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return jnp.array(relative_position_index) + + +def ones_with_scale(key, shape, scale, dtype=jnp.float32): + return jnp.ones(shape, dtype) * scale + + +class FlaxBeitDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + rate: float + + @nn.module.compact + def __call__(self, inputs, deterministic: Optional[bool] = True): + if self.rate == 0.0: + return inputs + keep_prob = 1.0 - self.rate + if deterministic: + return inputs + else: + shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + rng = self.make_rng("droppath") + random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype) + binary_tensor = jnp.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + +class FlaxBeitPatchEmbeddings(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.num_channels = self.config.num_channels + image_size = self.config.image_size + patch_size = self.config.patch_size + num_patches = (image_size // patch_size) * (image_size // patch_size) + patch_shape = (image_size // patch_size, image_size // patch_size) + self.num_patches = num_patches + self.patch_shape = patch_shape + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, pixel_values): + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) + + +class FlaxBeitEmbeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) + if self.config.use_mask_token: + self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) + self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + if self.config.use_absolute_position_embeddings: + self.position_embeddings = self.param( + "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True): + embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.shape + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + cls_tokens = cls_tokens.astype(embeddings.dtype) + + if bool_masked_pos is not None: + mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size)) + mask_tokens = mask_tokens.astype(embeddings.dtype) + # replace the masked visual tokens by mask_tokens + w = jnp.expand_dims(bool_masked_pos, axis=-1) + embeddings = embeddings * (1 - w) + mask_tokens * w + + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + + if self.config.use_absolute_position_embeddings: + embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype) + + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +class FlaxBeitRelativePositionBias(nn.Module): + config: BeitConfig + window_size: tuple[int, int] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3 + self.relative_position_bias_table = self.param( + "relative_position_bias_table", + nn.initializers.zeros, + (num_relative_distance, self.config.num_attention_heads), + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + self.relative_position_index = relative_position_index_init(self.window_size) + + def __call__(self): + index = self.relative_position_index.reshape(-1) + shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) + relative_position_bias = self.relative_position_bias_table[index].reshape(shape) # Wh*Ww,Wh*Ww,nH + return jnp.transpose(relative_position_bias, (2, 0, 1)) + + +class FlaxBeitSelfAttention(nn.Module): + config: BeitConfig + window_size: tuple[int, int] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr( + self.config, "embedding_size" + ): + raise ValueError( + f"The hidden size {self.config.hidden_size} is not a multiple of the number of attention " + f"heads {self.config.num_attention_heads}." + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=False, + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.relative_position_bias = ( + FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype) + if self.window_size + else None + ) + + def __call__( + self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False + ): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attention_bias = jnp.array(0.0, dtype=self.dtype) + # Add relative position bias if present. + if self.relative_position_bias is not None: + attention_bias = jnp.expand_dims(self.relative_position_bias(), 0) + attention_bias = attention_bias.astype(query_states.dtype) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype) + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxBeitSelfOutput(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxBeitAttention(nn.Module): + config: BeitConfig + window_size: tuple[int, int] + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype) + self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False + ): + attn_outputs = self.attention( + hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions + ) + attn_output = attn_outputs[0] + attn_output = self.output(attn_output, deterministic=deterministic) + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxBeitIntermediate(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +class FlaxBeitOutput(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + return hidden_states + + +class FlaxBeitLayer(nn.Module): + config: BeitConfig + window_size: tuple[int, int] + drop_path_rate: float + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype) + self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype) + self.output = FlaxBeitOutput(self.config, dtype=self.dtype) + self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate) + self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + self.init_values = self.config.layer_scale_init_value + if self.init_values > 0: + self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values) + self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values) + else: + self.lambda_1 = None + self.lambda_2 = None + + def __call__( + self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False + ): + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention + relative_position_bias, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states + + # in BEiT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, deterministic=deterministic) + + # apply lambda_2 if present + if self.lambda_2 is not None: + layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states + + outputs = (layer_output,) + + if output_attentions: + outputs += (self_attention_outputs[1],) + + return outputs + + +class FlaxBeitLayerCollection(nn.Module): + config: BeitConfig + window_size: tuple[int, int] + drop_path_rates: list[float] + relative_position_bias: Callable[[], jnp.ndarray] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxBeitLayer( + self.config, + window_size=self.window_size if self.config.use_relative_position_bias else None, + drop_path_rate=self.drop_path_rates[i], + name=str(i), + dtype=self.dtype, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None + layer_outputs = layer( + hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxBeitEncoder(nn.Module): + config: BeitConfig + window_size: tuple[int, int] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.use_shared_relative_position_bias: + self.relative_position_bias = FlaxBeitRelativePositionBias( + config=self.config, window_size=self.window_size, dtype=self.dtype + ) + + # stochastic depth decay rule + drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers)) + self.layer = FlaxBeitLayerCollection( + self.config, + window_size=self.window_size, + drop_path_rates=drop_path_rates, + relative_position_bias=self.relative_position_bias + if self.config.use_shared_relative_position_bias + else None, + dtype=self.dtype, + ) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BeitConfig + base_model_prefix = "beit" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: BeitConfig, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + bool_masked_pos=None, + params: Optional[dict] = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs["dropout"] = dropout_rng + rngs["droppath"] = droppath_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + bool_masked_pos, + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxBeitPooler(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.use_mean_pooling: + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states): + if self.config.use_mean_pooling: + # Mean pool the final hidden states of the patch tokens + patch_tokens = hidden_states[:, 1:, :] + pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, 0] + + return pooled_output + + +class FlaxBeitModule(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxBeitEncoder( + self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype + ) + if not self.config.use_mean_pooling: + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None + + def __call__( + self, + pixel_values, + bool_masked_pos=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if not self.config.use_mean_pooling: + hidden_states = self.layernorm(hidden_states) + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBeitModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.", + BEIT_START_DOCSTRING, +) +class FlaxBeitModel(FlaxBeitPreTrainedModel): + module_class = FlaxBeitModule + + +FLAX_BEIT_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxBeitModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") + >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig) + + +class FlaxBeitForMaskedImageModelingModule(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype) + + # Classifier head + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__( + self, + pixel_values=None, + bool_masked_pos=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + bool_masked_pos, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + sequence_output = self.layernorm(sequence_output) + prediction_scores = self.lm_head(sequence_output[:, 1:]) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return output + + return FlaxMaskedLMOutput( + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", + BEIT_START_DOCSTRING, +) +class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel): + module_class = FlaxBeitForMaskedImageModelingModule + + +FLAX_BEIT_MLM_DOCSTRING = """ + bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ``` +""" + +overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING) +append_replace_return_docstrings( + FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig +) + + +class FlaxBeitForImageClassificationModule(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True) + self.classifier = nn.Dense( + self.config.num_labels, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__( + self, + pixel_values=None, + bool_masked_pos=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + logits = self.classifier(pooled_output) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final + hidden states of the patch tokens) e.g. for ImageNet. + """, + BEIT_START_DOCSTRING, +) +class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel): + module_class = FlaxBeitForImageClassificationModule + + +FLAX_BEIT_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") + >>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + ``` +""" + +overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig +) + + +__all__ = [ + "FlaxBeitForImageClassification", + "FlaxBeitForMaskedImageModeling", + "FlaxBeitModel", + "FlaxBeitPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bert_generation/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f83b1f6e5bba323d5636820ef89026f072be816 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bert_generation import * + from .modeling_bert_generation import * + from .tokenization_bert_generation import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/bert_generation/configuration_bert_generation.py b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/configuration_bert_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..e6cf054cc5e27a5475b16e84f54e1a5091db62d0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/configuration_bert_generation.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BertGeneration model configuration""" + +from ...configuration_utils import PretrainedConfig + + +class BertGenerationConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BertGenerationPreTrainedModel`]. It is used to + instantiate a BertGeneration model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the BertGeneration + [google/bert_for_seq_generation_L-24_bbc_encoder](https://huggingface.co/google/bert_for_seq_generation_L-24_bbc_encoder) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50358): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BertGeneration`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often called feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + + Examples: + + ```python + >>> from transformers import BertGenerationConfig, BertGenerationEncoder + + >>> # Initializing a BertGeneration config + >>> configuration = BertGenerationConfig() + + >>> # Initializing a model (with random weights) from the config + >>> model = BertGenerationEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "bert-generation" + + def __init__( + self, + vocab_size=50358, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + position_embedding_type="absolute", + use_cache=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + + +__all__ = ["BertGenerationConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bert_generation/modeling_bert_generation.py b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/modeling_bert_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..4be87a0cd5446a1996f951f2577b2df13000e880 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/modeling_bert_generation.py @@ -0,0 +1,880 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model specific for generation.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_bert_generation import BertGenerationConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BertGeneration +class BertGenerationSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->BertGeneration +class BertGenerationSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertGenerationModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +BERT_GENERATION_SELF_ATTENTION_CLASSES = { + "eager": BertGenerationSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration,BERT->BERT_GENERATION +class BertGenerationAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.self = BERT_GENERATION_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, + ) + self.output = BertGenerationSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BertGeneration +class BertGenerationIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BertGeneration +class BertGenerationOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration +class BertGenerationLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertGenerationAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BertGenerationAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) + self.intermediate = BertGenerationIntermediate(config) + self.output = BertGenerationOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertGenerationLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +def load_tf_weights_in_bert_generation( + model, tf_hub_path, model_class, is_encoder_named_decoder=False, is_encoder=False +): + try: + import numpy as np + import tensorflow.compat.v1 as tf + import tensorflow_hub as hub + import tensorflow_text # noqa: F401 + + tf.disable_eager_execution() + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_model = hub.Module(tf_hub_path) + init = tf.global_variables_initializer() + with tf.Session() as sess: + init.run() + all_variables = tf_model.variable_map + keep_track_variables = all_variables.copy() + for key in list(all_variables.keys()): + if "global" in key: + logger.info(f"Skipping {key}...") + continue + if not is_encoder: + model_pointer = getattr(model, model_class) + else: + model_pointer = model + is_embedding = False + logger.info(f"Trying to match {key}...") + # remove start_string = "module/bert/" + sub_layers = key.split("/")[2:] + if is_encoder_named_decoder and sub_layers[0] == "encoder": + logger.info(f"Skipping encoder layer {key} for decoder") + continue + if is_encoder and sub_layers[0] == "decoder": + logger.info(f"Skipping decoder layer {key} for encoder") + continue + for i, sub_layer in enumerate(sub_layers): + if sub_layer == "embeddings": + is_embedding = True + elif sub_layer == "LayerNorm": + is_embedding = False + if "layer" in sub_layer: + model_pointer = model_pointer.layer[int(sub_layer.split("_")[-1])] + elif sub_layer in ["kernel", "gamma"]: + model_pointer = model_pointer.weight + elif sub_layer == "beta": + model_pointer = model_pointer.bias + elif sub_layer == "encdec": + model_pointer = model_pointer.crossattention.self + elif sub_layer == "encdec_output": + model_pointer = model_pointer.crossattention.output + elif is_encoder_named_decoder and sub_layer == "decoder": + model_pointer = model_pointer.encoder + else: + if sub_layer == "attention" and "encdec" in sub_layers[i + 1]: + continue + try: + model_pointer = getattr(model_pointer, sub_layer) + except AttributeError: + logger.info(f"Skipping to initialize {key} at {sub_layer}...") + raise AttributeError + + array = np.asarray(sess.run(all_variables[key])) + if not is_embedding: + logger.info(f"Transposing numpy weight of shape {array.shape} for {key}") + array = np.transpose(array) + else: + model_pointer = model_pointer.weight + + if model_pointer.shape != array.shape: + raise ValueError(f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched") + logger.info(f"Initialize PyTorch weight {key}") + + model_pointer.data = torch.from_numpy(array.astype(np.float32)) + keep_track_variables.pop(key, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(keep_track_variables.keys())}") + return model + + +class BertGenerationEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = inputs_embeds + position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +@auto_docstring +class BertGenerationPreTrainedModel(PreTrainedModel): + config: BertGenerationConfig + base_model_prefix = "bert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, BertGenerationOnlyLMHead): + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + The bare BertGeneration model transformer outputting raw hidden-states without any specific head on top. + """ +) +class BertGenerationEncoder(BertGenerationPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + This model should be used when leveraging Bert or Roberta checkpoints for the [`EncoderDecoderModel`] class as + described in [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://huggingface.co/papers/1907.12461) + by Sascha Rothe, Shashi Narayan, and Aliaksei Severyn. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = BertGenerationEmbeddings(config) + self.encoder = BertEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=sequence_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertGenerationOnlyLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, hidden_states): + logits = self.decoder(hidden_states) + return logits + + def _tie_weights(self): + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + + +@auto_docstring( + custom_intro=""" + BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BertGenerationDecoder` as a standalone, add `is_decoder=True.`") + + self.bert = BertGenerationEncoder(config) + self.lm_head = BertGenerationOnlyLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertGenerationDecoder, BertGenerationConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") + >>> config = BertGenerationConfig.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") + >>> config.is_decoder = True + >>> model = BertGenerationDecoder.from_pretrained( + ... "google/bert_for_seq_generation_L-24_bbc_encoder", config=config + ... ) + + >>> inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +__all__ = [ + "BertGenerationDecoder", + "BertGenerationEncoder", + "BertGenerationPreTrainedModel", + "load_tf_weights_in_bert_generation", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bert_generation/tokenization_bert_generation.py b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/tokenization_bert_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..baeba9e4b38849b6d325fbccd525129cc2955271 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bert_generation/tokenization_bert_generation.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model BertGeneration.""" + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +@requires(backends=("sentencepiece",)) +class BertGenerationTokenizer(PreTrainedTokenizer): + """ + Construct a BertGeneration tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The begin of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + sep_token (`str`, *optional*, defaults to `"<::::>"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + prefix_tokens: list[int] = [] + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + sep_token="<::::>", + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + # Add extra_ids to the special token list + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + sep_token=sep_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["BertGenerationTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/big_bird/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/big_bird/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87419e69e5c7f0572076832237e10f236186ec33 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/big_bird/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_big_bird import * + from .modeling_big_bird import * + from .modeling_flax_big_bird import * + from .tokenization_big_bird import * + from .tokenization_big_bird_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/big_bird/configuration_big_bird.py b/venv/lib/python3.13/site-packages/transformers/models/big_bird/configuration_big_bird.py new file mode 100644 index 0000000000000000000000000000000000000000..8e29439bc4bd54f06689b7a3e93008baf93427ad --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/big_bird/configuration_big_bird.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BigBird model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BigBirdConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BigBirdModel`]. It is used to instantiate an + BigBird model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the BigBird + [google/bigbird-roberta-base](https://huggingface.co/google/bigbird-roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50358): + Vocabulary size of the BigBird model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BigBirdModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 1024 or 2048 or 4096). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`BigBirdModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + attention_type (`str`, *optional*, defaults to `"block_sparse"`) + Whether to use block sparse attention (with n complexity) as introduced in paper or original attention + layer (with n^2 complexity). Possible values are `"original_full"` and `"block_sparse"`. + use_bias (`bool`, *optional*, defaults to `True`) + Whether to use bias in query, key, value. + rescale_embeddings (`bool`, *optional*, defaults to `False`) + Whether to rescale embeddings with (hidden_size ** 0.5). + block_size (`int`, *optional*, defaults to 64) + Size of each block. Useful only when `attention_type == "block_sparse"`. + num_random_blocks (`int`, *optional*, defaults to 3) + Each query is going to attend these many number of random blocks. Useful only when `attention_type == + "block_sparse"`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Example: + + ```python + >>> from transformers import BigBirdConfig, BigBirdModel + + >>> # Initializing a BigBird google/bigbird-roberta-base style configuration + >>> configuration = BigBirdConfig() + + >>> # Initializing a model (with random weights) from the google/bigbird-roberta-base style configuration + >>> model = BigBirdModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "big_bird" + + def __init__( + self, + vocab_size=50358, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu_new", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=4096, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sep_token_id=66, + attention_type="block_sparse", + use_bias=True, + rescale_embeddings=False, + block_size=64, + num_random_blocks=3, + classifier_dropout=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + + self.rescale_embeddings = rescale_embeddings + self.attention_type = attention_type + self.use_bias = use_bias + self.block_size = block_size + self.num_random_blocks = num_random_blocks + self.classifier_dropout = classifier_dropout + + +class BigBirdOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["BigBirdConfig", "BigBirdOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/big_bird/modeling_big_bird.py b/venv/lib/python3.13/site-packages/transformers/models/big_bird/modeling_big_bird.py new file mode 100644 index 0000000000000000000000000000000000000000..eb89d9872be8816c83158faa4ac81c2238c634d7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/big_bird/modeling_big_bird.py @@ -0,0 +1,2957 @@ +# coding=utf-8 +# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BigBird model.""" + +import math +import os +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_big_bird import BigBirdConfig + + +logger = logging.get_logger(__name__) + + +_TRIVIA_QA_MAPPING = { + "big_bird_attention": "attention/self", + "output_layer_norm": "output/LayerNorm", + "attention_output": "attention/output/dense", + "output": "output/dense", + "self_attention_layer_norm": "attention/output/LayerNorm", + "intermediate": "intermediate/dense", + "word_embeddings": "bert/embeddings/word_embeddings", + "position_embedding": "bert/embeddings/position_embeddings", + "type_embeddings": "bert/embeddings/token_type_embeddings", + "embeddings": "bert/embeddings", + "layer_normalization": "output/LayerNorm", + "layer_norm": "LayerNorm", + "trivia_qa_head": "qa_classifier", + "dense": "intermediate/dense", + "dense_1": "qa_outputs", +} + + +def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False): + """Load tf checkpoints in a pytorch model.""" + + def load_tf_weights_bert(init_vars, tf_path): + names = [] + tf_weights = {} + + for name, shape in init_vars: + array = tf.train.load_variable(tf_path, name) + name = name.replace("bert/encoder/LayerNorm", "bert/embeddings/LayerNorm") + logger.info(f"Loading TF weight {name} with shape {shape}") + names.append(name) + tf_weights[name] = array + + return names, tf_weights + + def load_tf_weights_trivia_qa(init_vars): + names = [] + tf_weights = {} + + for i, var in enumerate(init_vars): + name_items = var.name.split("/") + + if "transformer_scaffold" in name_items[0]: + layer_name_items = name_items[0].split("_") + if len(layer_name_items) < 3: + layer_name_items += [0] + + name_items[0] = f"bert/encoder/layer_{layer_name_items[2]}" + + name = "/".join([_TRIVIA_QA_MAPPING.get(x, x) for x in name_items])[:-2] # remove last :0 in variable + + if "self/attention/output" in name: + name = name.replace("self/attention/output", "output") + + if i >= len(init_vars) - 2: + name = name.replace("intermediate", "output") + + logger.info(f"Loading TF weight {name} with shape {var.shape}") + array = var.value().numpy() + names.append(name) + tf_weights[name] = array + + return names, tf_weights + + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + + # Load weights from TF model + init_vars = tf.saved_model.load(tf_path).variables if is_trivia_qa else tf.train.list_variables(tf_path) + + if len(init_vars) <= 0: + raise ValueError("Loaded trained variables cannot be empty.") + + pt_names = list(model.state_dict().keys()) + + if is_trivia_qa: + names, tf_weights = load_tf_weights_trivia_qa(init_vars) + else: + names, tf_weights = load_tf_weights_bert(init_vars, tf_path) + + for txt_name in names: + array = tf_weights[txt_name] + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + pt_name = [] + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + pt_name.append("bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + pt_name.append("classifier") + elif scope_names[0] == "transform": + pointer = getattr(pointer, "transform") + pt_name.append("transform") + if ("bias" in name) or ("kernel" in name): + pointer = getattr(pointer, "dense") + pt_name.append("dense") + elif ("beta" in name) or ("gamma" in name): + pointer = getattr(pointer, "LayerNorm") + pt_name.append("LayerNorm") + else: + try: + pointer = getattr(pointer, scope_names[0]) + pt_name.append(f"{scope_names[0]}") + except AttributeError: + logger.info(f"Skipping {m_name}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + pt_name.append(f"{num}") + if m_name[-11:] == "_embeddings" or m_name == "embeddings": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if len(array.shape) > len(pointer.shape) and math.prod(array.shape) == math.prod(pointer.shape): + # print(txt_name, array.shape) + if ( + txt_name.endswith("attention/self/key/kernel") + or txt_name.endswith("attention/self/query/kernel") + or txt_name.endswith("attention/self/value/kernel") + ): + array = array.transpose(1, 0, 2).reshape(pointer.shape) + elif txt_name.endswith("attention/output/dense/kernel"): + array = array.transpose(0, 2, 1).reshape(pointer.shape) + else: + array = array.reshape(pointer.shape) + + if pointer.shape != array.shape: + raise ValueError( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}." + ) + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + pt_weight_name = ".".join(pt_name) + logger.info(f"Initialize PyTorch weight {pt_weight_name} from {txt_name}.") + pointer.data = torch.from_numpy(array) + tf_weights.pop(txt_name, None) + pt_names.remove(pt_weight_name) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + logger.info(f"Weights not initialized in PyTorch model: {', '.join(pt_names)}.") + return model + + +class BigBirdEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + # End copy + + self.rescale_embeddings = config.rescale_embeddings + self.hidden_size = config.hidden_size + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.rescale_embeddings: + inputs_embeds = inputs_embeds * (self.hidden_size**0.5) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + embeddings = self.dropout(embeddings) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BigBirdSelfAttention(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + output_attentions=False, + cache_position=None, + ): + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if is_cross_attention and past_key_values is not None and past_key_values.get_seq_length(self.layer_idx) > 0: + # reuse k,v, cross_attentions + key_layer = past_key_values.layers[self.layer_idx].keys + value_layer = past_key_values.layers[self.layer_idx].values + else: + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + key_layer, value_layer = past_key_values.update( + key_layer, + value_layer, + self.layer_idx, + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BigBirdModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, attention_probs + + +class BigBirdBlockSparseAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + + self.max_seqlen = config.max_position_embeddings + self.seed = seed + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.num_random_blocks = config.num_random_blocks + self.block_size = config.block_size + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + def forward( + self, + hidden_states, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + output_attentions=None, + ): + # Currently this `class` can't be used in decoder. + + batch_size, seqlen, _ = hidden_states.size() + to_seq_length = from_seq_length = seqlen + from_block_size = to_block_size = self.block_size + + if from_seq_length % from_block_size != 0: + raise ValueError("Query sided sequence length must be multiple of block size") + + if to_seq_length % to_block_size != 0: + raise ValueError("Key/Value sided sequence length must be multiple of block size") + + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + context_layer, attention_probs = self.bigbird_block_sparse_attention( + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + self.num_attention_heads, + self.num_random_blocks, + self.attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_length, + to_seq_length, + seed=self.seed, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=output_attentions, + ) + + context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) + return context_layer, attention_probs + + @staticmethod + def torch_bmm_nd(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication""" + # faster replacement of torch.einsum ("bhqk,bhkd->bhqd") + return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view( + inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]) + ) + + @staticmethod + def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication with transpose""" + # faster replacement of torch.einsum (bhqd,bhkd->bhqk) + return torch.bmm( + inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2) + ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2])) + + def bigbird_block_sparse_attention( + self, + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + n_heads, + n_rand_blocks, + attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_len, + to_seq_len, + seed, + plan_from_length, + plan_num_rand_blocks, + output_attentions, + ): + # BigBird block-sparse attention as suggested in paper + + # ITC: + # global tokens: 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # ETC: + # global tokens: extra_globals_tokens + 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # Note: + # 1) Currently, ETC is not supported. + # 2) Window size is fixed to 3 blocks & it can be changed only by + # changing `block_size`. + # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be + # controlled only by `block_size`. + + # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of shifting tokens (for calculating sliding attention) + # hence following code can be divided into 5 parts. + + if from_seq_len // from_block_size != to_seq_len // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + rsqrt_d = 1 / math.sqrt(attention_head_size) + bsz = batch_size + attn_mask_penalty = -10000.0 + + # generate random attention and corresponding masks + np.random.seed(seed) + if from_seq_len in [1024, 3072, 4096]: # old plans used in paper + rand_attn = [ + self._bigbird_block_rand_mask( + self.max_seqlen, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024 + )[: (from_seq_len // from_block_size - 2)] + for _ in range(n_heads) + ] + else: + if plan_from_length is None: + plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( + from_seq_len, from_block_size, n_rand_blocks + ) + + rand_attn = self._bigbird_block_rand_mask_with_head( + from_seq_length=from_seq_len, + to_seq_length=to_seq_len, + from_block_size=from_block_size, + to_block_size=to_block_size, + num_heads=n_heads, + plan_from_length=plan_from_length, + plan_num_rand_blocks=plan_num_rand_blocks, + ) + + rand_attn = np.stack(rand_attn, axis=0) + rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long) + rand_attn.unsqueeze_(0) + rand_attn = torch.cat([rand_attn for _ in range(batch_size)], dim=0) + + rand_mask = self._create_rand_mask_from_inputs( + from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size + ) + + blocked_query_matrix = query_layer.view(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) + blocked_key_matrix = key_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + blocked_value_matrix = value_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + + # preparing block for randn attn + gathered_key = self.torch_gather_b2(blocked_key_matrix, rand_attn) + gathered_key = gathered_key.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + gathered_value = self.torch_gather_b2(blocked_value_matrix, rand_attn) + gathered_value = gathered_value.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + + # 1st PART + # 1st block (global block) attention scores + # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4) + + first_product = first_product * rsqrt_d + first_product += (1.0 - to_mask) * attn_mask_penalty + first_attn_weights = nn.functional.softmax( + first_product, dim=-1 + ) # [bsz, n_heads, from_block_size, to_seq_len] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) + first_context_layer.unsqueeze_(2) + + # 2nd PART + # 2nd block attention scores + # q[1] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> 2nd, 3rd blocks + # global key blocks -> 1st block + + second_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, 1], + blocked_key_matrix[:, :, 2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + second_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, 1], + blocked_value_matrix[:, :, 2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 1], second_key_mat, ndim=4) + second_seq_pad = torch.cat( + [ + to_mask[:, :, :, : 3 * to_block_size], + to_mask[:, :, :, -to_block_size:], + to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_rand_pad = torch.cat( + [ + rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, 0], + ], + dim=3, + ) + second_product = second_product * rsqrt_d + second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty + second_attn_weights = nn.functional.softmax( + second_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_context_layer = self.torch_bmm_nd(second_attn_weights, second_value_mat, ndim=4) + + second_context_layer.unsqueeze_(2) + + # 3rd PART + # Middle blocks attention scores + # q[-2:2] x (sliding_keys, random_keys, global_keys) + # sliding attn is calculated using special trick of shifting tokens as discussed in paper + # random keys are generated by taking random indices as per `rand_attn` + # global keys -> 1st & last block + + exp_blocked_key_matrix = torch.cat( + [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], dim=3 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + exp_blocked_value_matrix = torch.cat( + [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], + dim=3, + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + middle_query_matrix = blocked_query_matrix[:, :, 2:-2] + + # sliding attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + inner_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, exp_blocked_key_matrix, ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] + inner_band_product = inner_band_product * rsqrt_d + + # randn attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + rand_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, gathered_key[:, :, 1:-1], ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] + rand_band_product = rand_band_product * rsqrt_d + + # Including 1st block (since it's global) + first_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + first_band_product = first_band_product * rsqrt_d + + # Including last block (since it's global) + last_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + last_band_product = last_band_product * rsqrt_d + + # masking padded tokens + inner_band_product += (1.0 - band_mask) * attn_mask_penalty + first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * attn_mask_penalty + last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * attn_mask_penalty + rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty + + # completing attention scores matrix for all q[-2:2] + band_product = torch.cat( + [first_band_product, inner_band_product, rand_band_product, last_band_product], dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # safely doing softmax since attention matrix is completed + attn_weights = nn.functional.softmax( + band_product, dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # contribution of sliding keys + # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + context_layer = self.torch_bmm_nd( + attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix, ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of random keys + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + context_layer += self.torch_bmm_nd( + attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], gathered_value[:, :, 1:-1], ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of global keys + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # 4th PART + # last 2nd token attention scores + # q[-2] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> last 3 blocks + # global key block -> 1st block + # random key block -> based on indices stored in `randn_attn` + + second_last_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, -3], + blocked_key_matrix[:, :, -2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] + second_last_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, -3], + blocked_value_matrix[:, :, -2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+r)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -2], second_last_key_mat, ndim=4) + second_last_seq_pad = torch.cat( + [ + to_mask[:, :, :, :to_block_size], + to_mask[:, :, :, -3 * to_block_size :], + to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_last_rand_pad = torch.cat( + [ + rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, -1], + ], + dim=3, + ) + second_last_product = second_last_product * rsqrt_d + second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty + second_last_attn_weights = nn.functional.softmax( + second_last_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_last_context_layer = self.torch_bmm_nd(second_last_attn_weights, second_last_value_mat, ndim=4) + second_last_context_layer.unsqueeze_(2) + + # 5th PART + # last block (global) attention scores + # q[-1] x (k[0], k[1], k[2], k[3], .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) + last_product = last_product * rsqrt_d + last_product += (1.0 - to_mask) * attn_mask_penalty + last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) + last_context_layer.unsqueeze_(2) + + # combining representations of all tokens + context_layer = torch.cat( + [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], + dim=2, + ) + context_layer = context_layer.view((bsz, n_heads, from_seq_len, -1)) * from_mask + context_layer = torch.transpose(context_layer, 1, 2) + + # this is just for visualizing; forward pass doesn't depend on following code + if output_attentions: + # TODO(PVP): need to verify if below code is correct + attention_probs = torch.zeros( + bsz, n_heads, from_seq_len, to_seq_len, dtype=torch.float, device=context_layer.device + ) + + # 1st query block + # corresponding to `first_context_layer` + attention_probs[:, :, :from_block_size, :] = first_attn_weights # all keys global + + # 2nd query block + # corresponding to `second_context_layer` + attention_probs[:, :, from_block_size : 2 * from_block_size, : 3 * to_block_size] = second_attn_weights[ + :, :, :, : 3 * to_block_size + ] # 1st three key blocks (global + sliding) + attention_probs[:, :, from_block_size : 2 * from_block_size, -to_block_size:] = second_attn_weights[ + :, :, :, 3 * to_block_size : 4 * to_block_size + ] # last key block (global) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, 1, :, i2[0]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Middle query blocks + # corresponding to `context_layer` + # sliding keys + for q_idx in range(from_seq_len // from_block_size - 4): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + )[:, :, 2:-2, :, 1:-1, :] + right_slice = attn_weights[:, :, q_idx, :, to_block_size : 4 * to_block_size] + attn_probs_view[:, :, q_idx, :, q_idx : q_idx + 3, :] = right_slice.view( + bsz, n_heads, from_block_size, 3, to_block_size + ) # inner_band_product + # global keys (corresponding to 1st key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, :to_block_size] = attn_weights[ + :, :, :, :, :to_block_size + ].view(bsz, n_heads, -1, to_block_size) # first_band_product + # global keys (corresponding to last key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, -to_block_size:] = attn_weights[ + :, :, :, :, -to_block_size: + ].view(bsz, n_heads, -1, to_block_size) # last_band_product + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + for q_idx in range(1, len(i2) - 1): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[q_idx - 1, :, 4 * to_block_size : -to_block_size] + attn_probs_view[p1, p2, q_idx + 1, :, i2[q_idx]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Second-last query block + # corresponding to `second_last_context_layer` + attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[ + :, :, :, :to_block_size + ] # 1st key block (global) + attention_probs[:, :, -2 * from_block_size : -from_block_size, -3 * to_block_size :] = ( + second_last_attn_weights[:, :, :, to_block_size : 4 * to_block_size] + ) # last three blocks (global + sliding) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, -2, :, i2[-1]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # last query block + # corresponding to `last_context_layer` + attention_probs[:, :, -from_block_size:, :] = last_attn_weights # all keys global + + else: + attention_probs = None + + return context_layer, attention_probs + + @staticmethod + def torch_gather_b2(params, indices): + # this operation is equivalent to tf.gather when batch_dims=2 + + if params.shape[:2] != indices.shape[:2]: + raise ValueError( + "Make sure that the first two dimensions of params and indices are identical, but" + f" they are params: {params.shape[:2]} vs. indices: {indices.shape[:2]}" + ) + num_indices_to_gather = indices.shape[-2] * indices.shape[-1] + num_indices_to_pick_from = params.shape[2] + + shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device) + indices_shift = torch.div(shift, num_indices_to_gather, rounding_mode="floor") * num_indices_to_pick_from + + flattened_indices = indices.view(-1) + indices_shift + flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1]) + + out_flattened = flattened_params.index_select(0, flattened_indices) + + out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:]) + return out + + @staticmethod + def _create_rand_mask_from_inputs( + from_blocked_mask, + to_blocked_mask, + rand_attn, + num_attention_heads, + num_rand_blocks, + batch_size, + from_seq_length, + from_block_size, + ): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + rand_attn: [batch_size, num_attention_heads, + from_seq_length//from_block_size-2, num_rand_blocks] + num_attention_heads: int. Number of attention heads. + num_rand_blocks: int. Number of random chunks per row. + batch_size: int. Batch size for computation. + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + + Returns: + float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, + from_block_size, num_rand_blocks*to_block_size]. + """ + num_windows = from_seq_length // from_block_size - 2 + rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]) + rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size) + rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) + return rand_mask + + @staticmethod + def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): + """ + Gives the plan of where to put random attention. + + Args: + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + num_rand_blocks: int. Number of random chunks per row. + + Returns: + plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for + each block + """ + + plan_from_length = [] + plan_num_rand_blocks = [] + if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(0) + elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks // 2) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) + else: + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks) + + return plan_from_length, plan_num_rand_blocks + + def _bigbird_block_rand_mask( + self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1 + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_rand_blocks: int. Number of random chunks per row. + last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, + if positive then num_rand_blocks blocks chosen only up to last_idx. + + Returns: + adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks + """ + # using this method when from_seq_length in [1024, 3072, 4096] + + if from_seq_length // from_block_size != to_seq_length // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) + # During inference (eval) no randomness + if not self.training: + return rand_attn + middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) + last = to_seq_length // to_block_size - 1 + if last_idx > (2 * to_block_size): + last = (last_idx // to_block_size) - 1 + + r = num_rand_blocks # shorthand + for i in range(1, from_seq_length // from_block_size - 1): + start = i - 2 + end = i + if i == 1: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] + elif i == 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] + elif i == from_seq_length // from_block_size - 3: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -3: should have been sliced till last-3 + elif i == from_seq_length // from_block_size - 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -4: should have been sliced till last-4 + else: + if start > last: + start = last + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + elif (end + 1) == last: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + else: + rand_attn[i - 1, :] = np.random.permutation( + np.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) + )[:r] + return rand_attn + + def _bigbird_block_rand_mask_with_head( + self, + from_seq_length, + to_seq_length, + from_block_size, + to_block_size, + num_heads, + plan_from_length, + plan_num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_top=1, + global_block_bottom=1, + global_block_left=1, + global_block_right=1, + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_heads: int. total number of heads. + plan_from_length: list. plan from length where num_random_blocks are chosen from. + plan_num_rand_blocks: list. number of rand blocks within the plan. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_top: int. number of blocks at the top. + global_block_bottom: int. number of blocks at the bottom. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by + num_rand_blocks + """ + # using this method when from_seq_length not in [1024, 3072, 4096] + + if from_seq_length // from_block_size != to_seq_length // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + if from_seq_length not in plan_from_length: + raise ValueError("Error from sequence length not in plan!") + + # Total number of blocks in the mmask + num_blocks = from_seq_length // from_block_size + # Number of blocks per plan + plan_block_length = np.array(plan_from_length) // from_block_size + # till when to follow plan + max_plan_idx = plan_from_length.index(from_seq_length) + + # Random Attention adjacency list + rand_attn = [ + np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32) + for i in range(num_heads) + ] + # During inference (eval) no randomness + if not self.training: + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + return rand_attn + + # We will go iteratively over the plan blocks and pick random number of + # Attention blocks from the legally allowed blocks + for plan_idx in range(max_plan_idx + 1): + rnd_r_cnt = 0 + if plan_idx > 0: + # set the row for all from_blocks starting from 0 to + # plan_block_length[plan_idx-1] + # column indx start from plan_block_length[plan_idx-1] and ends at + # plan_block_length[plan_idx] + if plan_num_rand_blocks[plan_idx] > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=plan_block_length[plan_idx - 1], + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for pl_id in range(plan_idx): + if plan_num_rand_blocks[pl_id] == 0: + continue + for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): + rnd_r_cnt = 0 + to_start_block_id = 0 + if pl_id > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id])) + to_start_block_id = plan_block_length[pl_id - 1] + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: pl_id + 1])) + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[pl_id], + num_rand_blocks=plan_num_rand_blocks[pl_id], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + if plan_num_rand_blocks[plan_idx] == 0: + continue + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + from_start_block_id = global_block_top + to_start_block_id = 0 + if plan_idx > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + from_start_block_id = plan_block_length[plan_idx - 1] + to_start_block_id = plan_block_length[plan_idx - 1] + + for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + + return rand_attn + + @staticmethod + def _get_single_block_row_attention( + block_id, + to_start_block_id, + to_end_block_id, + num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_left=1, + global_block_right=1, + ): + """ + For a single row block get random row attention. + + Args: + block_id: int. block id of row. + to_start_block_id: int. random attention column start id. + to_end_block_id: int. random attention column end id. + num_rand_blocks: int. number of random blocks to be selected. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + row containing the random attention vector of size num_rand_blocks. + """ + # list of to_blocks from which to choose random attention + to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32) + # permute the blocks + perm_block = np.random.permutation(to_block_list) + + # illegal blocks for the current block id, using window + illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) + + # Add blocks at the start and at the end + illegal_blocks.extend(list(range(global_block_left))) + illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) + + # The second from_block cannot choose random attention on second last to_block + if block_id == 1: + illegal_blocks.append(to_end_block_id - 2) + + # The second last from_block cannot choose random attention on second to_block + if block_id == to_end_block_id - 2: + illegal_blocks.append(1) + + selected_random_blocks = [] + + for i in range(to_end_block_id - to_start_block_id): + if perm_block[i] not in illegal_blocks: + selected_random_blocks.append(perm_block[i]) + if len(selected_random_blocks) == num_rand_blocks: + break + return np.array(selected_random_blocks, dtype=np.int32) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BigBird +class BigBirdSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BigBirdAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + self.attention_type = config.attention_type + self.config = config + self.seed = seed + + if self.config.attention_type == "original_full": + self.self = BigBirdSelfAttention(config, layer_idx=seed) + elif self.config.attention_type == "block_sparse": + self.self = BigBirdBlockSparseAttention(config, seed) + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.config.attention_type}" + ) + + self.output = BigBirdSelfOutput(config) + + def set_attention_type(self, value: str, layer_idx=None): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + + self.attention_type = value + if value == "original_full": + # copy all weights to new full attention class + attn_weights = BigBirdSelfAttention(self.config, layer_idx=layer_idx) + else: + # copy all weights to new sparse attention class + attn_weights = BigBirdBlockSparseAttention(self.config, self.seed) + + attn_weights.query = self.self.query + attn_weights.value = self.self.value + attn_weights.key = self.self.key + self.self = attn_weights + if not self.training: + self.self.eval() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + output_attentions=False, + # block_sparse config + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + cache_position=None, + ): + # fp16 compatibility + if band_mask is not None: + band_mask = band_mask.to(hidden_states.dtype) + if from_mask is not None: + from_mask = from_mask.to(hidden_states.dtype) + if to_mask is not None: + to_mask = to_mask.to(hidden_states.dtype) + if self.attention_type == "original_full": + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + else: + if encoder_hidden_states is not None: + raise ValueError("BigBird cannot be used as a decoder when config.attention_type != 'original_full'") + self_outputs = self.self( + hidden_states, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, output_attentions + ) + + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BigBird +class BigBirdIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BigBird +class BigBirdOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BigBirdLayer(GradientCheckpointingLayer): + def __init__(self, config, seed=None): + super().__init__() + self.config = config + self.attention_type = config.attention_type + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BigBirdAttention(config, seed=seed) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise TypeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BigBirdAttention(config, seed=seed) + self.intermediate = BigBirdIntermediate(config) + self.output = BigBirdOutput(config) + + def set_attention_type(self, value: str, layer_idx=None): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + self.attention.set_attention_type(value, layer_idx=layer_idx) + + if self.add_cross_attention: + self.crossattention.set_attention_type(value, layer_idx=layer_idx) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + band_mask=None, + from_mask=None, + to_mask=None, + blocked_encoder_mask=None, + past_key_values=None, + output_attentions=False, + cache_position=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=blocked_encoder_mask, + to_blocked_mask=blocked_encoder_mask, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + " cross-attention layers by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + + return (layer_output,) + outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BigBirdEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.attention_type = config.attention_type + + self.layer = nn.ModuleList( + [BigBirdLayer(config, seed=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + for i, layer in enumerate(self.layer): + layer.set_attention_type(value, layer_idx=i) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + band_mask=None, + from_mask=None, + to_mask=None, + blocked_encoder_mask=None, + return_dict=True, + cache_position=None, + ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + band_mask, + from_mask, + to_mask, + blocked_encoder_mask, + past_key_values, + output_attentions, + cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->BigBird +class BigBirdPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->BigBird +class BigBirdLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BigBirdPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->BigBird +class BigBirdOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BigBirdLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->BigBird +class BigBirdOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->BigBird +class BigBirdPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BigBirdLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@auto_docstring +class BigBirdPreTrainedModel(PreTrainedModel): + config: BigBirdConfig + load_tf_weights = load_tf_weights_in_big_bird + base_model_prefix = "bert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, BigBirdLMPredictionHead): + module.bias.data.zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`BigBirdForPreTraining`]. + """ +) +class BigBirdForPreTrainingOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + seq_relationship_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of question answering models. + """ +) +class BigBirdForQuestionAnsweringModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + pooler_output (`torch.FloatTensor` of shape `(batch_size, 1)`): + pooler output from BigBigModel + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: Optional[torch.FloatTensor] = None + end_logits: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class BigBirdModel(BigBirdPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.attention_type = self.config.attention_type + self.config = config + + self.block_size = self.config.block_size + + self.embeddings = BigBirdEmbeddings(config) + self.encoder = BigBirdEncoder(config) + + if add_pooling_layer: + self.pooler = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + else: + self.pooler = None + self.activation = None + + if self.attention_type != "original_full" and config.add_cross_attention: + logger.warning( + "When using `BigBirdForCausalLM` as decoder, then `attention_type` must be `original_full`. Setting" + " `attention_type=original_full`" + ) + self.set_attention_type("original_full") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + self.encoder.set_attention_type(value) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple[torch.FloatTensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # in order to use block_sparse attention, sequence_length has to be at least + # bigger than all global attentions: 2 * block_size + # + sliding tokens: 3 * block_size + # + random tokens: 2 * num_random_blocks * block_size + max_tokens_to_attend = (5 + 2 * self.config.num_random_blocks) * self.config.block_size + if self.attention_type == "block_sparse" and seq_length <= max_tokens_to_attend: + # change attention_type from block_sparse to original_full + sequence_length = input_ids.size(1) if input_ids is not None else inputs_embeds.size(1) + logger.warning( + "Attention type 'block_sparse' is not possible if sequence_length: " + f"{sequence_length} <= num global tokens: 2 * config.block_size " + "+ min. num sliding tokens: 3 * config.block_size " + "+ config.num_random_blocks * config.block_size " + "+ additional buffer: config.num_random_blocks * config.block_size " + f"= {max_tokens_to_attend} with config.block_size " + f"= {self.config.block_size}, config.num_random_blocks " + f"= {self.config.num_random_blocks}. " + "Changing attention type to 'original_full'..." + ) + self.set_attention_type("original_full") + + if self.attention_type == "block_sparse": + ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) = self._pad_to_block_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.config.pad_token_id, + ) + else: + padding_len = 0 + + if self.attention_type == "block_sparse": + blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn( + attention_mask, self.block_size + ) + extended_attention_mask = None + + elif self.attention_type == "original_full": + blocked_encoder_mask = None + band_mask = None + from_mask = None + to_mask = None + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.attention_type}" + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + blocked_encoder_mask=blocked_encoder_mask, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + + pooler_output = self.activation(self.pooler(sequence_output[:, 0, :])) if (self.pooler is not None) else None + + # undo padding + if padding_len > 0: + # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) + sequence_output = sequence_output[:, :-padding_len] + + if not return_dict: + return (sequence_output, pooler_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooler_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + @staticmethod + def create_masks_for_block_sparse_attn(attention_mask: torch.Tensor, block_size: int): + batch_size, seq_length = attention_mask.size() + if seq_length % block_size != 0: + raise ValueError( + f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block" + f" size is {block_size}." + ) + + def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + + Returns: + float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, + 3*to_block_size]. + """ + exp_blocked_to_pad = torch.cat( + [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2 + ) + band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) + band_mask.unsqueeze_(1) + return band_mask + + blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size) + band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) + + from_mask = attention_mask.view(batch_size, 1, seq_length, 1) + to_mask = attention_mask.view(batch_size, 1, 1, seq_length) + + return blocked_encoder_mask, band_mask, from_mask, to_mask + + def _pad_to_block_size( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + pad_token_id: int, + ): + """A helper function to pad tokens and mask to work with implementation of BigBird block-sparse attention.""" + # padding + block_size = self.config.block_size + + input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape + batch_size, seq_len = input_shape[:2] + + padding_len = (block_size - seq_len % block_size) % block_size + if padding_len > 0: + logger.warning_once( + f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of " + f"`config.block_size`: {block_size}" + ) + if input_ids is not None: + input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id) + if position_ids is not None: + # pad with position_id = pad_token_id as in modeling_bigbird.BigBirdEmbeddings + position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id) + if inputs_embeds is not None: + input_ids_padding = inputs_embeds.new_full( + (batch_size, padding_len), + self.config.pad_token_id, + dtype=torch.long, + ) + inputs_embeds_padding = self.embeddings(input_ids_padding) + inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) + + attention_mask = nn.functional.pad( + attention_mask, (0, padding_len), value=False + ) # no attention on the padding tokens + token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 + + return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + + +class BigBirdForPreTraining(BigBirdPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BigBirdModel(config, add_pooling_layer=True) + self.cls = BigBirdPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + next_sentence_label: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BigBirdForPreTrainingOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. If specified, nsp loss will be + added to masked_lm loss. Input should be a sequence pair (see `input_ids` docstring) Indices should be in + `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BigBirdForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") + >>> model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + total_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if next_sentence_label is not None and total_loss is not None: + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = total_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BigBirdForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class BigBirdForMaskedLM(BigBirdPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `BigBirdForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = BigBirdModel(config) + self.cls = BigBirdOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, BigBirdForMaskedLM + >>> from datasets import load_dataset + + >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") + >>> model = BigBirdForMaskedLM.from_pretrained("google/bigbird-roberta-base") + >>> squad_ds = load_dataset("rajpurkar/squad_v2", split="train") # doctest: +IGNORE_RESULT + + >>> # select random long article + >>> LONG_ARTICLE_TARGET = squad_ds[81514]["context"] + >>> # select random sentence + >>> LONG_ARTICLE_TARGET[332:398] + 'the highest values are very close to the theoretical maximum value' + + >>> # add mask_token + >>> LONG_ARTICLE_TO_MASK = LONG_ARTICLE_TARGET.replace("maximum", "[MASK]") + >>> inputs = tokenizer(LONG_ARTICLE_TO_MASK, return_tensors="pt") + >>> # long article input + >>> list(inputs["input_ids"].shape) + [1, 919] + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> # retrieve index of [MASK] + >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] + >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) + >>> tokenizer.decode(predicted_token_id) + 'maximum' + ``` + + ```python + >>> labels = tokenizer(LONG_ARTICLE_TARGET, return_tensors="pt")["input_ids"] + >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + >>> outputs = model(**inputs, labels=labels) + >>> round(outputs.loss.item(), 2) + 1.99 + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@auto_docstring( + custom_intro=""" + BigBird Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BigBirdForCausalLM` as a standalone, add `is_decoder=True.`") + + self.bert = BigBirdModel(config) + self.cls = BigBirdOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[CausalLMOutputWithCrossAttentions, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class BigBirdClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@auto_docstring( + custom_intro=""" + BigBird Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class BigBirdForSequenceClassification(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.bert = BigBirdModel(config) + self.classifier = BigBirdClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, BigBirdForSequenceClassification + >>> from datasets import load_dataset + + >>> tokenizer = AutoTokenizer.from_pretrained("l-yohai/bigbird-roberta-base-mnli") + >>> model = BigBirdForSequenceClassification.from_pretrained("l-yohai/bigbird-roberta-base-mnli") + >>> squad_ds = load_dataset("rajpurkar/squad_v2", split="train") # doctest: +IGNORE_RESULT + + >>> LONG_ARTICLE = squad_ds[81514]["context"] + >>> inputs = tokenizer(LONG_ARTICLE, return_tensors="pt") + >>> # long input article + >>> list(inputs["input_ids"].shape) + [1, 919] + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> predicted_class_id = logits.argmax().item() + >>> model.config.id2label[predicted_class_id] + 'LABEL_0' + ``` + + ```python + >>> num_labels = len(model.config.id2label) + >>> model = BigBirdForSequenceClassification.from_pretrained( + ... "l-yohai/bigbird-roberta-base-mnli", num_labels=num_labels + ... ) + >>> labels = torch.tensor(1) + >>> loss = model(**inputs, labels=labels).loss + >>> round(loss.item(), 2) + 1.13 + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class BigBirdForMultipleChoice(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BigBirdModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MultipleChoiceModelOutput, tuple[torch.FloatTensor]]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class BigBirdForTokenClassification(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BigBirdModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[TokenClassifierOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BigBirdForQuestionAnsweringHead(nn.Module): + """Head for question answering tasks.""" + + def __init__(self, config): + super().__init__() + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.intermediate = BigBirdIntermediate(config) + self.output = BigBirdOutput(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, encoder_output): + hidden_states = self.dropout(encoder_output) + hidden_states = self.intermediate(hidden_states) + hidden_states = self.output(hidden_states, encoder_output) + hidden_states = self.qa_outputs(hidden_states) + return hidden_states + + +@auto_docstring +class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): + def __init__(self, config, add_pooling_layer=False): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + self.sep_token_id = config.sep_token_id + + self.bert = BigBirdModel(config, add_pooling_layer=add_pooling_layer) + self.qa_classifier = BigBirdForQuestionAnsweringHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + question_lengths: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BigBirdForQuestionAnsweringModelOutput, tuple[torch.FloatTensor]]: + r""" + question_lengths (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + The lengths of the questions in the batch. + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, BigBirdForQuestionAnswering + >>> from datasets import load_dataset + + >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") + >>> model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base") + >>> squad_ds = load_dataset("rajpurkar/squad_v2", split="train") # doctest: +IGNORE_RESULT + + >>> # select random article and question + >>> LONG_ARTICLE = squad_ds[81514]["context"] + >>> QUESTION = squad_ds[81514]["question"] + >>> QUESTION + 'During daytime how high can the temperatures reach?' + + >>> inputs = tokenizer(QUESTION, LONG_ARTICLE, return_tensors="pt") + >>> # long article and question input + >>> list(inputs["input_ids"].shape) + [1, 929] + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + >>> predict_answer_token_ids = inputs.input_ids[0, answer_start_index : answer_end_index + 1] + >>> predict_answer_token = tokenizer.decode(predict_answer_token_ids) + ``` + + ```python + >>> target_start_index, target_end_index = torch.tensor([130]), torch.tensor([132]) + >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) + >>> loss = outputs.loss + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + seqlen = input_ids.size(1) if input_ids is not None else inputs_embeds.size(1) + + if question_lengths is None and input_ids is not None: + # assuming input_ids format: context + question_lengths = torch.argmax(input_ids.eq(self.sep_token_id).int(), dim=-1) + 1 + question_lengths.unsqueeze_(1) + + logits_mask = None + if question_lengths is not None: + # setting lengths logits to `-inf` + logits_mask = self.prepare_question_mask(question_lengths, seqlen) + if token_type_ids is None: + token_type_ids = torch.ones(logits_mask.size(), dtype=int, device=logits_mask.device) - logits_mask + logits_mask[:, 0] = False + logits_mask.unsqueeze_(2) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.qa_classifier(sequence_output) + + if logits_mask is not None: + # removing question tokens from the competition + logits = logits - logits_mask * 1e6 + + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BigBirdForQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + pooler_output=outputs.pooler_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int): + # q_lengths -> (bz, 1) + mask = torch.arange(0, maxlen).to(q_lengths.device) + mask.unsqueeze_(0) # -> (1, maxlen) + mask = torch.where(mask < q_lengths, 1, 0) + return mask + + +__all__ = [ + "BigBirdForCausalLM", + "BigBirdForMaskedLM", + "BigBirdForMultipleChoice", + "BigBirdForPreTraining", + "BigBirdForQuestionAnswering", + "BigBirdForSequenceClassification", + "BigBirdForTokenClassification", + "BigBirdLayer", + "BigBirdModel", + "BigBirdPreTrainedModel", + "load_tf_weights_in_big_bird", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/big_bird/modeling_flax_big_bird.py b/venv/lib/python3.13/site-packages/transformers/models/big_bird/modeling_flax_big_bird.py new file mode 100644 index 0000000000000000000000000000000000000000..11dcb30f3d47e3ecf6d90d7a7305682d687e5828 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/big_bird/modeling_flax_big_bird.py @@ -0,0 +1,2648 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_big_bird import BigBirdConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base" +_CONFIG_FOR_DOC = "BigBirdConfig" + +remat = nn_partitioning.remat + + +@flax.struct.dataclass +class FlaxBigBirdForPreTrainingOutput(ModelOutput): + """ + Output type of [`BigBirdForPreTraining`]. + + Args: + prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + prediction_logits: jnp.ndarray = None + seq_relationship_logits: jnp.ndarray = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + pooled_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + pooled_output returned by FlaxBigBirdModel. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + start_logits: jnp.ndarray = None + end_logits: jnp.ndarray = None + pooled_output: jnp.ndarray = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +BIG_BIRD_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BigBirdConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +BIG_BIRD_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + +""" + + +class FlaxBigBirdEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.setup + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + if self.config.rescale_embeddings: + inputs_embeds *= self.config.hidden_size**0.5 + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->BigBird +class FlaxBigBirdSelfAttention(nn.Module): + config: BigBirdConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxBigBirdBlockSparseAttention(nn.Module): + config: BigBirdConfig + block_sparse_seed: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + use_bias=self.config.use_bias, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + use_bias=self.config.use_bias, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + use_bias=self.config.use_bias, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + @staticmethod + def transpose_for_scores(x, n_heads, head_size): + new_x_shape = x.shape[:-1] + (n_heads, head_size) + x = x.reshape(*new_x_shape) + return jnp.transpose(x, axes=(0, 2, 1, 3)) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic=True, + output_attentions=False, + ): + n_heads = self.config.num_attention_heads + head_size = self.config.hidden_size // n_heads + + blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn( + attention_mask, self.config.block_size + ) + + query_layer = self.transpose_for_scores(self.query(hidden_states), n_heads, head_size) + key_layer = self.transpose_for_scores(self.key(hidden_states), n_heads, head_size) + value_layer = self.transpose_for_scores(self.value(hidden_states), n_heads, head_size) + + indices_prng_key = None + if not deterministic: + indices_prng_key = self.make_rng("indices") + + attn_output, attn_weights = self.bigbird_block_sparse_attention( + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + blocked_encoder_mask, + blocked_encoder_mask, + n_heads, + head_size, + indices_prng_key=indices_prng_key, + deterministic=deterministic, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=output_attentions, + ) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + @staticmethod + def create_masks_for_block_sparse_attn(attention_mask, block_size: int): + batch_size, seq_length = attention_mask.shape + if seq_length % block_size != 0: + raise ValueError( + f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block" + f" size is {block_size}." + ) + + def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + + Returns: + float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, + 3*to_block_size]. + """ + exp_blocked_to_pad = jnp.concatenate( + [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], axis=2 + ) + band_mask = jnp.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) + band_mask = jnp.expand_dims(band_mask, 1) + return band_mask + + blocked_encoder_mask = attention_mask.reshape(batch_size, seq_length // block_size, block_size) + band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) + + from_mask = attention_mask.reshape(batch_size, 1, seq_length, 1) + to_mask = attention_mask.reshape(batch_size, 1, 1, seq_length) + + return blocked_encoder_mask, band_mask, from_mask, to_mask + + def bigbird_block_sparse_attention( + self, + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + n_heads, + head_size, + indices_prng_key: Optional[jax.random.PRNGKey] = None, + deterministic: Optional[bool] = True, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=None, + ): + # BigBird block-sparse attention as suggested in paper + + # ITC: + # global tokens: 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # ETC: + # global tokens: extra_globals_tokens + 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # Note: + # 1) Currently, ETC is not supported. + # 2) Window size is fixed to 3 blocks & it can be changed only by + # changing `block_size`. + # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be + # controlled only by `block_size`. + + # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of + # shifting tokens (for calculating sliding attention). hence following code can be divided into 5 parts. + + bsz, _, from_seq_len, _ = query_layer.shape + to_seq_len = key_layer.shape[2] + from_block_size = to_block_size = self.config.block_size + + if from_seq_len % from_block_size != 0: + raise ValueError("Query sided sequence length must be multiple of block size") + + if to_seq_len % to_block_size != 0: + raise ValueError("Key/Value sided sequence length must be multiple of block size") + + if from_seq_len // from_block_size != to_seq_len // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + n_rand_blocks = self.config.num_random_blocks + rsqrt_d = 1 / jnp.sqrt(head_size) + attn_mask_penalty = -10000.0 + + if from_seq_len in [1024, 3072, 4096]: # old plans used in paper + max_seqlen = self.config.max_position_embeddings + rand_attn = [ + self._bigbird_block_rand_mask( + max_seqlen, + max_seqlen, + from_block_size, + to_block_size, + n_rand_blocks, + indices_prng_key=indices_prng_key, + deterministic=deterministic, + last_idx=1024, + )[: (from_seq_len // from_block_size - 2)] + for _ in range(n_heads) + ] + else: + if plan_from_length is None: + plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( + from_seq_len, from_block_size, n_rand_blocks + ) + rand_attn = self._bigbird_block_rand_mask_with_head( + from_seq_length=from_seq_len, + to_seq_length=to_seq_len, + from_block_size=from_block_size, + to_block_size=to_block_size, + num_heads=n_heads, + plan_from_length=plan_from_length, + plan_num_rand_blocks=plan_num_rand_blocks, + indices_prng_key=indices_prng_key, + ) + + rand_attn = jnp.stack(rand_attn, axis=0) + rand_attn = jnp.broadcast_to(rand_attn, (bsz,) + rand_attn.shape) + + rand_mask = self._create_rand_mask_from_inputs( + from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size + ) + + blocked_query_matrix = query_layer.reshape(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) + blocked_key_matrix = key_layer.reshape(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + blocked_value_matrix = value_layer.reshape(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + + shape = (bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1) + gathered_key = self.jax_gather(blocked_key_matrix, rand_attn, batch_dims=2).reshape(*shape) + gathered_value = self.jax_gather(blocked_value_matrix, rand_attn, batch_dims=2).reshape(*shape) + + # 1st PART + # 1st block (global block) attention scores + # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + first_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, 0], key_layer) + + first_product = first_product * rsqrt_d + first_product += (1.0 - to_mask) * attn_mask_penalty + first_attn_weights = jax.nn.softmax(first_product, axis=-1) # [bsz, n_heads, from_block_size, to_seq_len] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + first_context_layer = jnp.einsum("bhqk,bhkd->bhqd", first_attn_weights, value_layer) + first_context_layer = jnp.expand_dims(first_context_layer, 2) + + # 2nd PART + # 2nd block attention scores + # q[1] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> 2nd, 3rd blocks + # global key blocks -> 1st block + + second_key_mat = jnp.concatenate( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, 1], + blocked_key_matrix[:, :, 2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, 0], + ], + axis=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + second_value_mat = jnp.concatenate( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, 1], + blocked_value_matrix[:, :, 2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, 0], + ], + axis=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + # ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, 1], second_key_mat) + second_seq_pad = jnp.concatenate( + [ + to_mask[:, :, :, : 3 * to_block_size], + to_mask[:, :, :, -to_block_size:], + jnp.ones([bsz, 1, 1, n_rand_blocks * to_block_size], dtype=to_mask.dtype), + ], + axis=3, + ) + second_rand_pad = jnp.concatenate( + [ + jnp.ones([bsz, n_heads, from_block_size, 4 * to_block_size], dtype=rand_mask.dtype), + rand_mask[:, :, 0], + ], + axis=3, + ) + second_product = second_product * rsqrt_d + second_product += (1.0 - jnp.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty + second_attn_weights = jax.nn.softmax( + second_product, axis=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+r)*to_block_size] x [bsz, n_heads, (4+r)*to_block_size, -1] + # ==> [bsz, n_heads, from_block_size, -1] + second_context_layer = jnp.einsum("bhqk,bhkd->bhqd", second_attn_weights, second_value_mat) + second_context_layer = jnp.expand_dims(second_context_layer, 2) + + # 3rd PART + # Middle blocks attention scores + # q[-2:2] x (sliding_keys, random_keys, global_keys) + # sliding attn is calculated using special trick of shifting tokens as discussed in paper + # random keys are generated by taking random indices as per `rand_attn` + # global keys -> 1st & last block + + exp_blocked_key_matrix = jnp.concatenate( + [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], axis=3 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + exp_blocked_value_matrix = jnp.concatenate( + [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], + axis=3, + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + middle_query_matrix = blocked_query_matrix[:, :, 2:-2] + + # sliding attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + inner_band_product = jnp.einsum("bhlqd,bhlkd->bhlqk", middle_query_matrix, exp_blocked_key_matrix) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] + inner_band_product = inner_band_product * rsqrt_d + + # randn attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + # x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + rand_band_product = jnp.einsum("bhlqd,bhlkd->bhlqk", middle_query_matrix, gathered_key[:, :, 1:-1]) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] + rand_band_product = rand_band_product * rsqrt_d + + # Including 1st block (since it's global) + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + first_band_product = jnp.einsum("bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0]) + first_band_product = first_band_product * rsqrt_d + + # Including last block (since it's global) + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + last_band_product = jnp.einsum("bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1]) + last_band_product = last_band_product * rsqrt_d + + # masking padded tokens + inner_band_product += (1.0 - band_mask) * attn_mask_penalty + first_band_product += (1.0 - jnp.expand_dims(to_mask[:, :, :, :to_block_size], 3)) * attn_mask_penalty + last_band_product += (1.0 - jnp.expand_dims(to_mask[:, :, :, -to_block_size:], 3)) * attn_mask_penalty + rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty + + # completing attention scores matrix for all q[-2:2] + band_product = jnp.concatenate( + [first_band_product, inner_band_product, rand_band_product, last_band_product], axis=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # safely doing softmax since attention matrix is completed + attn_weights = jax.nn.softmax( + band_product, axis=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # contribution of sliding keys + # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] + # x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + context_layer = jnp.einsum( + "bhlqk,bhlkd->bhlqd", attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of random keys + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] + # x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + context_layer += jnp.einsum( + "bhlqk,bhlkd->bhlqd", + attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], + gathered_value[:, :, 1:-1], + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of global keys + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + context_layer += jnp.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] + ) + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + context_layer += jnp.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] + ) + + # 4th PART + # last 2nd token attention scores + # q[-2] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> last 3 blocks + # global key block -> 1st block + # random key block -> based on indices stored in `randn_attn` + + second_last_key_mat = jnp.concatenate( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, -3], + blocked_key_matrix[:, :, -2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, -1], + ], + axis=2, + ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] + second_last_value_mat = jnp.concatenate( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, -3], + blocked_value_matrix[:, :, -2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, -1], + ], + axis=2, + ) # [bsz, n_heads, (4+r)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + # ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_last_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, -2], second_last_key_mat) + second_last_seq_pad = jnp.concatenate( + [ + to_mask[:, :, :, :to_block_size], + to_mask[:, :, :, -3 * to_block_size :], + jnp.ones([bsz, 1, 1, n_rand_blocks * to_block_size], dtype=to_mask.dtype), + ], + axis=3, + ) + second_last_rand_pad = jnp.concatenate( + [ + jnp.ones([bsz, n_heads, from_block_size, 4 * to_block_size], dtype=rand_mask.dtype), + rand_mask[:, :, -1], + ], + axis=3, + ) + second_last_product = second_last_product * rsqrt_d + second_last_product += (1.0 - jnp.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty + second_last_attn_weights = jax.nn.softmax( + second_last_product, axis=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + # ==> [bsz, n_heads, from_block_size, -1] + second_last_context_layer = jnp.einsum("bhqk,bhkd->bhqd", second_last_attn_weights, second_last_value_mat) + second_last_context_layer = jnp.expand_dims(second_last_context_layer, 2) + + # 5th PART + # last block (global) attention scores + # q[-1] x (k[0], k[1], k[2], k[3], .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + last_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, -1], key_layer) + last_product = last_product * rsqrt_d + last_product += (1.0 - to_mask) * attn_mask_penalty + last_attn_weights = jax.nn.softmax(last_product, axis=-1) # [bsz, n_heads, from_block_size, n] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + last_context_layer = jnp.einsum("bhqk,bhkd->bhqd", last_attn_weights, value_layer) + last_context_layer = jnp.expand_dims(last_context_layer, 2) + + # combining representations of all tokens + context_layer = jnp.concatenate( + [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], + axis=2, + ) + context_layer = context_layer.reshape(bsz, n_heads, from_seq_len, -1) * from_mask + context_layer = jnp.transpose(context_layer, axes=(0, 2, 1, 3)).reshape(bsz, from_seq_len, -1) + + attention_probs = None + + return context_layer, attention_probs + + @staticmethod + def jax_gather(params, indices, batch_dims=2): + """ + Gather the indices from params correctly (equivalent to tf.gather but with modifications) + + Args: + params: (bsz, n_heads, num_blocks, block_size, head_dim) + indices: (bhlqk", from_blocked_mask[:, 1:-1], rand_mask) + return rand_mask + + @staticmethod + def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): + """ + Gives the plan of where to put random attention. + + Args: + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + num_rand_blocks: int. Number of random chunks per row. + + Returns: + plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for + each block + """ + + plan_from_length = [] + plan_num_rand_blocks = [] + if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(0) + elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks // 2) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) + else: + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks) + + return plan_from_length, plan_num_rand_blocks + + @staticmethod + def _bigbird_block_rand_mask( + from_seq_length, + to_seq_length, + from_block_size, + to_block_size, + num_rand_blocks, + indices_prng_key: Optional[jax.random.PRNGKey] = None, + deterministic: Optional[bool] = True, + last_idx: Optional[int] = -1, + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_rand_blocks: int. Number of random chunks per row. + indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations. + deterministic: bool. When False random attention will be used. + last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, + if positive then num_rand_blocks blocks chosen only up to last_idx. + + Returns: + adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks + """ + # using this method when from_seq_length in [1024, 3072, 4096] + + if from_seq_length // from_block_size != to_seq_length // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + rand_attn = jnp.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=jnp.int32) + # deterministic nor randomness + if deterministic: + return rand_attn + + middle_seq = jnp.arange(1, to_seq_length // to_block_size - 1, dtype=jnp.int32) + last = to_seq_length // to_block_size - 1 + if last_idx > (2 * to_block_size): + last = (last_idx // to_block_size) - 1 + + r = num_rand_blocks # shorthand + for i in range(1, from_seq_length // from_block_size - 1): + start = i - 2 + end = i + if i == 1: + seq_values = jax.random.permutation(indices_prng_key, middle_seq[2:last])[:r] + rand_attn = rand_attn.at[i - 1].set(seq_values) + elif i == 2: + seq_values = jax.random.permutation(indices_prng_key, middle_seq[3:last])[:r] + rand_attn = rand_attn.at[i - 1].set(seq_values) + elif i == from_seq_length // from_block_size - 3: + seq_values = jax.random.permutation(indices_prng_key, middle_seq[:last])[:r] + rand_attn = rand_attn.at[i - 1].set(seq_values) + # Missing -3: should have been sliced till last-3 + elif i == from_seq_length // from_block_size - 2: + seq_values = jax.random.permutation(indices_prng_key, middle_seq[:last])[:r] + rand_attn = rand_attn.at[i - 1].set(seq_values) + # Missing -4: should have been sliced till last-4 + else: + if start > last: + start = last + seq_values = jax.random.permutation(indices_prng_key, middle_seq[:start])[:r] + rand_attn = rand_attn.at[i - 1].set(seq_values) + elif (end + 1) == last: + seq_values = jax.random.permutation(indices_prng_key, middle_seq[:start])[:r] + rand_attn = rand_attn.at[i - 1].set(seq_values) + else: + concat_values = jnp.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) + seq_values = jax.random.permutation(indices_prng_key, concat_values)[:r] + rand_attn = rand_attn.at[i - 1].set(seq_values) + return rand_attn + + def _bigbird_block_rand_mask_with_head( + self, + from_seq_length, + to_seq_length, + from_block_size, + to_block_size, + num_heads, + plan_from_length, + plan_num_rand_blocks, + indices_prng_key: Optional[jax.random.PRNGKey] = None, + deterministic: Optional[bool] = True, + window_block_left=1, + window_block_right=1, + global_block_top=1, + global_block_bottom=1, + global_block_left=1, + global_block_right=1, + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_heads: int. total number of heads. + plan_from_length: list. plan from length where num_random_blocks are chosen from. + plan_num_rand_blocks: list. number of rand blocks within the plan. + indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations. + deterministic: bool. When False random attention will be used. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_top: int. number of blocks at the top. + global_block_bottom: int. number of blocks at the bottom. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by + num_rand_blocks + """ + # using this method when from_seq_length not in [1024, 3072, 4096] + + if from_seq_length // from_block_size != to_seq_length // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + if from_seq_length not in plan_from_length: + raise ValueError("Error from sequence length not in plan!") + + # Total number of blocks in the mmask + num_blocks = from_seq_length // from_block_size + # Number of blocks per plan + plan_block_length = jnp.array(plan_from_length) // from_block_size + # till when to follow plan + max_plan_idx = plan_from_length.index(from_seq_length) + + # Random Attention adjacency list + rand_attn = [ + jnp.zeros((num_blocks, sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=jnp.int32) + for i in range(num_heads) + ] + + # deterministic + if deterministic: + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + return rand_attn + + # We will go iteratively over the plan blocks and pick random number of + # Attention blocks from the legally allowed blocks + for plan_idx in range(max_plan_idx + 1): + rnd_r_cnt = 0 + if plan_idx > 0: + # set the row for all from_blocks starting from 0 to + # plan_block_length[plan_idx-1] + # column indx start from plan_block_length[plan_idx-1] and ends at + # plan_block_length[plan_idx] + if plan_num_rand_blocks[plan_idx] > 0: + rnd_r_cnt = int(sum(plan_num_rand_blocks[:plan_idx])) + curr_r_cnt = int(sum(plan_num_rand_blocks[: plan_idx + 1])) + for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): + for h in range(num_heads): + single_block_row_attention = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=plan_block_length[plan_idx - 1], + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + indices_prng_key=indices_prng_key, + ) + rand_attn[h] = ( + rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention) + ) + + for pl_id in range(plan_idx): + if plan_num_rand_blocks[pl_id] == 0: + continue + for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): + rnd_r_cnt = 0 + to_start_block_id = 0 + if pl_id > 0: + rnd_r_cnt = int(sum(plan_num_rand_blocks[:pl_id])) + to_start_block_id = plan_block_length[pl_id - 1] + curr_r_cnt = int(sum(plan_num_rand_blocks[: pl_id + 1])) + for h in range(num_heads): + single_block_row_attention = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[pl_id], + num_rand_blocks=plan_num_rand_blocks[pl_id], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + indices_prng_key=indices_prng_key, + ) + rand_attn[h] = ( + rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention) + ) + + if plan_num_rand_blocks[plan_idx] == 0: + continue + curr_r_cnt = int(sum(plan_num_rand_blocks[: plan_idx + 1])) + from_start_block_id = global_block_top + to_start_block_id = 0 + if plan_idx > 0: + rnd_r_cnt = int(sum(plan_num_rand_blocks[:plan_idx])) + from_start_block_id = plan_block_length[plan_idx - 1] + to_start_block_id = plan_block_length[plan_idx - 1] + for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): + for h in range(num_heads): + single_block_row_attention = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + indices_prng_key=indices_prng_key, + ) + rand_attn[h] = rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention) + + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + return rand_attn + + @staticmethod + def _get_single_block_row_attention( + block_id, + to_start_block_id, + to_end_block_id, + num_rand_blocks, + indices_prng_key: Optional[jax.random.PRNGKey] = None, + window_block_left=1, + window_block_right=1, + global_block_left=1, + global_block_right=1, + ): + """ + For a single row block get random row attention. + + Args: + block_id: int. block id of row. + to_start_block_id: int. random attention column start id. + to_end_block_id: int. random attention column end id. + num_rand_blocks: int. number of random blocks to be selected. + indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + row containing the random attention vector of size num_rand_blocks. + """ + # list of to_blocks from which to choose random attention + to_block_list = jnp.arange(to_start_block_id, to_end_block_id, dtype=jnp.int32) + # permute the blocks + perm_block = jax.random.permutation(indices_prng_key, to_block_list) + + # illegal blocks for the current block id, using window + illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) + + # Add blocks at the start and at the end + illegal_blocks.extend(list(range(global_block_left))) + illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) + + # The second from_block cannot choose random attention on second last to_block + if block_id == 1: + illegal_blocks.append(to_end_block_id - 2) + + # The second last from_block cannot choose random attention on second to_block + if block_id == to_end_block_id - 2: + illegal_blocks.append(1) + + selected_random_blocks = [] + + for i in range(to_end_block_id - to_start_block_id): + if perm_block[i] not in illegal_blocks: + selected_random_blocks.append(perm_block[i]) + if len(selected_random_blocks) == num_rand_blocks: + break + return jnp.array(selected_random_blocks, dtype=jnp.int32) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->BigBird +class FlaxBigBirdSelfOutput(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FlaxBigBirdAttention(nn.Module): + config: BigBirdConfig + layer_id: Optional[int] = None + causal: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + if self.config.attention_type == "original_full": + self.self = FlaxBigBirdSelfAttention(self.config, causal=self.causal, dtype=self.dtype) + elif self.config.attention_type == "block_sparse": + self.self = FlaxBigBirdBlockSparseAttention(self.config, block_sparse_seed=self.layer_id, dtype=self.dtype) + else: + raise ValueError( + f"Your `config.attention_type` is {self.config.attention_type} but it can either be `original_full` or" + " `block_sparse`" + ) + + self.output = FlaxBigBirdSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + if self.config.attention_type == "original_full": + attn_outputs = self.self( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + else: + attn_outputs = self.self( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->BigBird +class FlaxBigBirdIntermediate(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->BigBird +class FlaxBigBirdOutput(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +class FlaxBigBirdLayer(nn.Module): + config: BigBirdConfig + layer_id: Optional[int] = None + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxBigBirdAttention( + self.config, layer_id=self.layer_id, causal=self.config.is_decoder, dtype=self.dtype + ) + self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype) + self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxBigBirdAttention(self.config, causal=False, dtype=self.dtype) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +class FlaxBigBirdLayerCollection(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + if self.gradient_checkpointing: + FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->BigBird +class FlaxBigBirdEncoder(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.layer = FlaxBigBirdLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->BigBird +class FlaxBigBirdPredictionHeadTransform(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.activation = ACT2FN[self.config.hidden_act] + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return self.LayerNorm(hidden_states) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->BigBird, np.ndarray->jnp.ndarray +class FlaxBigBirdLMPredictionHead(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.transform = FlaxBigBirdPredictionHeadTransform(self.config, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.transform(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->BigBird +class FlaxBigBirdOnlyMLMHead(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.predictions = FlaxBigBirdLMPredictionHead(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) + return hidden_states + + +class FlaxBigBirdPreTrainingHeads(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.predictions = FlaxBigBirdLMPredictionHead(self.config, dtype=self.dtype) + self.seq_relationship = nn.Dense(2, dtype=self.dtype) + + def __call__(self, hidden_states, pooled_output, shared_embedding=None): + prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BigBirdConfig + base_model_prefix = "bert" + module_class: nn.Module = None + + def __init__( + self, + config: BigBirdConfig, + input_shape: Optional[tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + if config.attention_type == "block_sparse" and input_shape is None: + input_shape = (1, 12 * config.block_size) + elif input_shape is None: + input_shape = (1, 1) + + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.zeros_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng, indices_rng = jax.random.split(rng, num=3) + rngs = {"params": params_rng, "dropout": dropout_rng, "indices": indices_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + return_dict=False, + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: Optional[dict] = None, + dropout_rng: Optional[jax.random.PRNGKey] = None, + indices_rng: Optional[jax.random.PRNGKey] = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: Optional[dict] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if indices_rng is not None: + rngs["indices"] = indices_rng + + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxBigBirdAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +class FlaxBigBirdModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + gradient_checkpointing: bool = False + + def setup(self): + self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxBigBirdEncoder( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.pooler = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + + pooled = nn.tanh(self.pooler(hidden_states[:, 0, :])) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The bare BigBird Model transformer outputting raw hidden-states without any specific head on top.", + BIG_BIRD_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModel with Bert->BigBird +class FlaxBigBirdModel(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdModule + + +append_call_sample_docstring(FlaxBigBirdModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForPreTrainingModule with Bert->BigBird +class FlaxBigBirdForPreTrainingModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + hidden_states = outputs[0] + pooled_output = outputs[1] + + prediction_scores, seq_relationship_score = self.cls( + hidden_states, pooled_output, shared_embedding=shared_embedding + ) + + if not return_dict: + return (prediction_scores, seq_relationship_score) + outputs[2:] + + return FlaxBigBirdForPreTrainingOutput( + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BIG_BIRD_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForPreTraining with Bert->BigBird +class FlaxBigBirdForPreTraining(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForPreTrainingModule + + +FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxBigBirdForPreTraining + + >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") + >>> model = FlaxBigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` +""" + +overwrite_call_docstring( + FlaxBigBirdForPreTraining, + BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxBigBirdForPreTraining, output_type=FlaxBigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLMModule with Bert->BigBird +class FlaxBigBirdForMaskedLMModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBigBirdModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLM with Bert->BigBird +class FlaxBigBirdForMaskedLM(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForMaskedLMModule + + +append_call_sample_docstring(FlaxBigBirdForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) + + +class FlaxBigBirdClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__(self, features, deterministic=True): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, deterministic=deterministic) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x, deterministic=deterministic) + x = self.out_proj(x) + return x + + +class FlaxBigBirdForSequenceClassificationModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBigBirdModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[2:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForSequenceClassification with Bert->BigBird +class FlaxBigBirdForSequenceClassification(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxBigBirdForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->BigBird +class FlaxBigBirdForMultipleChoiceModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForMultipleChoiceModule + + def __init__( + self, + config: BigBirdConfig, + input_shape: Optional[tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if config.attention_type == "block_sparse" and input_shape is None: + input_shape = (1, 1, 12 * config.block_size) + elif input_shape is None: + input_shape = (1, 1) + super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + +overwrite_call_docstring( + FlaxBigBirdForMultipleChoice, BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxBigBirdForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->BigBird +class FlaxBigBirdForTokenClassificationModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassification with Bert->BigBird +class FlaxBigBirdForTokenClassification(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForTokenClassificationModule + + +append_call_sample_docstring( + FlaxBigBirdForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxBigBirdForQuestionAnsweringHead(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype) + self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__(self, encoder_output, deterministic=True): + hidden_states = self.dropout(encoder_output, deterministic=deterministic) + hidden_states = self.intermediate(hidden_states) + hidden_states = self.output(hidden_states, encoder_output) + hidden_states = self.qa_outputs(hidden_states) + return hidden_states + + +class FlaxBigBirdForQuestionAnsweringModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + add_pooling_layer: bool = False + gradient_checkpointing: bool = False + + def setup(self): + self.config.num_labels = 2 + self.bert = FlaxBigBirdModule( + self.config, + dtype=self.dtype, + add_pooling_layer=self.add_pooling_layer, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + logits_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + pooled_output = outputs[1] if self.add_pooling_layer else None + logits = self.qa_classifier(hidden_states, deterministic=deterministic) + + if logits_mask is not None: + # removing question tokens from the competition + logits = logits - logits_mask * 1e6 + + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxBigBirdForQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + pooled_output=pooled_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BIG_BIRD_START_DOCSTRING, +) +class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForQuestionAnsweringModule + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + question_lengths=None, + params: Optional[dict] = None, + dropout_rng: Optional[jax.random.PRNGKey] = None, + indices_rng: Optional[jax.random.PRNGKey] = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + if question_lengths is None and input_ids is not None: + # assuming input_ids format: context + question_lengths = jnp.argmax((input_ids == self.config.sep_token_id).astype("i4"), axis=-1) + 1 + question_lengths = jnp.expand_dims(question_lengths, axis=1) + + seqlen = input_ids.shape[1] + + logits_mask = None + if question_lengths is not None: + # setting lengths logits to `-inf` + logits_mask = self.prepare_question_mask(question_lengths, seqlen) + if token_type_ids is None: + token_type_ids = (~logits_mask).astype("i4") + logits_mask = jnp.expand_dims(logits_mask, axis=2) + logits_mask = logits_mask.at[:, 0].set(False) + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + if indices_rng is not None: + rngs["indices"] = indices_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids, + jnp.array(position_ids, dtype="i4"), + jnp.array(head_mask, dtype="i4"), + logits_mask, + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + @staticmethod + def prepare_question_mask(q_lengths, maxlen: int): + # q_lengths -> (bz, 1) + mask = jnp.arange(0, maxlen) + mask = jnp.expand_dims(mask, axis=0) < q_lengths + return mask + + +append_call_sample_docstring( + FlaxBigBirdForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxBigBirdForQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxBigBirdForCausalLMModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBigBirdModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->BigBird +class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxBigBirdForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxBigBirdForCausalLM", + "FlaxBigBirdForMaskedLM", + "FlaxBigBirdForMultipleChoice", + "FlaxBigBirdForPreTraining", + "FlaxBigBirdForQuestionAnswering", + "FlaxBigBirdForSequenceClassification", + "FlaxBigBirdForTokenClassification", + "FlaxBigBirdModel", + "FlaxBigBirdPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/big_bird/tokenization_big_bird.py b/venv/lib/python3.13/site-packages/transformers/models/big_bird/tokenization_big_bird.py new file mode 100644 index 0000000000000000000000000000000000000000..56680972a63e6360eaaefc023e57bd6beff71b9d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/big_bird/tokenization_big_bird.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for BigBird.""" + +import os +import re +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +@requires(backends=("sentencepiece",)) +class BigBirdTokenizer(PreTrainedTokenizer): + """ + Construct a BigBird tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `""`): + The begin of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + prefix_tokens: list[int] = [] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sep_token="[SEP]", + mask_token="[MASK]", + cls_token="[CLS]", + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + sep_token=sep_token, + mask_token=mask_token, + cls_token=cls_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def _decode( + self, + token_ids: list[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + spaces_between_special_tokens: bool = True, + **kwargs, + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + + # Mimic the behavior of the Rust tokenizer: + # No space before [MASK] and [SEP] + if spaces_between_special_tokens: + text = re.sub(r" (\[(MASK|SEP)\])", r"\1", " ".join(sub_texts)) + else: + text = "".join(sub_texts) + + clean_up_tokenization_spaces = ( + clean_up_tokenization_spaces + if clean_up_tokenization_spaces is not None + else self.clean_up_tokenization_spaces + ) + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A Big Bird sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + +__all__ = ["BigBirdTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/big_bird/tokenization_big_bird_fast.py b/venv/lib/python3.13/site-packages/transformers/models/big_bird/tokenization_big_bird_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..6148585a40b10385098c69714f650e1f42c22b4f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/big_bird/tokenization_big_bird_fast.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Big Bird model.""" + +import os +from shutil import copyfile +from typing import Optional + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_big_bird import BigBirdTokenizer +else: + BigBirdTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +SPIECE_UNDERLINE = "▁" + + +class BigBirdTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" BigBird tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This + tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token + that is used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = BigBirdTokenizer + model_input_names = ["input_ids", "attention_mask"] + prefix_tokens: list[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sep_token="[SEP]", + mask_token="[MASK]", + cls_token="[CLS]", + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + self.vocab_file = vocab_file + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An BigBird sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return cls + token_ids_0 + sep + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Set to True if the token list is already formatted with special tokens for the model + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0] + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["BigBirdTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d203b2a578f9a4c6125cffb6041a627167518680 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bigbird_pegasus import * + from .modeling_bigbird_pegasus import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..29b481c78ad1a99c7c720a257211bc996ba4e714 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py @@ -0,0 +1,413 @@ +# coding=utf-8 +# Copyright Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BigBirdPegasus model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import Any, Optional + +from ... import PreTrainedTokenizer +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx.utils import compute_effective_axis_dimension +from ...utils import TensorType, is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +class BigBirdPegasusConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BigBirdPegasusModel`]. It is used to instantiate + an BigBirdPegasus model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BigBirdPegasus + [google/bigbird-pegasus-large-arxiv](https://huggingface.co/google/bigbird-pegasus-large-arxiv) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 96103): + Vocabulary size of the BigBirdPegasus model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`BigBirdPegasusModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimension of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 16): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 16): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 1024 or 2048 or 4096). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + attention_type (`str`, *optional*, defaults to `"block_sparse"`) + Whether to use block sparse attention (with n complexity) as introduced in paper or original attention + layer (with n^2 complexity) in encoder. Possible values are `"original_full"` and `"block_sparse"`. + use_bias (`bool`, *optional*, defaults to `False`) + Whether to use bias in query, key, value. + block_size (`int`, *optional*, defaults to 64) + Size of each block. Useful only when `attention_type == "block_sparse"`. + num_random_blocks (`int`, *optional*, defaults to 3) + Each query is going to attend these many number of random blocks. Useful only when `attention_type == + "block_sparse"`. + scale_embeddings (`bool`, *optional*, defaults to `True`) + Whether to rescale embeddings with (hidden_size ** 0.5). + + Example: + + ```python + >>> from transformers import BigBirdPegasusConfig, BigBirdPegasusModel + + >>> # Initializing a BigBirdPegasus bigbird-pegasus-base style configuration + >>> configuration = BigBirdPegasusConfig() + + >>> # Initializing a model (with random weights) from the bigbird-pegasus-base style configuration + >>> model = BigBirdPegasusModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "bigbird_pegasus" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model", + "attention_probs_dropout_prob": "attention_dropout", + } + + def __init__( + self, + vocab_size=96103, + max_position_embeddings=4096, + encoder_layers=16, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=16, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu_new", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + classifier_dropout=0.0, + scale_embedding=True, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + attention_type="block_sparse", # only for encoder + block_size=64, + num_random_blocks=3, + use_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + # extra config + self.attention_type = attention_type + self.block_size = block_size + self.num_random_blocks = num_random_blocks + self.use_bias = use_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig +class BigBirdPegasusOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + elif self.task == "causal-lm": + # TODO: figure this case out. + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + else: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), + ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ] + ) + + return common_inputs + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_outputs = super().outputs + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + return common_outputs + + def _generate_dummy_inputs_for_default_and_seq2seq_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, encoder_seq_length = common_inputs["input_ids"].shape + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_past_length = decoder_seq_length + 3 + decoder_shape = ( + batch, + num_decoder_attention_heads, + decoder_past_length, + self._config.hidden_size // num_decoder_attention_heads, + ) + + common_inputs["decoder_attention_mask"] = torch.cat( + [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 + ) + + common_inputs["past_key_values"] = [] + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + for _ in range(min_num_layers): + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + # TODO: test this. + shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape + for _ in range(min_num_layers, max_num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + return common_inputs + + def _generate_dummy_inputs_for_causal_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + num_encoder_layers, _ = self.num_layers + num_encoder_attention_heads, _ = self.num_attention_heads + past_shape = ( + batch, + num_encoder_attention_heads, + past_key_values_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + + mask_dtype = common_inputs["attention_mask"].dtype + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + common_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) + ] + return common_inputs + + def _generate_dummy_inputs_for_sequence_classification_and_question_answering( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # Copied from OnnxConfig.generate_dummy_inputs + # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + return common_inputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + elif self.task == "causal-lm": + common_inputs = self._generate_dummy_inputs_for_causal_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + else: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + return common_inputs + + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + if self.task in ["default", "seq2seq-lm"]: + flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( + flattened_output, name, idx, t + ) + + +__all__ = ["BigBirdPegasusConfig", "BigBirdPegasusOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..e419af75da3878553dc755776550d305e4d717ba --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -0,0 +1,3033 @@ +# coding=utf-8 +# Copyright 2021 Google Research The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BigBirdPegasus model.""" + +import math +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_bigbird_pegasus import BigBirdPegasusConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_EXPECTED_OUTPUT_SHAPE = [1, 7, 1024] + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__(num_embeddings, embedding_dim) + + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) + + +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus +class BigBirdPegasusScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus +class BigBirdPegasusSelfAttention(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + output_attentions=False, + cache_position=None, + ): + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if is_cross_attention and past_key_values is not None and past_key_values.get_seq_length(self.layer_idx) > 0: + # reuse k,v, cross_attentions + key_layer = past_key_values.layers[self.layer_idx].keys + value_layer = past_key_values.layers[self.layer_idx].values + else: + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + key_layer, value_layer = past_key_values.update( + key_layer, + value_layer, + self.layer_idx, + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BigBirdPegasusModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdBlockSparseAttention with BigBird->BigBirdPegasus +class BigBirdPegasusBlockSparseAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + + self.max_seqlen = config.max_position_embeddings + self.seed = seed + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.num_random_blocks = config.num_random_blocks + self.block_size = config.block_size + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + def forward( + self, + hidden_states, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + output_attentions=None, + ): + # Currently this `class` can't be used in decoder. + + batch_size, seqlen, _ = hidden_states.size() + to_seq_length = from_seq_length = seqlen + from_block_size = to_block_size = self.block_size + + if from_seq_length % from_block_size != 0: + raise ValueError("Query sided sequence length must be multiple of block size") + + if to_seq_length % to_block_size != 0: + raise ValueError("Key/Value sided sequence length must be multiple of block size") + + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + context_layer, attention_probs = self.bigbird_block_sparse_attention( + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + self.num_attention_heads, + self.num_random_blocks, + self.attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_length, + to_seq_length, + seed=self.seed, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=output_attentions, + ) + + context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) + return context_layer, attention_probs + + @staticmethod + def torch_bmm_nd(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication""" + # faster replacement of torch.einsum ("bhqk,bhkd->bhqd") + return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view( + inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]) + ) + + @staticmethod + def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication with transpose""" + # faster replacement of torch.einsum (bhqd,bhkd->bhqk) + return torch.bmm( + inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2) + ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2])) + + def bigbird_block_sparse_attention( + self, + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + n_heads, + n_rand_blocks, + attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_len, + to_seq_len, + seed, + plan_from_length, + plan_num_rand_blocks, + output_attentions, + ): + # BigBirdPegasus block-sparse attention as suggested in paper + + # ITC: + # global tokens: 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # ETC: + # global tokens: extra_globals_tokens + 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # Note: + # 1) Currently, ETC is not supported. + # 2) Window size is fixed to 3 blocks & it can be changed only by + # changing `block_size`. + # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be + # controlled only by `block_size`. + + # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of shifting tokens (for calculating sliding attention) + # hence following code can be divided into 5 parts. + + if from_seq_len // from_block_size != to_seq_len // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + rsqrt_d = 1 / math.sqrt(attention_head_size) + bsz = batch_size + attn_mask_penalty = -10000.0 + + # generate random attention and corresponding masks + np.random.seed(seed) + if from_seq_len in [1024, 3072, 4096]: # old plans used in paper + rand_attn = [ + self._bigbird_block_rand_mask( + self.max_seqlen, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024 + )[: (from_seq_len // from_block_size - 2)] + for _ in range(n_heads) + ] + else: + if plan_from_length is None: + plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( + from_seq_len, from_block_size, n_rand_blocks + ) + + rand_attn = self._bigbird_block_rand_mask_with_head( + from_seq_length=from_seq_len, + to_seq_length=to_seq_len, + from_block_size=from_block_size, + to_block_size=to_block_size, + num_heads=n_heads, + plan_from_length=plan_from_length, + plan_num_rand_blocks=plan_num_rand_blocks, + ) + + rand_attn = np.stack(rand_attn, axis=0) + rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long) + rand_attn.unsqueeze_(0) + rand_attn = torch.cat([rand_attn for _ in range(batch_size)], dim=0) + + rand_mask = self._create_rand_mask_from_inputs( + from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size + ) + + blocked_query_matrix = query_layer.view(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) + blocked_key_matrix = key_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + blocked_value_matrix = value_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + + # preparing block for randn attn + gathered_key = self.torch_gather_b2(blocked_key_matrix, rand_attn) + gathered_key = gathered_key.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + gathered_value = self.torch_gather_b2(blocked_value_matrix, rand_attn) + gathered_value = gathered_value.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + + # 1st PART + # 1st block (global block) attention scores + # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4) + + first_product = first_product * rsqrt_d + first_product += (1.0 - to_mask) * attn_mask_penalty + first_attn_weights = nn.functional.softmax( + first_product, dim=-1 + ) # [bsz, n_heads, from_block_size, to_seq_len] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) + first_context_layer.unsqueeze_(2) + + # 2nd PART + # 2nd block attention scores + # q[1] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> 2nd, 3rd blocks + # global key blocks -> 1st block + + second_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, 1], + blocked_key_matrix[:, :, 2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + second_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, 1], + blocked_value_matrix[:, :, 2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 1], second_key_mat, ndim=4) + second_seq_pad = torch.cat( + [ + to_mask[:, :, :, : 3 * to_block_size], + to_mask[:, :, :, -to_block_size:], + to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_rand_pad = torch.cat( + [ + rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, 0], + ], + dim=3, + ) + second_product = second_product * rsqrt_d + second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty + second_attn_weights = nn.functional.softmax( + second_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_context_layer = self.torch_bmm_nd(second_attn_weights, second_value_mat, ndim=4) + + second_context_layer.unsqueeze_(2) + + # 3rd PART + # Middle blocks attention scores + # q[-2:2] x (sliding_keys, random_keys, global_keys) + # sliding attn is calculated using special trick of shifting tokens as discussed in paper + # random keys are generated by taking random indices as per `rand_attn` + # global keys -> 1st & last block + + exp_blocked_key_matrix = torch.cat( + [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], dim=3 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + exp_blocked_value_matrix = torch.cat( + [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], + dim=3, + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + middle_query_matrix = blocked_query_matrix[:, :, 2:-2] + + # sliding attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + inner_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, exp_blocked_key_matrix, ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] + inner_band_product = inner_band_product * rsqrt_d + + # randn attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + rand_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, gathered_key[:, :, 1:-1], ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] + rand_band_product = rand_band_product * rsqrt_d + + # Including 1st block (since it's global) + first_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + first_band_product = first_band_product * rsqrt_d + + # Including last block (since it's global) + last_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + last_band_product = last_band_product * rsqrt_d + + # masking padded tokens + inner_band_product += (1.0 - band_mask) * attn_mask_penalty + first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * attn_mask_penalty + last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * attn_mask_penalty + rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty + + # completing attention scores matrix for all q[-2:2] + band_product = torch.cat( + [first_band_product, inner_band_product, rand_band_product, last_band_product], dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # safely doing softmax since attention matrix is completed + attn_weights = nn.functional.softmax( + band_product, dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # contribution of sliding keys + # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + context_layer = self.torch_bmm_nd( + attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix, ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of random keys + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + context_layer += self.torch_bmm_nd( + attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], gathered_value[:, :, 1:-1], ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of global keys + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # 4th PART + # last 2nd token attention scores + # q[-2] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> last 3 blocks + # global key block -> 1st block + # random key block -> based on indices stored in `randn_attn` + + second_last_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, -3], + blocked_key_matrix[:, :, -2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] + second_last_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, -3], + blocked_value_matrix[:, :, -2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+r)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -2], second_last_key_mat, ndim=4) + second_last_seq_pad = torch.cat( + [ + to_mask[:, :, :, :to_block_size], + to_mask[:, :, :, -3 * to_block_size :], + to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_last_rand_pad = torch.cat( + [ + rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, -1], + ], + dim=3, + ) + second_last_product = second_last_product * rsqrt_d + second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty + second_last_attn_weights = nn.functional.softmax( + second_last_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_last_context_layer = self.torch_bmm_nd(second_last_attn_weights, second_last_value_mat, ndim=4) + second_last_context_layer.unsqueeze_(2) + + # 5th PART + # last block (global) attention scores + # q[-1] x (k[0], k[1], k[2], k[3], .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) + last_product = last_product * rsqrt_d + last_product += (1.0 - to_mask) * attn_mask_penalty + last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) + last_context_layer.unsqueeze_(2) + + # combining representations of all tokens + context_layer = torch.cat( + [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], + dim=2, + ) + context_layer = context_layer.view((bsz, n_heads, from_seq_len, -1)) * from_mask + context_layer = torch.transpose(context_layer, 1, 2) + + # this is just for visualizing; forward pass doesn't depend on following code + if output_attentions: + # TODO(PVP): need to verify if below code is correct + attention_probs = torch.zeros( + bsz, n_heads, from_seq_len, to_seq_len, dtype=torch.float, device=context_layer.device + ) + + # 1st query block + # corresponding to `first_context_layer` + attention_probs[:, :, :from_block_size, :] = first_attn_weights # all keys global + + # 2nd query block + # corresponding to `second_context_layer` + attention_probs[:, :, from_block_size : 2 * from_block_size, : 3 * to_block_size] = second_attn_weights[ + :, :, :, : 3 * to_block_size + ] # 1st three key blocks (global + sliding) + attention_probs[:, :, from_block_size : 2 * from_block_size, -to_block_size:] = second_attn_weights[ + :, :, :, 3 * to_block_size : 4 * to_block_size + ] # last key block (global) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, 1, :, i2[0]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Middle query blocks + # corresponding to `context_layer` + # sliding keys + for q_idx in range(from_seq_len // from_block_size - 4): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + )[:, :, 2:-2, :, 1:-1, :] + right_slice = attn_weights[:, :, q_idx, :, to_block_size : 4 * to_block_size] + attn_probs_view[:, :, q_idx, :, q_idx : q_idx + 3, :] = right_slice.view( + bsz, n_heads, from_block_size, 3, to_block_size + ) # inner_band_product + # global keys (corresponding to 1st key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, :to_block_size] = attn_weights[ + :, :, :, :, :to_block_size + ].view(bsz, n_heads, -1, to_block_size) # first_band_product + # global keys (corresponding to last key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, -to_block_size:] = attn_weights[ + :, :, :, :, -to_block_size: + ].view(bsz, n_heads, -1, to_block_size) # last_band_product + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + for q_idx in range(1, len(i2) - 1): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[q_idx - 1, :, 4 * to_block_size : -to_block_size] + attn_probs_view[p1, p2, q_idx + 1, :, i2[q_idx]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Second-last query block + # corresponding to `second_last_context_layer` + attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[ + :, :, :, :to_block_size + ] # 1st key block (global) + attention_probs[:, :, -2 * from_block_size : -from_block_size, -3 * to_block_size :] = ( + second_last_attn_weights[:, :, :, to_block_size : 4 * to_block_size] + ) # last three blocks (global + sliding) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, -2, :, i2[-1]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # last query block + # corresponding to `last_context_layer` + attention_probs[:, :, -from_block_size:, :] = last_attn_weights # all keys global + + else: + attention_probs = None + + return context_layer, attention_probs + + @staticmethod + def torch_gather_b2(params, indices): + # this operation is equivalent to tf.gather when batch_dims=2 + + if params.shape[:2] != indices.shape[:2]: + raise ValueError( + "Make sure that the first two dimensions of params and indices are identical, but" + f" they are params: {params.shape[:2]} vs. indices: {indices.shape[:2]}" + ) + num_indices_to_gather = indices.shape[-2] * indices.shape[-1] + num_indices_to_pick_from = params.shape[2] + + shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device) + indices_shift = torch.div(shift, num_indices_to_gather, rounding_mode="floor") * num_indices_to_pick_from + + flattened_indices = indices.view(-1) + indices_shift + flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1]) + + out_flattened = flattened_params.index_select(0, flattened_indices) + + out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:]) + return out + + @staticmethod + def _create_rand_mask_from_inputs( + from_blocked_mask, + to_blocked_mask, + rand_attn, + num_attention_heads, + num_rand_blocks, + batch_size, + from_seq_length, + from_block_size, + ): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + rand_attn: [batch_size, num_attention_heads, + from_seq_length//from_block_size-2, num_rand_blocks] + num_attention_heads: int. Number of attention heads. + num_rand_blocks: int. Number of random chunks per row. + batch_size: int. Batch size for computation. + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + + Returns: + float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, + from_block_size, num_rand_blocks*to_block_size]. + """ + num_windows = from_seq_length // from_block_size - 2 + rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]) + rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size) + rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) + return rand_mask + + @staticmethod + def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): + """ + Gives the plan of where to put random attention. + + Args: + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + num_rand_blocks: int. Number of random chunks per row. + + Returns: + plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for + each block + """ + + plan_from_length = [] + plan_num_rand_blocks = [] + if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(0) + elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks // 2) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) + else: + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks) + + return plan_from_length, plan_num_rand_blocks + + def _bigbird_block_rand_mask( + self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1 + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_rand_blocks: int. Number of random chunks per row. + last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, + if positive then num_rand_blocks blocks chosen only up to last_idx. + + Returns: + adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks + """ + # using this method when from_seq_length in [1024, 3072, 4096] + + if from_seq_length // from_block_size != to_seq_length // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) + # During inference (eval) no randomness + if not self.training: + return rand_attn + middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) + last = to_seq_length // to_block_size - 1 + if last_idx > (2 * to_block_size): + last = (last_idx // to_block_size) - 1 + + r = num_rand_blocks # shorthand + for i in range(1, from_seq_length // from_block_size - 1): + start = i - 2 + end = i + if i == 1: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] + elif i == 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] + elif i == from_seq_length // from_block_size - 3: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -3: should have been sliced till last-3 + elif i == from_seq_length // from_block_size - 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -4: should have been sliced till last-4 + else: + if start > last: + start = last + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + elif (end + 1) == last: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + else: + rand_attn[i - 1, :] = np.random.permutation( + np.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) + )[:r] + return rand_attn + + def _bigbird_block_rand_mask_with_head( + self, + from_seq_length, + to_seq_length, + from_block_size, + to_block_size, + num_heads, + plan_from_length, + plan_num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_top=1, + global_block_bottom=1, + global_block_left=1, + global_block_right=1, + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_heads: int. total number of heads. + plan_from_length: list. plan from length where num_random_blocks are chosen from. + plan_num_rand_blocks: list. number of rand blocks within the plan. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_top: int. number of blocks at the top. + global_block_bottom: int. number of blocks at the bottom. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by + num_rand_blocks + """ + # using this method when from_seq_length not in [1024, 3072, 4096] + + if from_seq_length // from_block_size != to_seq_length // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + if from_seq_length not in plan_from_length: + raise ValueError("Error from sequence length not in plan!") + + # Total number of blocks in the mmask + num_blocks = from_seq_length // from_block_size + # Number of blocks per plan + plan_block_length = np.array(plan_from_length) // from_block_size + # till when to follow plan + max_plan_idx = plan_from_length.index(from_seq_length) + + # Random Attention adjacency list + rand_attn = [ + np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32) + for i in range(num_heads) + ] + # During inference (eval) no randomness + if not self.training: + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + return rand_attn + + # We will go iteratively over the plan blocks and pick random number of + # Attention blocks from the legally allowed blocks + for plan_idx in range(max_plan_idx + 1): + rnd_r_cnt = 0 + if plan_idx > 0: + # set the row for all from_blocks starting from 0 to + # plan_block_length[plan_idx-1] + # column indx start from plan_block_length[plan_idx-1] and ends at + # plan_block_length[plan_idx] + if plan_num_rand_blocks[plan_idx] > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=plan_block_length[plan_idx - 1], + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for pl_id in range(plan_idx): + if plan_num_rand_blocks[pl_id] == 0: + continue + for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): + rnd_r_cnt = 0 + to_start_block_id = 0 + if pl_id > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id])) + to_start_block_id = plan_block_length[pl_id - 1] + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: pl_id + 1])) + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[pl_id], + num_rand_blocks=plan_num_rand_blocks[pl_id], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + if plan_num_rand_blocks[plan_idx] == 0: + continue + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + from_start_block_id = global_block_top + to_start_block_id = 0 + if plan_idx > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + from_start_block_id = plan_block_length[plan_idx - 1] + to_start_block_id = plan_block_length[plan_idx - 1] + + for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + + return rand_attn + + @staticmethod + def _get_single_block_row_attention( + block_id, + to_start_block_id, + to_end_block_id, + num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_left=1, + global_block_right=1, + ): + """ + For a single row block get random row attention. + + Args: + block_id: int. block id of row. + to_start_block_id: int. random attention column start id. + to_end_block_id: int. random attention column end id. + num_rand_blocks: int. number of random blocks to be selected. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + row containing the random attention vector of size num_rand_blocks. + """ + # list of to_blocks from which to choose random attention + to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32) + # permute the blocks + perm_block = np.random.permutation(to_block_list) + + # illegal blocks for the current block id, using window + illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) + + # Add blocks at the start and at the end + illegal_blocks.extend(list(range(global_block_left))) + illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) + + # The second from_block cannot choose random attention on second last to_block + if block_id == 1: + illegal_blocks.append(to_end_block_id - 2) + + # The second last from_block cannot choose random attention on second to_block + if block_id == to_end_block_id - 2: + illegal_blocks.append(1) + + selected_random_blocks = [] + + for i in range(to_end_block_id - to_start_block_id): + if perm_block[i] not in illegal_blocks: + selected_random_blocks.append(perm_block[i]) + if len(selected_random_blocks) == num_rand_blocks: + break + return np.array(selected_random_blocks, dtype=np.int32) + + +class BigBirdPegasusEncoderAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + self.config = config + self.seed = seed + + self.attention_type = config.attention_type + + if self.attention_type == "original_full": + self.self = BigBirdPegasusSelfAttention(config) + elif self.attention_type == "block_sparse": + self.self = BigBirdPegasusBlockSparseAttention(config, seed) + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.config.attention_type}" + ) + + self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias) + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + + if value == "original_full": + # copy all weights to new full attention class + attn_weights = BigBirdPegasusSelfAttention(self.config) + else: + # copy all weights to new sparse attention class + attn_weights = BigBirdPegasusBlockSparseAttention(self.config, self.seed) + + attn_weights.query = self.self.query + attn_weights.value = self.self.value + attn_weights.key = self.self.key + self.self = attn_weights + self.attention_type = value + + if not self.training: + self.self.eval() + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + ): + # Expand dims to enable multiplication in the self-attention module + head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None + + if self.attention_type == "original_full": + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + else: + self_outputs = self.self( + hidden_states, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, output_attentions + ) + + attention_output = self.output(self_outputs[0]) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder +class BigBirdPegasusDecoderAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[BigBirdPegasusConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BigBirdPegasusConfig, seed=None): + super().__init__() + self.attention_type = config.attention_type + self.embed_dim = config.d_model + self.self_attn = BigBirdPegasusEncoderAttention(config, seed=seed) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + self_attention_outputs = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=from_blocked_mask, + to_blocked_mask=to_blocked_mask, + ) + hidden_states = self_attention_outputs[0] + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attention_outputs[1],) + + return outputs + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + self.self_attn.set_attention_type(value) + + +class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = BigBirdPegasusDecoderAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=config.use_bias, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BigBirdPegasusDecoderAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=config.use_bias, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->BigBirdPegasus +class BigBirdPegasusClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +@auto_docstring +class BigBirdPegasusPreTrainedModel(PreTrainedModel): + config: BigBirdPegasusConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_param_buffer_assignment = False + + _can_compile_fullgraph = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`BigBirdPegasusEncoderLayer`]. + + Args: + config: BigBirdPegasusConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.attention_type = config.attention_type + self.block_size = config.block_size + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BigBirdPegasusScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([BigBirdPegasusEncoderLayer(config, seed=i) for i in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=hidden_states.device) + attention_mask = attention_mask.long() + + # in order to use block_sparse attention, sequence_length has to be at least + # bigger than all global attentions: 2 * block_size + # + sliding tokens: 3 * block_size + # + random tokens: 2 * num_random_blocks * block_size + max_tokens_to_attend = (5 + 2 * self.config.num_random_blocks) * self.config.block_size + if self.attention_type == "block_sparse" and input_shape[1] <= max_tokens_to_attend: + # change attention_type from block_sparse to original_full + sequence_length = input_shape[1] + logger.warning( + "Attention type 'block_sparse' is not possible if sequence_length: " + f"{sequence_length} <= num global tokens: 2 * config.block_size " + "+ min. num sliding tokens: 3 * config.block_size " + "+ config.num_random_blocks * config.block_size " + "+ additional buffer: config.num_random_blocks * config.block_size " + f"= {max_tokens_to_attend} with config.block_size " + f"= {self.config.block_size}, config.num_random_blocks " + f"= {self.config.num_random_blocks}. " + "Changing attention type to 'original_full'..." + ) + self.set_attention_type("original_full") + + if self.attention_type == "block_sparse": + padding_len, hidden_states, attention_mask = self._pad_to_block_size(hidden_states, attention_mask) + else: + padding_len = 0 + + # expand attention_mask + if self.attention_type == "original_full": + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + blocked_encoder_mask = band_mask = from_mask = to_mask = None + elif self.attention_type == "block_sparse": + blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn( + attention_mask, self.block_size + ) + attention_mask = None + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.attention_type}" + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=blocked_encoder_mask, + to_blocked_mask=blocked_encoder_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layernorm_embedding(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if padding_len > 0: + # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) + hidden_states = hidden_states[:, :-padding_len] + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + for layer in self.layers: + layer.set_attention_type(value) + + @staticmethod # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdModel.create_masks_for_block_sparse_attn + def create_masks_for_block_sparse_attn(attention_mask: torch.Tensor, block_size: int): + batch_size, seq_length = attention_mask.size() + if seq_length % block_size != 0: + raise ValueError( + f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block" + f" size is {block_size}." + ) + + def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + + Returns: + float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, + 3*to_block_size]. + """ + exp_blocked_to_pad = torch.cat( + [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2 + ) + band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) + band_mask.unsqueeze_(1) + return band_mask + + blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size) + band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) + + from_mask = attention_mask.view(batch_size, 1, seq_length, 1) + to_mask = attention_mask.view(batch_size, 1, 1, seq_length) + + return blocked_encoder_mask, band_mask, from_mask, to_mask + + def _pad_to_block_size(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): + """A helper function to pad tokens and mask to work with implementation of BigBird block-sparse attention.""" + # padding + block_size = self.config.block_size + batch_size, seq_len = hidden_states.shape[:2] + + padding_len = (block_size - seq_len % block_size) % block_size + if padding_len > 0: + logger.warning_once( + f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of " + f"`config.block_size`: {block_size}" + ) + pad_id = self.config.pad_token_id + device = hidden_states.device + input_ids_padding = torch.ones((batch_size, padding_len), dtype=torch.long, device=device) * pad_id + inputs_embeds_padding = self.embed_tokens(input_ids_padding) + hidden_states = torch.cat([hidden_states, inputs_embeds_padding], dim=-2) + + attention_mask = nn.functional.pad( + attention_mask, (0, padding_len), value=0 + ) # no attention on the padding tokens + + return padding_len, hidden_states, attention_mask + + +class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BigBirdPegasusDecoderLayer`] + + Args: + config: BigBirdPegasusConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BigBirdPegasusScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList( + [BigBirdPegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + attention_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layernorm_embedding(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BigBirdPegasusConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = BigBirdPegasusScaledWordEmbedding( + vocab_size, config.d_model, padding_idx, embed_scale=embed_scale + ) + + self.encoder = BigBirdPegasusEncoder(config, self.shared) + self.decoder = BigBirdPegasusDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the `input_ids` to the right, following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read + [`modeling_bigbird_pegasus._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in + [the paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + decoder_head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + # different to other models, BigBirdPegasus automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The BigBirdPegasus Model with a language modeling head. Can be used for summarization. + """ +) +# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS +class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + + def __init__(self, config: BigBirdPegasusConfig): + super().__init__(config) + self.model = BigBirdPegasusModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self.model._tie_weights() + self._tie_or_clone_weights(self.lm_head, self.model.shared) + + @auto_docstring + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the `input_ids` to the right, following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read + [`modeling_bigbird_pegasus._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in + [the paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + decoder_head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example summarization: + + ```python + >>> from transformers import AutoTokenizer, BigBirdPegasusForConditionalGeneration + + >>> model = BigBirdPegasusForConditionalGeneration.from_pretrained("google/bigbird-pegasus-large-arxiv") + >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "The dominant sequence transduction models are based on complex recurrent or convolutional neural " + ... "networks in an encoder-decoder configuration. The best performing models also connect the encoder " + ... "and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, " + ... "based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. " + ... "Experiments on two machine translation tasks show these models to be superior in quality " + ... "while being more parallelizable and requiring significantly less time to train." + ... ) + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=4096, return_tensors="pt", truncation=True) + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=15) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'dominant sequence models are based on recurrent or convolutional neural networks .' + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + +@auto_docstring( + custom_intro=""" + BigBirdPegasus model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. + for GLUE tasks. + """ +) +class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BigBirdPegasusConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = BigBirdPegasusModel(config) + self.classification_head = BigBirdPegasusClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the `input_ids` to the right, following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read + [`modeling_bigbird_pegasus._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in + [the paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + decoder_head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@auto_docstring +class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = BigBirdPegasusModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the `input_ids` to the right, following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read + [`modeling_bigbird_pegasus._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in + [the paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + decoder_head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# Copied from transformers.models.pegasus.modeling_pegasus.PegasusDecoderWrapper with Pegasus->BigBirdPegasus +class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BigBirdPegasusDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = BigBirdPegasusDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BigBirdPegasusForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv") + >>> model = BigBirdPegasusForCausalLM.from_pretrained( + ... "google/bigbird-pegasus-large-arxiv", add_cross_attention=False + ... ) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +__all__ = [ + "BigBirdPegasusForCausalLM", + "BigBirdPegasusForConditionalGeneration", + "BigBirdPegasusForQuestionAnswering", + "BigBirdPegasusForSequenceClassification", + "BigBirdPegasusModel", + "BigBirdPegasusPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/blip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..952de2f855a71591e950b31751af1fe6cfbae20a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_blip import * + from .image_processing_blip import * + from .image_processing_blip_fast import * + from .modeling_blip import * + from .modeling_blip_text import * + from .modeling_tf_blip import * + from .modeling_tf_blip_text import * + from .processing_blip import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/configuration_blip.py b/venv/lib/python3.13/site-packages/transformers/models/blip/configuration_blip.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0a5590c3f769896675e801f76ecfa9234eb8be --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/configuration_blip.py @@ -0,0 +1,317 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Blip model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BlipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BlipTextModel`]. It is used to instantiate a BLIP + text model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the `BlipText` used by the [base + architectures](https://huggingface.co/Salesforce/blip-vqa-base). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30524): + Vocabulary size of the `Blip` text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`BlipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + encoder_hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers from the vision model. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + bos_token_id (`int`, *optional*, defaults to 30522): + The id of the `beginning-of-sequence` token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the `end-of-sequence` token. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the `padding` token. + sep_token_id (`int`, *optional*, defaults to 102): + The id of the `separator` token. + is_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as a decoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + label_smoothing (float, *optional*): + A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Example: + + ```python + >>> from transformers import BlipTextConfig, BlipTextModel + + >>> # Initializing a BlipTextConfig with Salesforce/blip-vqa-base style configuration + >>> configuration = BlipTextConfig() + + >>> # Initializing a BlipTextModel (with random weights) from the Salesforce/blip-vqa-base style configuration + >>> model = BlipTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blip_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=30524, + hidden_size=768, + encoder_hidden_size=768, + intermediate_size=3072, + projection_dim=768, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=512, + hidden_act="gelu", + layer_norm_eps=1e-12, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + bos_token_id=30522, + eos_token_id=2, + pad_token_id=0, + sep_token_id=102, + is_decoder=True, + use_cache=True, + label_smoothing=0.0, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.encoder_hidden_size = encoder_hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.hidden_dropout_prob = hidden_dropout_prob + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.is_decoder = is_decoder + self.use_cache = use_cache + self.label_smoothing = label_smoothing + + +class BlipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BlipVisionModel`]. It is used to instantiate a + BLIP vision model according to the specified arguments, defining the model architecture. Instantiating a + configuration defaults will yield a similar configuration to that of the Blip-base + [Salesforce/blip-vqa-base](https://huggingface.co/Salesforce/blip-vqa-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import BlipVisionConfig, BlipVisionModel + + >>> # Initializing a BlipVisionConfig with Salesforce/blip-vqa-base style configuration + >>> configuration = BlipVisionConfig() + + >>> # Initializing a BlipVisionModel (with random weights) from the Salesforce/blip-vqa-base style configuration + >>> model = BlipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + image_size=384, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=1e-10, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class BlipConfig(PretrainedConfig): + r""" + [`BlipConfig`] is the configuration class to store the configuration of a [`BlipModel`]. It is used to instantiate + a BLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating + a configuration with the defaults will yield a similar configuration to that of the BLIP-base + [Salesforce/blip-vqa-base](https://huggingface.co/Salesforce/blip-vqa-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`BlipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`BlipVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original BLIP implementation. + image_text_hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden state of the image-text fusion layer. + label_smoothing (float, optional, *optional*, defaults to 0.0): + A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import BlipConfig, BlipModel + + >>> # Initializing a BlipConfig with Salesforce/blip-vqa-base style configuration + >>> configuration = BlipConfig() + + >>> # Initializing a BlipPModel (with random weights) from the Salesforce/blip-vqa-base style configuration + >>> model = BlipModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a BlipConfig from a BlipTextConfig and a BlipVisionConfig + + >>> # Initializing a BLIPText and BLIPVision configuration + >>> config_text = BlipTextConfig() + >>> config_vision = BlipVisionConfig() + + >>> config = BlipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "blip" + sub_configs = {"text_config": BlipTextConfig, "vision_config": BlipVisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + projection_dim=512, + logit_scale_init_value=2.6592, + image_text_hidden_size=256, + label_smoothing=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `BlipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. Initializing the `BlipVisionConfig` with default values.") + + self.text_config = BlipTextConfig(**text_config) + self.vision_config = BlipVisionConfig(**vision_config) + + self.text_config.encoder_hidden_size = self.vision_config.hidden_size + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + self.initializer_range = 0.02 + self.image_text_hidden_size = image_text_hidden_size + self.label_smoothing = label_smoothing + + +__all__ = ["BlipConfig", "BlipTextConfig", "BlipVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/image_processing_blip.py b/venv/lib/python3.13/site-packages/transformers/models/blip/image_processing_blip.py new file mode 100644 index 0000000000000000000000000000000000000000..78a152374fd0089846f4431c8b408da4c5c9ff54 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/image_processing_blip.py @@ -0,0 +1,297 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for BLIP.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +class BlipImageProcessor(BaseImageProcessor): + r""" + Constructs a BLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 384, "width": 384} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + + return encoded_outputs + + +__all__ = ["BlipImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/image_processing_blip_fast.py b/venv/lib/python3.13/site-packages/transformers/models/blip/image_processing_blip_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1d14f139049162590bcaabf923a08b194aaae5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/image_processing_blip_fast.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for BLIP.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...utils import auto_docstring + + +@auto_docstring +class BlipImageProcessorFast(BaseImageProcessorFast): + # To be checked against the slow image processor + # None values left after checking can be removed + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + +__all__ = ["BlipImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_blip.py b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_blip.py new file mode 100644 index 0000000000000000000000000000000000000000..8a58add38af3a68befc369376f4023b461c512b4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_blip.py @@ -0,0 +1,1308 @@ +# coding=utf-8 +# Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BLIP model.""" + +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +from torch import nn +from torch.nn.functional import normalize + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.generic import check_model_inputs +from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig +from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip +def blip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +@auto_docstring( + custom_intro=""" + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder. + """ +) +class BlipForConditionalGenerationModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Language modeling loss from the text decoder. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head of the text decoder model. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*): + The image embeddings obtained after applying the Vision Transformer model to the input image. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[tuple[torch.FloatTensor]] = None + logits: Optional[tuple[torch.FloatTensor]] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + @property + def decoder_logits(self): + warnings.warn( + "`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the `logits` attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.logits + + +@dataclass +@auto_docstring( + custom_intro=""" + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder. + """ +) +class BlipTextVisionModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss from the text decoder. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + loss: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity + scores. + """ +) +class BlipImageTextMatchingModelOutput(ModelOutput): + r""" + itm_score (`torch.FloatTensor`): + The image-text similarity scores. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss from the text decoder. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*): + Last layer hidden-state of the vision of the vision-only branch of the model. + question_embeds (`torch.FloatTensor`): + The question embeddings obtained by the text projection layer. + """ + + itm_score: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + vision_pooler_output: Optional[torch.FloatTensor] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + question_embeds: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring +class BlipOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`BlipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`BlipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class BlipVisionEmbeddings(nn.Module): + def __init__(self, config: BlipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip +class BlipTextEmbeddings(nn.Module): + def __init__(self, config: BlipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class BlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = nn.Dropout(config.attention_dropout) + + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim) + + self.projection = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = ( + self.qkv(hidden_states) + .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) + + new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) + context_layer = context_layer.reshape(new_context_layer_shape) + + output = self.projection(context_layer) + + return output, attention_probs + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip +class BlipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class BlipEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BlipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = BlipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = BlipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + head_mask=attention_mask, + **kwargs, + ) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + return hidden_states + + +@auto_docstring +class BlipPreTrainedModel(PreTrainedModel): + config: BlipConfig + base_model_prefix = "blip" + supports_gradient_checkpointing = True + _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"] + _skip_keys_device_placement = ["past_key_values"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_range + if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)): + module.weight.data.normal_(mean=0.0, std=factor) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + if isinstance(module, BlipVisionEmbeddings): + if hasattr(self.config, "vision_config"): + factor = self.config.vision_config.initializer_range + nn.init.trunc_normal_( + module.position_embedding, + mean=0.0, + std=factor, + ) + + nn.init.trunc_normal_( + module.class_embedding, + mean=0.0, + std=factor, + ) + + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`BlipEncoderLayer`]. + + Args: + config (`BlipConfig`): + The corresponding vision configuration for the `BlipEncoder`. + """ + + def __init__(self, config: BlipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class BlipVisionModel(BlipPreTrainedModel): + main_input_name = "pixel_values" + config: BlipVisionConfig + _can_record_outputs = { + "hidden_states": BlipEncoderLayer, + "attentions": BlipAttention, + } + + def __init__(self, config: BlipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = BlipVisionEmbeddings(config) + self.encoder = BlipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPooling]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + def get_input_embeddings(self): + return self.embeddings + + +@auto_docstring( + custom_intro=""" + This model is going to be deprecated in future versions. Please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase. + """ +) +class BlipModel(BlipPreTrainedModel): + config: BlipConfig + + def __init__(self, config: BlipConfig): + super().__init__(config) + + if not isinstance(config.text_config, BlipTextConfig): + raise TypeError( + "config.text_config is expected to be of type BlipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, BlipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type BlipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = BlipTextModel(text_config) + self.vision_model = BlipVisionModel(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + logger.warning( + "`BlipModel` is going to be deprecated in future release, please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase." + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + @auto_docstring + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`BlipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @auto_docstring + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`BlipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @auto_docstring + def get_multimodal_features( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings + obtained by applying the image embeddings to the text encoder using the cross-attention mechanism. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["a photo of a cat", "a photo of a dog"] + >>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt") + + >>> multimodal_features = model.get_multimodal_features(**inputs) + ```""" + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + image_embeds = vision_outputs[0] + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + ) + + pooled_output = text_outputs[1] # pooled_output + multimodal_features = self.text_projection(pooled_output) + + return multimodal_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BlipOutput]: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + + image_embeds = vision_outputs.pooler_output + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs.pooler_output + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp().to(device=text_embeds.device) + image_embeds = image_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype) + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = blip_loss(logits_per_text) + + return BlipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@auto_docstring( + custom_intro=""" + BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass + `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise, + the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption + from the text input. If no text input is provided, the decoder will start with the [BOS] token only. + """ +) +class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin): + config: BlipConfig + _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] + main_input_name = "pixel_values" + + def __init__(self, config: BlipConfig): + super().__init__(config) + + self.vision_model = BlipVisionModel(config.vision_config) + + self.text_decoder = BlipTextLMHeadModel(config.text_config) + + self.decoder_input_ids = config.text_config.bos_token_id + self.decoder_pad_token_id = config.text_config.pad_token_id + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_decoder.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_decoder.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BlipForConditionalGenerationModelOutput]: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "A picture of" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + + >>> outputs = model(**inputs) + ```""" + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + image_embeds = vision_outputs.last_hidden_state + + outputs = self.text_decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + labels=labels, + reduction="mean", + **kwargs, + ) + + return BlipForConditionalGenerationModelOutput( + loss=outputs.loss, + logits=outputs.logits, + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + pixel_values: torch.FloatTensor, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **generate_kwargs, + ) -> torch.LongTensor: + r""" + Overrides *generate* function to be able to use the model as a conditional generator + + Parameters: + pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*: + Input image to be processed + input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): + The sequence used as a prompt for the generation. + attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForConditionalGeneration + + >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + two cats sleeping on a couch + ``` + """ + + batch_size = pixel_values.shape[0] + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + image_embeds = vision_outputs[0] + + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if isinstance(input_ids, list): + input_ids = torch.LongTensor(input_ids) + elif input_ids is None: + input_ids = ( + torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]]) + .repeat(batch_size, 1) + .to(image_embeds.device) + ) + + input_ids[:, 0] = self.config.text_config.bos_token_id + attention_mask = attention_mask[:, :-1] if attention_mask is not None else None + + outputs = self.text_decoder.generate( + input_ids=input_ids[:, :-1], + eos_token_id=self.config.text_config.sep_token_id, + pad_token_id=self.config.text_config.pad_token_id, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + **generate_kwargs, + ) + + return outputs + + +@auto_docstring( + custom_intro=""" + BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text + decoder. The vision encoder will encode the input image, the text encoder will encode the input question together + with the encoding of the image, and the text decoder will output the answer to the question. + """ +) +class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin): + config: BlipConfig + _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] + + def __init__(self, config: BlipConfig): + super().__init__(config) + + self.vision_model = BlipVisionModel(config.vision_config) + + self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False) + + self.text_decoder = BlipTextLMHeadModel(config.text_config) + + self.decoder_pad_token_id = config.text_config.pad_token_id + self.decoder_start_token_id = config.text_config.bos_token_id + + # Initialize weights and apply final processing + self.post_init() + + def set_input_embeddings(self, value): + self.text_encoder.set_input_embeddings(value) + + def get_input_embeddings(self): + # This will return shared embeddings if they are shared else specific to encoder. + return self.text_encoder.get_input_embeddings() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BlipTextVisionModelOutput]: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForQuestionAnswering + + >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # training + >>> text = "How many cats are in the picture?" + >>> label = "2" + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> labels = processor(text=label, return_tensors="pt").input_ids + + >>> inputs["labels"] = labels + >>> outputs = model(**inputs) + >>> loss = outputs.loss + >>> loss.backward() + + >>> # inference + >>> text = "How many cats are in the picture?" + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + 2 + ```""" + if labels is None and decoder_input_ids is None: + raise ValueError( + "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with" + " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you" + " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`" + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + image_embeds = vision_outputs.last_hidden_state + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long) + + question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + **kwargs, + ) + + if labels is not None and decoder_input_ids is None: + # labels are already shifted right, see: https://github.com/huggingface/transformers/pull/23153 + decoder_input_ids = labels + + question_embeds = question_embeds[0] + + answer_output = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=question_embeds, + encoder_attention_mask=attention_mask, + labels=labels, + reduction="mean", + **kwargs, + ) + + if labels is not None: + decoder_loss = answer_output.loss.mean() + else: + decoder_loss = None + + return BlipTextVisionModelOutput( + loss=decoder_loss, + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **generate_kwargs, + ) -> torch.LongTensor: + r""" + Overrides *generate* function to be able to use the model as a conditional generator + + Parameters: + input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*): + The sequence used as a prompt for the generation. + pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*: + Input image to be processed + attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for + tokens that are NOT MASKED, `0` for MASKED tokens. + **generate_kwargs: + Additional arguments passed to the *generate* function of the decoder + + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForQuestionAnswering + + >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "How many cats are in the picture?" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + 2 + ``` + """ + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + image_embeds = vision_outputs[0] + + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if isinstance(input_ids, list): + input_ids = torch.LongTensor(input_ids) + + question_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=False, + ) + + question_embeds = question_outputs[0] + + question_attention_mask = torch.ones( + question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device + ) + + bos_ids = torch.full( + (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device + ) + + outputs = self.text_decoder.generate( + input_ids=bos_ids, + eos_token_id=self.config.text_config.sep_token_id, + pad_token_id=self.config.text_config.pad_token_id, + encoder_hidden_states=question_embeds, + encoder_attention_mask=question_attention_mask, + **generate_kwargs, + ) + + return outputs + + +@auto_docstring( + custom_intro=""" + BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of + image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to + the image. + """ +) +class BlipForImageTextRetrieval(BlipPreTrainedModel): + config: BlipConfig + + def __init__(self, config: BlipConfig): + super().__init__(config) + + self.vision_model = BlipVisionModel(config.vision_config) + + self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False) + + # vision projection layer + self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size) + + # image text matching head + self.itm_head = nn.Linear(config.text_config.hidden_size, 2) + + self.decoder_pad_token_id = ( + config.text_config.pad_token_id + if not hasattr(config, "decoder_pad_token_id") + else config.decoder_pad_token_id + ) + self.decoder_start_token_id = ( + config.text_config.bos_token_id + if not hasattr(config, "decoder_start_token_id") + else config.decoder_start_token_id + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_encoder.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_encoder.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + use_itm_head: Optional[bool] = True, + attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BlipTextVisionModelOutput]: + r""" + use_itm_head (`bool`, *optional*, defaults to `True`): + Whether or not to use the image-text matching head. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForImageTextRetrieval + + >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "an image of a cat" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> outputs = model(**inputs) + ``` + """ + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + image_embeds = vision_outputs.last_hidden_state + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long) + + if use_itm_head: + question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + **kwargs, + ) + question_embeds = question_embeds.last_hidden_state + + output = self.itm_head(question_embeds[:, 0, :]) + else: + question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + question_embeds = question_embeds.last_hidden_state + + image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1) + + output = image_feat @ text_feat.t() + + return BlipImageTextMatchingModelOutput( + itm_score=output, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + question_embeds=question_embeds, + ) + + +__all__ = [ + "BlipModel", + "BlipPreTrainedModel", + "BlipForConditionalGeneration", + "BlipForQuestionAnswering", + "BlipVisionModel", + "BlipTextModel", + "BlipForImageTextRetrieval", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_blip_text.py b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_blip_text.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1f58c7533463ab9ba01f45a15ba51a294c0651 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_blip_text.py @@ -0,0 +1,968 @@ +# coding=utf-8 +# Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the BSD-3-clause license (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Optional, Union + +import torch +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_blip import BlipTextConfig + + +logger = logging.get_logger(__name__) + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L52 +class BlipTextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97 +class BlipTextSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention, layer_idx=None): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention heads (%d)" + % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.layer_idx = layer_idx + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) + self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BlipTextModel forward() function) + attention_scores = attention_scores + attention_mask.to(attention_scores.device) + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert -> BlipText +class BlipTextSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242 +class BlipTextAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.self = BlipTextSelfAttention(config, is_cross_attention, layer_idx=layer_idx) + self.output = BlipTextSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert -> BlipText +class BlipTextIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert -> BlipText +class BlipTextOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BlipTextLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BlipTextAttention(config, layer_idx=layer_num) + self.layer_num = layer_num + if self.config.is_decoder: + self.crossattention = BlipTextAttention( + config, is_cross_attention=self.config.is_decoder, layer_idx=layer_num + ) + self.intermediate = BlipTextIntermediate(config) + self.output = BlipTextOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] + + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + return (layer_output,) + outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L386 +class BlipTextEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BlipTextLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + if isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + # The model acts as encoder decoder but is not an encoder decoder. So we cast all cache objects to + # `EncoderDecoderCache` type assuming that the incoming cache is from `self_attention` + elif isinstance(past_key_values, DynamicCache): + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config)) + elif past_key_values is None: + past_key_values = EncoderDecoderCache( + DynamicCache(config=self.config), DynamicCache(config=self.config) + ) + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and encoder_hidden_states is not None else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + output_attentions, + cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BlipText +class BlipTextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->BlipText +class BlipTextPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->BlipText +class BlipTextLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BlipTextPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->BlipText +class BlipTextOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BlipTextLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L548 +class BlipTextPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config: BlipTextConfig + base_model_prefix = "bert" + _no_split_modules = [] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571 +class BlipTextModel(BlipTextPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and `is_decoder` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BlipTextEmbeddings(config) + self.encoder = BlipTextEncoder(config) + self.pooler = BlipTextPooler(config) if add_pooling_layer else None + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: tuple[int], device: device, is_decoder: bool + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`tuple[int]`): + The shape of the input to the model. + device (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + is_decoder: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`Cache`, *optional*): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length)).to(device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 +class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BlipTextModel(config, add_pooling_layer=False) + self.cls = BlipTextOnlyMLMHead(config) + self.label_smoothing = config.label_smoothing + + def get_input_embeddings(self): + return self.bert.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.bert.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_logits: Optional[bool] = False, + is_decoder: Optional[bool] = True, + reduction: Optional[str] = "mean", + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is + configured as a decoder. + encoder_attention_mask (`torch.FloatTensor`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`Cache`, *optional*): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + cache_position=cache_position, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous().to(shifted_prediction_scores.device) + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=self.label_smoothing) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + # Overwrite -- hardcoded key return (`is_decoder=True`) + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + **model_kwargs, + ) + model_inputs["is_decoder"] = True + + return model_inputs + + +__all__ = ["BlipTextModel", "BlipTextLMHeadModel", "BlipTextPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_tf_blip.py b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_tf_blip.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a1f7928273c69de30b4ba1c190435d91dfd509 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_tf_blip.py @@ -0,0 +1,1709 @@ +# coding=utf-8 +# Copyright 2023 The Salesforce Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow BLIP model.""" + +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from typing import Any + +import tensorflow as tf + +from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling +from ...modeling_tf_utils import ( + TFPreTrainedModel, + get_initializer, + get_tf_activation, + keras, + keras_serializable, + shape_list, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, stable_softmax +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig +from .modeling_tf_blip_text import BLIP_TEXT_INPUTS_DOCSTRING, TFBlipTextLMHeadModel, TFBlipTextModel + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base" + + +# Copied from transformers.models.clip.modeling_tf_clip.contrastive_loss +def contrastive_loss(logits: tf.Tensor) -> tf.Tensor: + return tf.math.reduce_mean( + keras.metrics.sparse_categorical_crossentropy( + y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True + ) + ) + + +# Copied from transformers.models.clip.modeling_tf_clip.clip_loss with clip->blip +def blip_loss(similarity: tf.Tensor) -> tf.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(tf.transpose(similarity)) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class TFBlipForConditionalGenerationModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder. + + Args: + loss (`tf.Tensor`, *optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`): + Language modeling loss from the text decoder. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head of the text decoder model. + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)`, *optional*): + The image embeddings obtained after applying the Vision Transformer model to the input image. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads.` + """ + + loss: tuple[tf.Tensor] | None = None + logits: tuple[tf.Tensor] | None = None + image_embeds: tf.Tensor | None = None + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + + @property + def decoder_logits(self): + warnings.warn( + "`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the `logits` attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.logits + + +@dataclass +class TFBlipTextVisionModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss from the text decoder. + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + image_embeds: tf.Tensor | None = None + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFBlipImageTextMatchingModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity + scores. + + Args: + itm_score (`tf.Tensor`): + The image-text similarity scores. + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss from the text decoder. + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + vision_pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*): + Last layer hidden-state of the vision of the vision-only branch of the model. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + question_embeds (`tf.Tensor`): + The question embeddings obtained by the text projection layer. + """ + + itm_score: tf.Tensor | None = None + loss: tf.Tensor | None = None + image_embeds: tf.Tensor | None = None + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + vision_pooler_output: tf.Tensor | None = None + attentions: tuple[tf.Tensor, ...] | None = None + question_embeds: tuple[tf.Tensor] | None = None + + +@dataclass +class TFBlipOutput(ModelOutput): + """ + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`tf.Tensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`tf.Tensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`tf.Tensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`]. + image_embeds(`tf.Tensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`BlipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`BlipVisionModel`]. + """ + + loss: tf.Tensor | None = None + logits_per_image: tf.Tensor | None = None + logits_per_text: tf.Tensor | None = None + text_embeds: tf.Tensor | None = None + image_embeds: tf.Tensor | None = None + text_model_output: TFBaseModelOutputWithPooling = None + vision_model_output: TFBaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class TFBlipVisionEmbeddings(keras.layers.Layer): + def __init__(self, config: BlipVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = keras.layers.Conv2D( + filters=self.embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + kernel_initializer=get_initializer(self.config.initializer_range), + data_format="channels_last", + name="patch_embedding", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + def build(self, input_shape=None): + self.class_embedding = self.add_weight( + shape=(1, 1, self.embed_dim), + initializer=get_initializer(self.config.initializer_range), + trainable=True, + name="class_embedding", + ) + + self.position_embedding = self.add_weight( + shape=(1, self.num_positions, self.embed_dim), + initializer=get_initializer(self.config.initializer_range), + trainable=True, + name="position_embedding", + ) + + if self.built: + return + self.built = True + if getattr(self, "patch_embedding", None) is not None: + with tf.name_scope(self.patch_embedding.name): + self.patch_embedding.build([None, None, None, 3]) + + def call(self, pixel_values: tf.Tensor) -> tf.Tensor: + # Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch + # likes channels-first convs. + batch_size = tf.shape(pixel_values)[0] + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + patch_embeds = self.patch_embedding(pixel_values) + patch_embeds = tf.reshape(patch_embeds, (batch_size, self.num_patches, -1)) + + class_embeds = tf.broadcast_to(self.class_embedding, (batch_size, 1, self.embed_dim)) + embeddings = tf.concat([class_embeds, patch_embeds], axis=1) + embeddings = embeddings + self.position_embedding[:, : tf.shape(embeddings)[1], :] + return embeddings + + +# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextEmbeddings with CLIP->Blip +class TFBlipTextEmbeddings(keras.layers.Layer): + def __init__(self, config: BlipTextConfig, **kwargs): + super().__init__(**kwargs) + + self.embed_dim = config.hidden_size + + self.config = config + + def build(self, input_shape: tf.TensorShape = None): + with tf.name_scope("token_embedding"): + self.weight = self.add_weight( + shape=(self.config.vocab_size, self.embed_dim), + initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), + trainable=True, + name="weight", + ) + + with tf.name_scope("position_embedding"): + self.position_embedding = self.add_weight( + shape=(self.config.max_position_embeddings, self.embed_dim), + initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), + trainable=True, + name="embeddings", + ) + + super().build(input_shape) + + def call( + self, + input_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if position_ids is None: + position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + + position_embeds = tf.gather(params=self.position_embedding, indices=position_ids) + position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) + final_embeddings = inputs_embeds + position_embeds + + return final_embeddings + + +class TFBlipAttention(keras.layers.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = keras.layers.Dropout(config.attention_dropout, name="dropout") + + self.qkv = keras.layers.Dense( + 3 * self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="qkv" + ) + + self.projection = keras.layers.Dense( + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="projection" + ) + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor | None = None, + output_attentions: bool | None = False, + training: bool | None = None, + ) -> tuple[tf.Tensor, tf.Tensor | None, tuple[tf.Tensor] | None]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + mixed_qkv = self.qkv(hidden_states) + mixed_qkv = tf.reshape(mixed_qkv, (bsz, tgt_len, 3, self.num_heads, self.head_dim)) + mixed_qkv = tf.transpose(mixed_qkv, perm=(2, 0, 3, 1, 4)) + + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = query_states @ tf.transpose(key_states, (0, 1, 3, 2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = tf.transpose(attention_probs @ value_states, perm=(0, 2, 1, 3)) + + new_context_layer_shape = shape_list(context_layer)[:-2] + [self.embed_dim] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + + output = self.projection(context_layer) + + outputs = (output, attention_probs) if output_attentions else (output, None) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "qkv", None) is not None: + with tf.name_scope(self.qkv.name): + self.qkv.build([None, None, self.embed_dim]) + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, self.embed_dim]) + + +class TFBlipMLP(keras.layers.Layer): + def __init__(self, config: BlipConfig, **kwargs): + super().__init__(**kwargs) + + self.activation_fn = get_tf_activation(config.hidden_act) + + in_proj_std = (config.hidden_size**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) + fc_std = (2 * config.hidden_size) ** -0.5 + + self.fc1 = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(fc_std), name="fc1" + ) + self.fc2 = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(in_proj_std), name="fc2" + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.fc1(inputs=hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(inputs=hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.config.hidden_size]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.intermediate_size]) + + +class TFBlipEncoderLayer(keras.layers.Layer): + def __init__(self, config: BlipConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.hidden_size + self.self_attn = TFBlipAttention(config, name="self_attn") + self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") + self.mlp = TFBlipMLP(config, name="mlp") + self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + output_attentions: bool | None = False, + training: bool | None = None, + ) -> tuple[tf.Tensor]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + head_mask=attention_mask, + output_attentions=output_attentions, + training=training, + ) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build([None, None, self.embed_dim]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build([None, None, self.embed_dim]) + + +class TFBlipPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BlipConfig + base_model_prefix = "blip" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + +BLIP_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`BlipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +BLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@keras_serializable +class TFBlipEncoder(keras.layers.Layer): + config_class = BlipConfig + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`BlipEncoderLayer`]. + + Args: + config (`BlipConfig`): + The corresponding vision configuration for the `BlipEncoder`. + """ + + def __init__(self, config: BlipConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layers = [TFBlipEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)] + + @unpack_inputs + def call( + self, + inputs_embeds, + attention_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = None, + ) -> tuple | TFBaseModelOutput: + r""" + Args: + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + training=training, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFBlipVisionModel(TFBlipPreTrainedModel): + main_input_name = "pixel_values" + config_class = BlipVisionConfig + + def __init__(self, config: BlipVisionConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + + self.embeddings = TFBlipVisionEmbeddings(config, name="embeddings") + self.encoder = TFBlipEncoder(config, name="encoder") + self.post_layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm") + self.embed_dim = config.hidden_size + + def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFBaseModelOutputWithPooling( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + hidden_states=hs, + attentions=attns, + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=BlipVisionConfig) + def call( + self, + pixel_values: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = None, + ) -> tuple | TFBaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + # TF gets confused if we call the layer with inputs of different ranks, so insert a singleton dimension + pooled_output = self.post_layernorm(tf.expand_dims(pooled_output, 1)) + pooled_output = tf.squeeze(pooled_output, 1) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "post_layernorm", None) is not None: + with tf.name_scope(self.post_layernorm.name): + self.post_layernorm.build([None, None, self.embed_dim]) + + +class TFBlipMainLayer(keras.layers.Layer): + config_class = BlipConfig + + def __init__(self, config: BlipConfig, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not isinstance(config.text_config, BlipTextConfig): + raise TypeError( + "config.text_config is expected to be of type BlipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, BlipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type BlipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = TFBlipTextModel(text_config, name="text_model") + self.vision_model = TFBlipVisionModel(vision_config, name="vision_model") + + self.visual_projection = keras.layers.Dense( + self.projection_dim, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="visual_projection", + ) + self.text_projection = keras.layers.Dense( + self.projection_dim, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="text_projection", + ) + + self.config = config + + def build(self, input_shape=None): + self.logit_scale = self.add_weight( + name="logit_scale", + shape=[], + initializer=keras.initializers.Constant(self.config.logit_scale_init_value), + trainable=True, + ) + + if self.built: + return + self.built = True + if getattr(self, "text_model", None) is not None: + with tf.name_scope(self.text_model.name): + self.text_model.build(None) + if getattr(self, "vision_model", None) is not None: + with tf.name_scope(self.vision_model.name): + self.vision_model.build(None) + if getattr(self, "visual_projection", None) is not None: + with tf.name_scope(self.visual_projection.name): + self.visual_projection.build([None, None, self.vision_embed_dim]) + if getattr(self, "text_projection", None) is not None: + with tf.name_scope(self.text_projection.name): + self.text_projection.build([None, None, self.text_embed_dim]) + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor | None = None, + pixel_values: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + return_loss: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = None, + ) -> tuple | TFBlipOutput: + # Use BLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / tf.norm(image_embeds, ord=2, axis=-1, keepdims=True) + text_embeds = text_embeds / tf.norm(text_embeds, ord=2, axis=-1, keepdims=True) + + # cosine similarity as logits + logit_scale = tf.exp(self.logit_scale) + logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale + logits_per_image = tf.transpose(logits_per_text) + + loss = None + if return_loss: + loss = blip_loss(logits_per_text) + loss = tf.reshape(loss, (1,)) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return TFBlipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class TFBlipModel(TFBlipPreTrainedModel): + config_class = BlipConfig + _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] + main_input_name = "input_ids" + + def __init__(self, config: BlipConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.blip = TFBlipMainLayer(config, name="blip") + + def serving_output(self, output: TFBlipOutput) -> TFBlipOutput: + return TFBlipOutput( + logits_per_image=output.logits_per_image, + logits_per_text=output.logits_per_text, + text_embeds=output.text_embeds, + image_embeds=output.image_embeds, + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBlipOutput, config_class=BlipConfig) + def call( + self, + input_ids: tf.Tensor | None = None, + pixel_values: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + return_loss: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = None, + ) -> tuple | TFBlipOutput: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, TFBlipModel + + >>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = tf.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities + ```""" + outputs = self.blip( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + return_loss=return_loss, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + return_dict: bool | None = None, + ) -> tf.Tensor: + r""" + Returns: + text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying + the projection layer to the pooled output of [`TFBlipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoProcessor, TFBlipModel + + >>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") + >>> text_features = model.get_text_features(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.blip.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.blip.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: tf.Tensor | None = None, + return_dict: bool | None = None, + ) -> tf.Tensor: + r""" + Returns: + image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying + the projection layer to the pooled output of [`TFBlipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, TFBlipModel + + >>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="tf") + + >>> image_features = model.get_image_features(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.blip.vision_model(pixel_values=pixel_values, return_dict=return_dict) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.blip.visual_projection(pooled_output) + + return image_features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "blip", None) is not None: + with tf.name_scope(self.blip.name): + self.blip.build(None) + + +@add_start_docstrings( + """ + BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass + `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise, + the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption + from the text input. If no text input is provided, the decoder will start with the [BOS] token only. + """, + BLIP_START_DOCSTRING, +) +class TFBlipForConditionalGeneration(TFBlipPreTrainedModel): + config_class = BlipConfig + _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] + main_input_name = "pixel_values" + + def __init__(self, config: BlipConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + + self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model") + + self.text_decoder = TFBlipTextLMHeadModel(config.text_config, name="text_decoder") + + self.decoder_input_ids = config.text_config.bos_token_id + self.decoder_pad_token_id = config.text_config.pad_token_id + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.vision_model.embeddings.patch_embedding + + @unpack_inputs + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBlipForConditionalGenerationModelOutput, config_class=BlipConfig) + def call( + self, + pixel_values: tf.Tensor, + input_ids: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + labels: tf.Tensor | None = None, + return_dict: bool | None = None, + training: bool | None = None, + ) -> tuple | TFBlipForConditionalGenerationModelOutput: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, TFBlipForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + >>> model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "A picture of" + + >>> inputs = processor(images=image, text=text, return_tensors="tf") + + >>> outputs = model(**inputs) + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + image_embeds = vision_outputs[0] + + outputs = self.text_decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + labels=labels, + return_dict=False, + training=training, + ) + + if not return_dict: + outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:] + return tuple(output for output in outputs if output is not None) + + if labels is not None: + loss = outputs[0] + logits = outputs[1] + else: + loss = None + logits = outputs[0] + + if loss is not None and loss.shape.rank == 0: + loss = tf.reshape(loss, (1,)) + + return TFBlipForConditionalGenerationModelOutput( + loss=loss, + logits=logits, + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + def generate( + self, + pixel_values: tf.Tensor, + input_ids: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + **generate_kwargs, + ) -> tf.Tensor: + r""" + Overrides *generate* function to be able to use the model as a conditional generator + + Parameters: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, image_height, image_width)`: + Input image to be processed + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + The sequence used as a prompt for the generation. + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, TFBlipForConditionalGeneration + + >>> model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="tf") + + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + two cats sleeping on a couch + ``` + """ + + batch_size = pixel_values.shape[0] + vision_outputs = self.vision_model(pixel_values=pixel_values) + + image_embeds = vision_outputs[0] + + image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int32) + + if isinstance(input_ids, list): + input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int32) + elif input_ids is None: + input_ids = tf.convert_to_tensor( + [[self.decoder_input_ids, self.config.text_config.eos_token_id]], dtype=tf.int32 + ) + + input_ids = tf.tile(input_ids, (batch_size, 1)) + + # PyTorch: input_ids[:, 0] = self.config.text_config.bos_token_id + input_ids = tf.concat( + [tf.ones((batch_size, 1), dtype=tf.int32) * self.config.text_config.bos_token_id, input_ids[:, 1:]], axis=1 + ) + attention_mask = attention_mask[:, :-1] if attention_mask is not None else None + + outputs = self.text_decoder.generate( + input_ids=input_ids[:, :-1], + eos_token_id=self.config.text_config.sep_token_id, + pad_token_id=self.config.text_config.pad_token_id, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + **generate_kwargs, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vision_model", None) is not None: + with tf.name_scope(self.vision_model.name): + self.vision_model.build(None) + if getattr(self, "text_decoder", None) is not None: + with tf.name_scope(self.text_decoder.name): + self.text_decoder.build(None) + + +@add_start_docstrings( + """ + BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text + decoder. The vision encoder will encode the input image, the text encoder will encode the input question together + with the encoding of the image, and the text decoder will output the answer to the question. + """, + BLIP_START_DOCSTRING, +) +class TFBlipForQuestionAnswering(TFBlipPreTrainedModel): + config_class = BlipConfig + _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] + + def __init__(self, config: BlipConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + + self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model") + + self.text_encoder = TFBlipTextModel(config.text_config, name="text_encoder", add_pooling_layer=False) + + self.text_decoder = TFBlipTextLMHeadModel(config.text_config, name="text_decoder") + + self.decoder_pad_token_id = config.text_config.pad_token_id + self.decoder_start_token_id = config.text_config.bos_token_id + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.vision_model.embeddings.patch_embedding + + # Adapted from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel._shift_right + def _shift_right(self, input_ids): + decoder_start_token_id = self.decoder_start_token_id + pad_token_id = self.decoder_pad_token_id + + if decoder_start_token_id is None or pad_token_id is None: + raise ValueError("decoder_start_token_id and pad_token_id must be defined!") + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)) + + return shifted_input_ids + + @unpack_inputs + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig) + def call( + self, + input_ids: tf.Tensor, + pixel_values: tf.Tensor | None = None, + decoder_input_ids: tf.Tensor | None = None, + decoder_attention_mask: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + labels: tf.Tensor | None = None, + return_dict: bool | None = None, + training: bool | None = None, + ) -> tuple | TFBlipTextVisionModelOutput: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, TFBlipForQuestionAnswering + + >>> model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # training + >>> text = "How many cats are in the picture?" + >>> label = "2" + >>> inputs = processor(images=image, text=text, return_tensors="tf") + >>> labels = processor(text=label, return_tensors="tf").input_ids + + >>> inputs["labels"] = labels + >>> outputs = model(**inputs) + >>> loss = outputs.loss + + >>> # inference + >>> text = "How many cats are in the picture?" + >>> inputs = processor(images=image, text=text, return_tensors="tf") + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + 2 + ```""" + if labels is None and decoder_input_ids is None: + raise ValueError( + "Either `decoder_input_ids` or `labels` should be passed when calling" + " `TFBlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you" + " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`" + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + image_embeds = vision_outputs[0] + image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int64) + + question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + training=training, + ) + + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + if labels is not None and decoder_input_ids is None: + # labels are already shifted right, see: https://github.com/huggingface/transformers/pull/23153 + decoder_input_ids = labels + + answer_output = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=question_embeds, + encoder_attention_mask=attention_mask, + labels=labels, + return_dict=return_dict, + training=training, + ) + + if labels is not None: + decoder_loss = tf.reduce_mean(answer_output.loss) if return_dict else tf.reduce_mean(answer_output[0]) + else: + decoder_loss = None + + if not return_dict: + outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return TFBlipTextVisionModelOutput( + loss=decoder_loss, + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + def generate( + self, + input_ids: tf.Tensor, + pixel_values: tf.Tensor, + attention_mask: tf.Tensor | None = None, + **generate_kwargs, + ) -> tf.Tensor: + r""" + Overrides *generate* function to be able to use the model as a conditional generator + + Parameters: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, image_height, image_width)`: + Input image to be processed + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for + tokens that are NOT MASKED, `0` for MASKED tokens. + generate_kwargs (dict, *optional*): + Additional arguments passed to the `generate` function of the decoder + + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, TFBlipForQuestionAnswering + + >>> model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "How many cats are in the picture?" + + >>> inputs = processor(images=image, text=text, return_tensors="tf") + + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + 2 + ``` + """ + vision_outputs = self.vision_model(pixel_values=pixel_values) + + image_embeds = vision_outputs[0] + + image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int32) + + if isinstance(input_ids, list): + input_ids = tf.Tensor(input_ids) + + question_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=False, + ) + + question_embeds = question_outputs[0] + + question_attention_mask = tf.ones(shape_list(question_embeds)[:-1], dtype=tf.int32) + + bos_ids = tf.fill( + (tf.shape(question_embeds)[0], 1), value=tf.cast(self.decoder_start_token_id, input_ids.dtype) + ) + + outputs = self.text_decoder.generate( + input_ids=bos_ids, + eos_token_id=self.config.text_config.sep_token_id, + pad_token_id=self.config.text_config.pad_token_id, + encoder_hidden_states=question_embeds, + encoder_attention_mask=question_attention_mask, + **generate_kwargs, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vision_model", None) is not None: + with tf.name_scope(self.vision_model.name): + self.vision_model.build(None) + if getattr(self, "text_encoder", None) is not None: + with tf.name_scope(self.text_encoder.name): + self.text_encoder.build(None) + if getattr(self, "text_decoder", None) is not None: + with tf.name_scope(self.text_decoder.name): + self.text_decoder.build(None) + + +@add_start_docstrings( + """ + BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of + image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to + the image. + """, + BLIP_START_DOCSTRING, +) +class TFBlipForImageTextRetrieval(TFBlipPreTrainedModel): + config_class = BlipConfig + + def __init__(self, config: BlipConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + + self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model") + + self.text_encoder = TFBlipTextModel(config.text_config, name="text_encoder", add_pooling_layer=False) + + # vision projection layer + self.vision_proj = keras.layers.Dense( + config.image_text_hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="vision_proj", + ) + + # text projection layer + self.text_proj = keras.layers.Dense( + config.image_text_hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="text_proj", + ) + + # image text matching head + self.itm_head = keras.layers.Dense( + 2, kernel_initializer=get_initializer(config.initializer_range), name="itm_head" + ) + + self.decoder_pad_token_id = ( + config.text_config.pad_token_id + if not hasattr(config, "decoder_pad_token_id") + else config.decoder_pad_token_id + ) + self.decoder_start_token_id = ( + config.text_config.bos_token_id + if not hasattr(config, "decoder_start_token_id") + else config.decoder_start_token_id + ) + self.config = config + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.vision_model.embeddings.patch_embedding + + @unpack_inputs + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBlipImageTextMatchingModelOutput, config_class=BlipVisionConfig) + def call( + self, + input_ids: tf.Tensor, + pixel_values: tf.Tensor | None = None, + use_itm_head: bool | None = True, + attention_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = None, + ) -> tuple | TFBlipImageTextMatchingModelOutput: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, TFBlipForImageTextRetrieval + + >>> model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "an image of a cat" + + >>> inputs = processor(images=image, text=text, return_tensors="tf") + >>> outputs = model(**inputs) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + image_embeds = vision_outputs[0] + image_atts = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int64) + + # Matt: In PyTorch, only one path (itm/non-itm) is taken. However, in TensorFlow this can result in + # some layers not being built! To avoid this, we always call both paths, then use an if statement to select + # which output to pass to the final output. The unnecessary nodes will be pruned from the final graph, but + # not before the layers have all been built correctly. + itm_question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=return_dict, + training=training, + ) + itm_question_embeds = itm_question_embeds[0] if not return_dict else itm_question_embeds.last_hidden_state + + itm_output = self.itm_head(itm_question_embeds[:, 0, :]) + + no_itm_question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=return_dict, + training=training, + ) + no_itm_question_embeds = ( + no_itm_question_embeds[0] if not return_dict else no_itm_question_embeds.last_hidden_state + ) + + image_feat, _ = tf.linalg.normalize(self.vision_proj(image_embeds[:, 0, :]), ord=2, axis=-1) + text_feat, _ = tf.linalg.normalize(self.text_proj(no_itm_question_embeds[:, 0, :]), ord=2, axis=-1) + + no_itm_output = tf.matmul(image_feat, text_feat, transpose_b=True) + + if use_itm_head: + output = itm_output + question_embeds = itm_question_embeds + else: + output = no_itm_output + question_embeds = no_itm_question_embeds + + if not return_dict: + outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,) + return tuple(output for output in outputs if output is not None) + + return TFBlipImageTextMatchingModelOutput( + itm_score=output, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + question_embeds=question_embeds, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vision_model", None) is not None: + with tf.name_scope(self.vision_model.name): + self.vision_model.build(None) + if getattr(self, "text_encoder", None) is not None: + with tf.name_scope(self.text_encoder.name): + self.text_encoder.build(None) + if getattr(self, "vision_proj", None) is not None: + with tf.name_scope(self.vision_proj.name): + self.vision_proj.build([None, None, self.config.vision_config.hidden_size]) + if getattr(self, "text_proj", None) is not None: + with tf.name_scope(self.text_proj.name): + self.text_proj.build([None, None, self.config.text_config.hidden_size]) + if getattr(self, "itm_head", None) is not None: + with tf.name_scope(self.itm_head.name): + self.itm_head.build([None, None, self.config.text_config.hidden_size]) + + +__all__ = [ + "TFBlipModel", + "TFBlipPreTrainedModel", + "TFBlipForConditionalGeneration", + "TFBlipForQuestionAnswering", + "TFBlipVisionModel", + "TFBlipTextModel", + "TFBlipForImageTextRetrieval", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_tf_blip_text.py b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_tf_blip_text.py new file mode 100644 index 0000000000000000000000000000000000000000..7dae1126e03b12b39b17b8fdc00b6050f9681fcd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/modeling_tf_blip_text.py @@ -0,0 +1,1122 @@ +# coding=utf-8 +# Copyright 2023 The Salesforce Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the BSD-3-clause license (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import math + +import tensorflow as tf + +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, +) +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + get_initializer, + get_tf_activation, + keras, + keras_serializable, + shape_list, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, invert_attention_mask, stable_softmax +from ...utils import add_start_docstrings_to_model_forward, logging +from .configuration_blip import BlipTextConfig + + +logger = logging.get_logger(__name__) + +BLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L52 +class TFBlipTextEmbeddings(keras.layers.Layer): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.word_embeddings = keras.layers.Embedding( + config.vocab_size, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="word_embeddings", + ) + self.position_embeddings = keras.layers.Embedding( + config.max_position_embeddings, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="position_embeddings", + ) + + # self.LayerNorm is not snake-cased to stick with PyTorch model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") + + self.position_ids = tf.expand_dims(tf.range(config.max_position_embeddings), 0) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def call(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0, training=None): + if input_ids is not None: + input_shape = tf.shape(input_ids) + else: + input_shape = tf.shape(inputs_embeds)[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings, training=training) + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "word_embeddings", None) is not None: + with tf.name_scope(self.word_embeddings.name): + self.word_embeddings.build(None) + if getattr(self, "position_embeddings", None) is not None: + with tf.name_scope(self.position_embeddings.name): + self.position_embeddings.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97 +class TFBlipTextSelfAttention(keras.layers.Layer): + def __init__(self, config, is_cross_attention, **kwargs): + super().__init__(**kwargs) + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention heads (%d)" + % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = keras.layers.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.is_cross_attention = is_cross_attention + + def transpose_for_scores(self, x): + new_x_shape = tf.concat( + [tf.shape(x)[:-1], tf.constant([self.num_attention_heads, self.attention_head_size], dtype=tf.int32)], + axis=0, + ) + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, perm=(0, 2, 1, 3)) + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=None, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = shape_list(hidden_states)[1] + position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64, device=hidden_states.device), 1) + position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64, device=hidden_states.device), 0) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BlipTextModel forward() function) + attention_scores = attention_scores + tf.cast(attention_mask, attention_scores.dtype) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = attention_probs_dropped @ value_layer + + context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3)) + new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if self.is_cross_attention: + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.encoder_hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.encoder_hidden_size]) + else: + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +class TFBlipTextSelfOutput(keras.layers.Layer): + def __init__(self, config: BlipTextConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool | None = None) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242 +class TFBlipTextAttention(keras.layers.Layer): + def __init__(self, config, is_cross_attention=False, **kwargs): + super().__init__(**kwargs) + self.self = TFBlipTextSelfAttention(config, is_cross_attention, name="self") + # "output" is a protected attribute on TF models + self.self_output = TFBlipTextSelfOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + past_key_value: tuple[tuple[tf.Tensor]] | None = None, + output_attentions: bool | None = False, + training: bool | None = None, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training=training, + ) + attention_output = self.self_output(self_outputs[0], hidden_states, training=training) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "self_output", None) is not None: + with tf.name_scope(self.self_output.name): + self.self_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->BlipText +class TFBlipTextIntermediate(keras.layers.Layer): + def __init__(self, config: BlipTextConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFBlipTextOutput(keras.layers.Layer): + def __init__(self, config: BlipTextConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFBlipTextLayer(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.attention = TFBlipTextAttention(config, name="attention") + if self.config.is_decoder: + self.crossattention = TFBlipTextAttention( + config, is_cross_attention=self.config.is_decoder, name="crossattention" + ) + self.intermediate = TFBlipTextIntermediate(config, name="intermediate") + self.self_output = TFBlipTextOutput(config, name="output") + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + training=training, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + intermediate_output = self.intermediate(attention_output) + layer_output = self.self_output(intermediate_output, attention_output, training=training) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "self_output", None) is not None: + with tf.name_scope(self.self_output.name): + self.self_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L386 +@keras_serializable +class TFBlipTextEncoder(keras.layers.Layer): + config_class = BlipTextConfig + + def __init__(self, config, name=None, **kwargs): + super().__init__(name=name, **kwargs) + self.config = config + self.layer = [TFBlipTextLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + @unpack_inputs + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + training=None, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.is_decoder else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training=training, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->BlipText +class TFBlipTextPooler(keras.layers.Layer): + def __init__(self, config: BlipTextConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->BlipText +class TFBlipTextPredictionHeadTransform(keras.layers.Layer): + def __init__(self, config: BlipTextConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = get_tf_activation(config.hidden_act) + else: + self.transform_act_fn = config.hidden_act + + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(inputs=hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFBlipTextLMPredictionHead(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.transform = TFBlipTextPredictionHeadTransform(config, name="transform") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = keras.layers.Dense( + config.vocab_size, + kernel_initializer=get_initializer(config.initializer_range), + name="decoder", + use_bias=False, + ) + self.config = config + + def build(self, input_shape=None): + self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) + + if self.built: + return + self.built = True + if getattr(self, "transform", None) is not None: + with tf.name_scope(self.transform.name): + self.transform.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build([None, None, self.config.hidden_size]) + + def call(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class TFBlipTextOnlyMLMHead(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.predictions = TFBlipTextLMPredictionHead(config, name="predictions") + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L548 +class TFBlipTextPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BlipTextConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + +# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571 +class TFBlipTextModel(TFBlipTextPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and `is_decoder` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True, name=None, **kwargs): + super().__init__(config, name=name, **kwargs) + self.config = config + + self.embeddings = TFBlipTextEmbeddings(config, name="embeddings") + self.encoder = TFBlipTextEncoder(config, name="encoder") + self.pooler = TFBlipTextPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @tf.function + def get_extended_attention_mask( + self, attention_mask: tf.Tensor, input_shape: tuple[int], is_decoder: bool + ) -> tf.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`tf.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`tuple[int]`): + The shape of the input to the model. + is_decoder (`bool`): + Whether the model is used as a decoder. + + Returns: + `tf.Tensor` The extended attention mask, with the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if not isinstance(attention_mask, tf.Tensor): + attention_mask = tf.convert_to_tensor(attention_mask) # Catches NumPy inputs that haven't been cast yet + if attention_mask.shape.rank == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.shape.rank == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = tf.range(seq_length, dtype=attention_mask.dtype) + causal_mask = tf.broadcast_to(seq_ids, (batch_size, seq_length, seq_length)) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + + if shape_list(causal_mask)[1] < shape_list(attention_mask)[1]: + prefix_seq_len = tf.shape(attention_mask)[1] - tf.shape(causal_mask)[1] + causal_mask = tf.concat( + [ + tf.ones((batch_size, seq_length, prefix_seq_len), dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + tf.cast(causal_mask[:, None, :, :], attention_mask.dtype) * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + encoder_embeds: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + past_key_values: tuple[tuple[tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + is_decoder: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPoolingAndCrossAttentions: + r""" + encoder_hidden_states (`tf.Tensor`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + batch_size, seq_length = input_shape + elif encoder_embeds is not None: + input_shape = shape_list(encoder_embeds)[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = tf.ones((batch_size, seq_length + past_key_values_length)) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: tf.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = shape_list(encoder_hidden_states[0]) + else: + encoder_batch_size, encoder_sequence_length, _ = shape_list(encoder_hidden_states) + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = tf.ones(encoder_hidden_shape) + encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 +class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + self.bert = TFBlipTextModel(config, add_pooling_layer=False, name="bert") + self.cls = TFBlipTextOnlyMLMHead(config, name="cls") + self.label_smoothing = config.label_smoothing + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + training=None, + ): + r""" + encoder_hidden_states (`tf.Tensor`, *optional*): Sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is + configured as a decoder. + encoder_attention_mask (`tf.Tensor`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`tf.Tensor`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + training=training, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :] + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :] + shifted_prediction_scores = tf.reshape(shifted_prediction_scores, (-1, self.config.vocab_size)) + labels = labels[:, 1:] + labels = tf.reshape(labels, (-1,)) + # Keras won't give us label smoothing for sparse CE, so we de-sparsify things here + # Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway) + one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32) + loss_fct = keras.losses.CategoricalCrossentropy( + from_logits=True, label_smoothing=self.label_smoothing, reduction="none" + ) + masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32) + lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores) + lm_loss *= masked_positions + lm_loss = tf.reduce_sum(lm_loss, axis=0) / tf.math.count_nonzero(masked_positions, dtype=tf.float32) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states"), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask"), + "is_decoder": True, + } + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "cls", None) is not None: + with tf.name_scope(self.cls.name): + self.cls.build(None) + + +__all__ = ["TFBlipTextLMHeadModel", "TFBlipTextModel", "TFBlipTextPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip/processing_blip.py b/venv/lib/python3.13/site-packages/transformers/models/blip/processing_blip.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc4334a974cb8e47f1b94b180e1f5cf8f553513 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip/processing_blip.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Blip. +""" + +from typing import Optional, Union + +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +class BlipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } + + +class BlipProcessor(ProcessorMixin): + r""" + Constructs a BLIP processor which wraps a BERT tokenizer and BLIP image processor into a single processor. + + [`BlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`BertTokenizerFast`]. See the + docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information. + + Args: + image_processor (`BlipImageProcessor`): + An instance of [`BlipImageProcessor`]. The image processor is a required input. + tokenizer (`BertTokenizerFast`): + An instance of ['BertTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + + def __init__(self, image_processor, tokenizer, **kwargs): + tokenizer.return_token_type_ids = False + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[str, list[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[BlipProcessorKwargs], + ) -> BatchEncoding: + """ + This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and + [`BertTokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + """ + if images is None and text is None: + raise ValueError("You have to specify either images or text.") + + text_encoding = None + + # add pixel_values encoding. If we also have text_encoding, update image encoding and return it. + # else, return the text encoding. + output_kwargs = self._merge_kwargs( + BlipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if text is not None: + text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if images is not None: + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text_encoding is not None: + encoding_image_processor.update(text_encoding) + return encoding_image_processor + + return text_encoding + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + tokenizer_input_names = [name for name in tokenizer_input_names if name != "token_type_ids"] + return tokenizer_input_names + image_processor_input_names + + +__all__ = ["BlipProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip_2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/blip_2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0717e81fca606edde774ad19092472bf59b4701a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip_2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_blip_2 import * + from .modeling_blip_2 import * + from .processing_blip_2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip_2/configuration_blip_2.py b/venv/lib/python3.13/site-packages/transformers/models/blip_2/configuration_blip_2.py new file mode 100644 index 0000000000000000000000000000000000000000..23145ffc543ff39f21a356b0fa5172eecdaa2e90 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip_2/configuration_blip_2.py @@ -0,0 +1,351 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLIP-2 model configuration""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class Blip2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Blip2VisionModel`]. It is used to instantiate a + BLIP-2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration defaults will yield a similar configuration to that of the BLIP-2 + [Salesforce/blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1408): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 39): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults + to 1e-5): The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries and values in the self-attention layers. + + Example: + + ```python + >>> from transformers import Blip2VisionConfig, Blip2VisionModel + + >>> # Initializing a Blip2VisionConfig with Salesforce/blip2-opt-2.7b style configuration + >>> configuration = Blip2VisionConfig() + + >>> # Initializing a Blip2VisionModel (with random weights) from the Salesforce/blip2-opt-2.7b style configuration + >>> model = Blip2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blip_2_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1408, + intermediate_size=6144, + num_hidden_layers=39, + num_attention_heads=16, + image_size=224, + patch_size=14, + hidden_act="gelu", + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.qkv_bias = qkv_bias + + +class Blip2QFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Blip2QFormerModel`]. It is used to instantiate a + BLIP-2 Querying Transformer (Q-Former) model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the BLIP-2 + [Salesforce/blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b) architecture. Configuration objects + inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + Note that [`Blip2QFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling the model. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Index to be used for padding token. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + cross_attention_frequency (`int`, *optional*, defaults to 2): + The frequency of adding cross-attention to the Transformer layers. + encoder_hidden_size (`int`, *optional*, defaults to 1408): + The hidden size of the hidden states for cross-attention. + use_qformer_text_input (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style embeddings. + + Examples: + + ```python + >>> from transformers import Blip2QFormerConfig, Blip2QFormerModel + + >>> # Initializing a BLIP-2 Salesforce/blip2-opt-2.7b style configuration + >>> configuration = Blip2QFormerConfig() + + >>> # Initializing a model (with random weights) from the Salesforce/blip2-opt-2.7b style configuration + >>> model = Blip2QFormerModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blip_2_qformer" + base_config_key = "qformer_config" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + cross_attention_frequency=2, + encoder_hidden_size=1408, + use_qformer_text_input=False, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.cross_attention_frequency = cross_attention_frequency + self.encoder_hidden_size = encoder_hidden_size + self.use_qformer_text_input = use_qformer_text_input + + +class Blip2Config(PretrainedConfig): + r""" + [`Blip2Config`] is the configuration class to store the configuration of a [`Blip2ForConditionalGeneration`]. It is + used to instantiate a BLIP-2 model according to the specified arguments, defining the vision model, Q-Former model + and language model configs. Instantiating a configuration with the defaults will yield a similar configuration to + that of the BLIP-2 [Salesforce/blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Blip2VisionConfig`]. + qformer_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Blip2QFormerConfig`]. + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize any [`PretrainedConfig`]. + num_query_tokens (`int`, *optional*, defaults to 32): + The number of query tokens passed through the Transformer. + image_text_hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden state of the image-text fusion layer. + + image_token_index (`int`, *optional*): + Token index of special image token. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ( + ... Blip2VisionConfig, + ... Blip2QFormerConfig, + ... OPTConfig, + ... Blip2Config, + ... Blip2ForConditionalGeneration, + ... ) + + >>> # Initializing a Blip2Config with Salesforce/blip2-opt-2.7b style configuration + >>> configuration = Blip2Config() + + >>> # Initializing a Blip2ForConditionalGeneration (with random weights) from the Salesforce/blip2-opt-2.7b style configuration + >>> model = Blip2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Blip2Config from a Blip2VisionConfig, Blip2QFormerConfig and any PretrainedConfig + + >>> # Initializing BLIP-2 vision, BLIP-2 Q-Former and language model configurations + >>> vision_config = Blip2VisionConfig() + >>> qformer_config = Blip2QFormerConfig() + >>> text_config = OPTConfig() + + >>> config = Blip2Config.from_text_vision_configs(vision_config, qformer_config, text_config) + ```""" + + model_type = "blip-2" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = {"text_config": AutoConfig, "qformer_config": Blip2QFormerConfig, "vision_config": Blip2VisionConfig} + + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_text_hidden_size=256, + image_token_index=None, + **kwargs, + ): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. initializing the Blip2VisionConfig with default values.") + + if qformer_config is None: + qformer_config = {} + logger.info("qformer_config is None. Initializing the Blip2QFormerConfig with default values.") + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).") + + self.vision_config = Blip2VisionConfig(**vision_config) + self.qformer_config = Blip2QFormerConfig(**qformer_config) + text_model_type = text_config.get("model_type", "opt") + self.text_config = CONFIG_MAPPING[text_model_type](**text_config) + + self.num_query_tokens = num_query_tokens + self.image_text_hidden_size = image_text_hidden_size + self.image_token_index = image_token_index + self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size + self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + self.is_encoder_decoder = self.text_config.is_encoder_decoder + self.initializer_factor = 1.0 + self.initializer_range = 0.02 + + @classmethod + def from_vision_qformer_text_configs( + cls, + vision_config: Blip2VisionConfig, + qformer_config: Blip2QFormerConfig, + text_config: Optional[PretrainedConfig] = None, + **kwargs, + ): + r""" + Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision model, Q-Former and language model + configurations. + + Args: + vision_config (`dict`): + Dictionary of configuration options used to initialize [`Blip2VisionConfig`]. + qformer_config (`dict`): + Dictionary of configuration options used to initialize [`Blip2QFormerConfig`]. + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize any [`PretrainedConfig`]. + + Returns: + [`Blip2Config`]: An instance of a configuration object + """ + + return cls( + vision_config=vision_config.to_dict(), + qformer_config=qformer_config.to_dict(), + text_config=text_config.to_dict() if text_config is not None else None, + **kwargs, + ) + + +__all__ = ["Blip2Config", "Blip2QFormerConfig", "Blip2VisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip_2/modeling_blip_2.py b/venv/lib/python3.13/site-packages/transformers/models/blip_2/modeling_blip_2.py new file mode 100644 index 0000000000000000000000000000000000000000..1246b70e1370e3c85af7bcd999a33e44da319836 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip_2/modeling_blip_2.py @@ -0,0 +1,2239 @@ +# coding=utf-8 +# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BLIP-2 model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithPast, + Seq2SeqLMOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + can_return_tuple, + filter_out_non_signature_kwargs, + logging, + torch_int, +) +from ...utils.generic import OutputRecorder, check_model_inputs +from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM +from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class defining the outputs of [`Blip2ForConditionalGeneration`]. + """ +) +class Blip2ForConditionalGenerationModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Language modeling loss from the language model. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head of the language model. + vision_outputs (`BaseModelOutputWithPooling`): + Outputs of the vision encoder. + qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): + Outputs of the Q-Former (Querying Transformer). + language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`): + Outputs of the language model. + """ + + loss: Optional[tuple[torch.FloatTensor]] = None + logits: Optional[tuple[torch.FloatTensor]] = None + vision_outputs: Optional[torch.FloatTensor] = None + qformer_outputs: Optional[tuple[torch.FloatTensor]] = None + language_model_outputs: Optional[tuple[torch.FloatTensor]] = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] + if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"] + else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +@auto_docstring +class Blip2ImageTextMatchingModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`Blip2QFormerModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Blip2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Blip2 +class Blip2TextModelOutput(ModelOutput): + r""" + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Blip2 +class Blip2VisionModelOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 +class Blip2VisionEmbeddings(nn.Module): + def __init__(self, config: Blip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> BLIP doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Blip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.is_causal = False + self.attention_dropout = config.attention_dropout + + # small tweak here compared to CLIP, no bias here + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) + + if config.qkv_bias: + q_bias = nn.Parameter(torch.zeros(self.embed_dim)) + v_bias = nn.Parameter(torch.zeros(self.embed_dim)) + else: + q_bias = None + v_bias = None + + if q_bias is not None: + qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias)) + self.qkv.bias = nn.Parameter(qkv_bias) + + self.projection = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = self.qkv(hidden_states) + + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( + 2, 0, 3, 1, 4 + ) + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.projection(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.blip.modeling_blip.BlipMLP +class Blip2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2 +class Blip2EncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Blip2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Blip2Attention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Blip2MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + head_mask=attention_mask, + **kwargs, + ) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + return hidden_states + + +@auto_docstring +class Blip2PreTrainedModel(PreTrainedModel): + config: Blip2Config + base_model_prefix = "blip" + supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _no_split_modules = [ + "Blip2Attention", + "Blip2QFormerMultiHeadAttention", + "Blip2EncoderLayer", + "Blip2TextEmbeddings", + "T5Block", + "OPTDecoderLayer", + ] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_range + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=factor) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Blip2VisionEmbeddings): + nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + elif isinstance( + module, + ( + Blip2Model, + Blip2TextModelWithProjection, + Blip2VisionModelWithProjection, + Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, + ), + ): + module.query_tokens.data.zero_() + + +# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2 +class Blip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Blip2EncoderLayer`]. + + Args: + config (`Blip2Config`): + The corresponding vision configuration for the `Blip2Encoder`. + """ + + def __init__(self, config: Blip2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Blip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +@auto_docstring +# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 +class Blip2VisionModel(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config: Blip2VisionConfig + _can_record_outputs = { + "hidden_states": Blip2EncoderLayer, + "attentions": Blip2Attention, + } + + def __init__(self, config: Blip2VisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Blip2VisionEmbeddings(config) + self.encoder = Blip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPooling]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + def get_input_embeddings(self): + return self.embeddings + + +class Blip2QFormerMultiHeadAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention heads (%d)" + % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) + self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return ( + context_layer, + attention_probs, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Blip2QFormer +class Blip2QFormerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.attention = Blip2QFormerMultiHeadAttention(config, is_cross_attention) + self.output = Blip2QFormerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + attn_output, _ = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + attention_output = self.output(attn_output, hidden_states) + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Blip2QFormer +class Blip2QFormerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Blip2QFormer +class Blip2QFormerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + if config.use_qformer_text_input: + self.intermediate = Blip2QFormerIntermediate(config) + self.output = Blip2QFormerOutput(config) + + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + query_length=0, + **kwargs: Unpack[TransformersKwargs], + ): + attention_output = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + **kwargs, + ) + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + query_attention_output = self.crossattention( + hidden_states=query_attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + return layer_output + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class Blip2QFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + query_length=0, + **kwargs: Unpack[TransformersKwargs], + ): + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + layer_head_mask = head_mask[i] if head_mask is not None else None + + hidden_states = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + query_length=query_length, + **kwargs, + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + ) + + +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if input_ids is not None: + input_ids = input_ids.to(self.word_embeddings.weight.device) + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if query_embeds is not None: + # `query_embeds` are kept in fp32 when we use it with Qformer + if query_embeds.dtype != embeddings.dtype: + query_embeds = query_embeds.to(embeddings.dtype) + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + return embeddings + + +@auto_docstring( + custom_intro=""" + BLIP-2 Querying Transformer (Q-Former). + """ +) +class Blip2QFormerModel(Blip2PreTrainedModel): + _supports_attention_backend = False # adds position on attn weights before last matmul + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + + _can_record_outputs = { + "hidden_states": Blip2QFormerLayer, + "attentions": [ + OutputRecorder(Blip2QFormerMultiHeadAttention, index=1, layer_name=".attention"), + ], + "cross_attentions": [ + OutputRecorder(Blip2QFormerMultiHeadAttention, index=1, layer_name=".crossattention"), + ], + } + + def __init__(self, config: Blip2QFormerConfig): + super().__init__(config) + self.config = config + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.encoder = Blip2QFormerEncoder(config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: tuple[int], + device: torch.device, + has_query: bool = False, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`tuple[int]`): + The shape of the input to the model. + device (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + @check_model_inputs() + @auto_docstring + def forward( + self, + query_embeds: torch.FloatTensor, + query_length: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Hidden states to be used in the attention computation. If cross-attention, + will be used for the query (i.e., key and value will use the encoder_hidden_states). + query_length (`int`, *optional*): + Length of the query, usually based on the number of query tokens. + If no value is provided, query_length will be inferred by the query_embeds. + """ + query_length = ( + query_length if query_length is not None else query_embeds.shape[1] if query_embeds is not None else 0 + ) + + # `Blip2QFormerModel` is kept as fp32 + query_embeds = query_embeds.to(self.layernorm.weight.dtype) + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + # Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already + if encoder_hidden_states.dtype != query_embeds.dtype: + encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype) + + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs: BaseModelOutput = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + query_length=query_length, + **kwargs, + ) + sequence_output = encoder_outputs.last_hidden_state + pooled_output = sequence_output[:, 0, :] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + +@auto_docstring( + custom_intro=""" + BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer + (Q-Former) and a language model. + """ +) +class Blip2Model(Blip2PreTrainedModel): + config: Blip2Config + main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn = False # because self.qformer does not support FA2 + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel._from_config(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + + # Update _tied_weights_keys using the base model used. + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + + self.language_model = language_model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.language_model.get_output_embeddings() + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + legacy_output: bool = True, + ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + legacy_output (`bool`, *optional*, defaults to `True`): + Whether to return a model output object or a tensor of features. + + Returns: + text_outputs (`CausalLMOutputWithPast` or `torch.FloatTensor`): + The language model outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`. + + Examples: + ```python + >>> import torch + >>> from transformers import AutoTokenizer, Blip2Model + + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b") + + >>> inputs = tokenizer(["a photo of a cat"], padding=True, return_tensors="pt") + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ```""" + + if legacy_output: + warnings.warn( + "Deprecation notice: In Transformers v4.59, the default return value of `get_text_features` will change. " + "Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. " + "To opt in to the new behavior now, set `legacy_output=False`." + ) + + if self.config.use_decoder_only_language_model: + text_outputs: CausalLMOutputWithPast = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + else: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + text_outputs: Seq2SeqLMOutput = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=labels, + return_dict=True, + ) + + return text_outputs if legacy_output else text_outputs.logits + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + legacy_output: bool = True, + ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]: + r""" + legacy_output (`bool`, *optional*, defaults to `True`): + Whether to return a model output object or a tensor of features. + + Returns: + vision_outputs (`BaseModelOutputWithPooling` or `torch.FloatTensor`): + The vision model outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`. + + Examples: + ```python + >>> import torch + >>> from transformers import AutoProcessor, Blip2Model + >>> from transformers.image_utils import load_image + + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b") + + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.inference_mode(): + ... image_outputs = model.get_image_features(**inputs) + ```""" + if legacy_output: + warnings.warn( + "Deprecation notice: In Transformers v4.59, the default return value of `get_text_features` will change. " + "Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. " + "To opt in to the new behavior now, set `legacy_output=False`." + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + return vision_outputs if legacy_output else vision_outputs.pooler_output + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_qformer_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + legacy_output: bool = True, + ) -> Union[torch.FloatTensor, BaseModelOutputWithPooling]: + r""" + legacy_output (`bool`, *optional*, defaults to `True`): + Whether to return a model output object or a tensor of features. + + Returns: + qformer_outputs (`BaseModelOutputWithPooling` or `torch.FloatTensor`): + The Q-Former outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, Blip2Model + >>> from transformers.image_utils import load_image + + >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.inference_mode(): + ... qformer_outputs = model.get_qformer_features(**inputs) + ```""" + + if legacy_output: + warnings.warn( + "Deprecation notice: In Transformers v4.59, the default return value of `get_qformer_features` will change. " + "Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. " + "To opt in to the new behavior now, set `legacy_output=False`." + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + image_embeds = vision_outputs.last_hidden_state + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + + return query_outputs if legacy_output else query_outputs.last_hidden_state + + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Blip2ForConditionalGenerationModelOutput]: + r""" + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import Blip2Processor, Blip2Model + >>> import torch + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", dtype=torch.float16) + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> prompt = "Question: how many cats are there? Answer:" + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) + + >>> outputs = model(**inputs) + ```""" + + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + **kwargs, + ) + query_output = query_outputs[0] + + # Qformer is kept in fp32, we downcast the output back if needed + if query_output.dtype != image_embeds.dtype: + query_output = query_output.to(image_embeds.dtype) + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( + special_image_mask, language_model_inputs + ) + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **kwargs, + ) + logits = outputs[0] + loss = None + # we compute the loss here since we need to take into account the sequence length of the query embeds + if labels is not None: + labels = labels.to(logits.device) + logits = logits[:, -labels.size(1) :, :] + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(logits.device) + + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction="mean") + + loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) + else: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=labels, + return_dict=True, + **kwargs, + ) + loss = outputs.loss + logits = outputs.logits + + return Blip2ForConditionalGenerationModelOutput( + loss=loss, + logits=logits, + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + +@auto_docstring +class Blip2TextModelWithProjection(Blip2PreTrainedModel): + supports_gradient_checkpointing = False + _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn = False # because self.qformer does not support FA2 + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.embeddings = Blip2TextEmbeddings(config.qformer_config) + self.qformer = Blip2QFormerModel(config.qformer_config) + + # text projection layer + self.text_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Blip2TextModelOutput]: + r""" + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, Blip2TextModelWithProjection + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2TextModelWithProjection.from_pretrained( + ... "Salesforce/blip2-itm-vit-g", dtype=torch.float16 + ... ) + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g") + + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], return_tensors="pt").to(device) + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + >>> print(text_embeds.shape) + torch.Size([2, 7, 256]) + ```""" + + query_embeds = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + ) + + text_outputs = self.qformer( + query_embeds=query_embeds, + query_length=0, + attention_mask=attention_mask, + **kwargs, + ) + + pooled_output = text_outputs[0] + pooled_output = pooled_output.to(dtype=self.text_projection.weight.dtype) + + text_embeds = self.text_projection(pooled_output) + text_embeds = nn.functional.normalize(text_embeds, dim=-1) + + return Blip2TextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + ) + + +@auto_docstring +class Blip2VisionModelWithProjection(Blip2PreTrainedModel): + main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn = False # because self.qformer does not support FA2 + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel._from_config(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) + + # vision projection layer + self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Blip2VisionModelOutput]: + r""" + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, Blip2VisionModelWithProjection + >>> from transformers.image_utils import load_image + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g") + >>> model = Blip2VisionModelWithProjection.from_pretrained( + ... "Salesforce/blip2-itm-vit-g", dtype=torch.float16 + ... ) + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + + >>> with torch.inference_mode(): + ... outputs = model(**inputs) + >>> image_embeds = outputs.image_embeds + >>> print(image_embeds.shape) + torch.Size([1, 32, 256]) + ```""" + vision_outputs = self.vision_model( + pixel_values=pixel_values, + **kwargs, + ) + + pooled_output = vision_outputs[0] + image_attention_mask = torch.ones(pooled_output.size()[:-1], dtype=torch.long, device=pooled_output.device) + query_tokens = self.query_tokens.expand(pooled_output.shape[0], -1, -1) + + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=pooled_output, + encoder_attention_mask=image_attention_mask, + **kwargs, + ) + + embeds = query_outputs[0] + embeds = embeds.to(dtype=self.vision_projection.weight.dtype) + image_embeds = self.vision_projection(embeds) + image_embeds = nn.functional.normalize(image_embeds, dim=-1) + + return Blip2VisionModelOutput( + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision + encoder, Querying Transformer (Q-Former) and a language model. + + One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue + the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token. + + + + Note that Flan-T5 checkpoints cannot be cast to float16. They are pre-trained using bfloat16. + + + """ +) +class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): + config: Blip2Config + main_input_name = "pixel_values" + + _can_compile_fullgraph = True + _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn = False # because self.qformer does not support FA2 + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel._from_config(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + + # Update _tied_weights_keys using the base model used. + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + + self.language_model = language_model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.language_model.get_output_embeddings() + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared + + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + query_output = query_outputs[0] + + # Qformer is kept in fp32, we downcast the output back if needed + if query_output.dtype != image_embeds.dtype: + query_output = query_output.to(image_embeds.dtype) + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + if return_dict: + return language_model_inputs, vision_outputs, query_outputs + return language_model_inputs + + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Blip2ForConditionalGenerationModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be + provided to serve as text prompt, which the language model can continue. + + Indices can be obtained using [`Blip2Processor`]. See [`Blip2Processor.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + + Examples: + + Prepare processor, model and image input + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import Blip2Processor, Blip2ForConditionalGeneration + >>> import torch + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> model = Blip2ForConditionalGeneration.from_pretrained( + ... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, dtype=torch.float16 + ... ) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + ``` + + Image captioning (without providing a text prompt): + + ```python + >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + + >>> generated_ids = model.generate(**inputs) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + >>> print(generated_text) + two cats laying on a couch + ``` + + Visual question answering (prompt = question): + + ```python + >>> prompt = "Question: how many cats are there? Answer:" + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16) + + >>> generated_ids = model.generate(**inputs) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + >>> print(generated_text) + two + ``` + + Note that int8 inference is also supported through [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). + This greatly reduces the amount of memory used by the model while maintaining the same performance. + + ```python + >>> model = Blip2ForConditionalGeneration.from_pretrained( + ... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, dtype=torch.bfloat16 + ... ) # doctest: +IGNORE_RESULT + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16) + + >>> generated_ids = model.generate(**inputs) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + >>> print(generated_text) + two + ```""" + + language_model_inputs, vision_outputs, query_outputs = self.get_image_features( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True + ) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( + special_image_mask, language_model_inputs + ) + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **kwargs, + ) + logits = outputs[0] + loss = None + # we compute the loss here since we need to take into account the sequence length of the query embeds + if labels is not None: + labels = labels.to(logits.device) + logits = logits[:, -labels.size(1) :, :] + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(logits.device) + + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction="mean") + + loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) + else: + kwargs["return_dict"] = True + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=labels, + **kwargs, + ) + loss = outputs.loss + logits = outputs.logits + + return Blip2ForConditionalGenerationModelOutput( + loss=loss, + logits=logits, + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + @torch.no_grad() + def generate( + self, + pixel_values: torch.FloatTensor, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + **generate_kwargs, + ) -> torch.LongTensor: + """ + Overrides `generate` function to be able to use the model as a conditional generator. + + Args: + pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): + Input images to be processed. + input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + The sequence used as a prompt for the generation. + attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + Mask to avoid performing attention on padding token indices + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the positional encoding of the image embeddings. + + Returns: + captions (list): A list of strings of length batch_size * num_captions. + """ + if hasattr(self, "hf_device_map"): + # preprocess for `accelerate` + self._preprocess_accelerate() + + batch_size = pixel_values.shape[0] + image_embeds = self.vision_model( + pixel_values, + return_dict=True, + interpolate_pos_encoding=interpolate_pos_encoding, + ).last_hidden_state + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + query_output = query_outputs.last_hidden_state + + # Qformer is kept in fp32, we downcast the output back if needed + if query_output.dtype != image_embeds.dtype: + query_output = query_output.to(image_embeds.dtype) + + language_model_inputs = self.language_projection(query_output) + + if inputs_embeds is None: + if input_ids is None: + image_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens = image_tokens + [self.config.text_config.bos_token_id] + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + + inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} + if not self.language_model.config.is_encoder_decoder: + inputs["input_ids"] = input_ids + + outputs = self.language_model.generate(**inputs, **generate_kwargs) + + return outputs + + +@auto_docstring( + custom_intro=""" + BLIP-2 Model with a vision and text projector, and a classification head on top. The model is used in the context + of image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to + the image. + """ +) +class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): + main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn = False # because self.qformer does not support FA2 + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel._from_config(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + + self.embeddings = Blip2TextEmbeddings(config.qformer_config) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) + + # vision projection layer + self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # image text matching head + self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + use_image_text_matching_head: Optional[bool] = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Blip2ImageTextMatchingModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be + provided to serve as text prompt, which the language model can continue. + + Indices can be obtained using [`Blip2Processor`]. See [`Blip2Processor.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + use_image_text_matching_head (`bool`, *optional*): + Whether to return the Image-Text Matching or Contrastive scores. + + Examples: + + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2ForImageTextRetrieval + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g", dtype=torch.float16) + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g") + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "two cats laying on a pink blanket" + + >>> inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16) + >>> itm_out = model(**inputs, use_image_text_matching_head=True) + >>> logits_per_image = torch.nn.functional.softmax(itm_out.logits_per_image, dim=1) + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + + >>> print(f"{probs[0][0]:.1%} that image 0 is not '{text}'") + 26.9% that image 0 is not 'two cats laying on a pink blanket' + + >>> print(f"{probs[0][1]:.1%} that image 0 is '{text}'") + 73.0% that image 0 is 'two cats laying on a pink blanket' + + >>> texts = ["a photo of a cat", "a photo of a dog"] + + >>> inputs = processor(images=image, text=texts, return_tensors="pt").to(device, torch.float16) + >>> itc_out = model(**inputs, use_image_text_matching_head=False) + >>> logits_per_image = itc_out.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 55.3% that image 0 is 'a photo of a cat' + + >>> print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'") + 44.7% that image 0 is 'a photo of a dog' + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if use_image_text_matching_head: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + if self.config.image_token_index is not None: + input_ids = input_ids[:, self.config.num_query_tokens :] + else: + query_attention_mask = torch.ones( + query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device + ) + attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1) + + query_embeds = self.embeddings( + input_ids=input_ids, + query_embeds=query_tokens, + ) + + text_outputs = self.qformer( + query_embeds=query_embeds, + query_length=query_tokens.shape[1], + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + text_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state + text_embeds = text_embeds.to(dtype=self.itm_head.weight.dtype) + + output = self.itm_head(text_embeds[:, : query_tokens.size(1), :]) + logits_per_image = output.mean(dim=1) + logits_per_text = logits_per_image.t() + else: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + image_embeds = image_embeds.to(dtype=self.vision_projection.weight.dtype) + + if self.config.image_token_index is not None: + input_ids = input_ids[:, self.config.num_query_tokens :] + attention_mask = attention_mask[:, self.config.num_query_tokens :] + + query_embeds = self.embeddings( + input_ids=input_ids, + ) + text_outputs = self.qformer( + query_embeds=query_embeds, + query_length=0, + attention_mask=attention_mask, + return_dict=return_dict, + ) + question_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state + question_embeds = question_embeds.to(dtype=self.text_projection.weight.dtype) + + # normalized features + image_embeds = nn.functional.normalize(self.vision_projection(image_embeds), dim=-1) + text_embeds = nn.functional.normalize(self.text_projection(question_embeds[:, 0, :]), dim=-1) + + # cosine similarity as logits + logits_per_image = torch.matmul(image_embeds, text_embeds.t()) + logits_per_image, _ = logits_per_image.max(dim=1) + + logits_per_text = logits_per_image.t() + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return output + + return Blip2ImageTextMatchingModelOutput( + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = [ + "Blip2Model", + "Blip2VisionModelWithProjection", + "Blip2QFormerModel", + "Blip2PreTrainedModel", + "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", + "Blip2VisionModel", + "Blip2TextModelWithProjection", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blip_2/processing_blip_2.py b/venv/lib/python3.13/site-packages/transformers/models/blip_2/processing_blip_2.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c89f7f460a2954f4b8638cc51bf3bc6f248e84 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blip_2/processing_blip_2.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for BLIP-2. +""" + +from typing import Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AddedToken, BatchEncoding, PreTokenizedInput, TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Blip2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } + + +class Blip2Processor(ProcessorMixin): + r""" + Constructs a BLIP-2 processor which wraps a BLIP image processor and an OPT/T5 tokenizer into a single processor. + + [`BlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the docstring + of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information. + + Args: + image_processor (`BlipImageProcessor`): + An instance of [`BlipImageProcessor`]. The image processor is a required input. + tokenizer (`AutoTokenizer`): + An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. + num_query_tokens (`int`, *optional*): + Number of tokens used by the Qformer as queries, should be same as in model's config. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs): + tokenizer.return_token_type_ids = False + self.current_processor = image_processor + if not hasattr(tokenizer, "image_token"): + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + else: + self.image_token = tokenizer.image_token + self.num_query_tokens = num_query_tokens + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[str, list[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Blip2ProcessorKwargs], + ) -> BatchEncoding: + """ + This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and + [`BertTokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + """ + if images is None and text is None: + raise ValueError("You have to specify either images or text.") + output_kwargs = self._merge_kwargs( + Blip2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # BC for explicit return_tensors + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + max_length = output_kwargs["text_kwargs"].pop("max_length", None) + if max_length is not None: + output_kwargs["text_kwargs"]["max_length"] = max_length - self.num_query_tokens + + encoding = BatchFeature(tensor_type=return_tensors) + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token + text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if images is not None and self.num_query_tokens is not None: + # Image tokens should not be padded/truncated or prepended with special BOS token + image_tokens = self.image_token.content * self.num_query_tokens + output_kwargs["text_kwargs"]["add_special_tokens"] = False + output_kwargs["text_kwargs"]["padding"] = False + output_kwargs["text_kwargs"]["truncation"] = False + image_text_encoding = self.tokenizer(image_tokens, **output_kwargs["text_kwargs"]) + for k in text_encoding: + text_encoding[k] = [image_text_encoding[k] + sample for sample in text_encoding[k]] + encoding.update(text_encoding) + + # Now add pixel_values encoding. If we also have text_encoding, update image encoding and return it. + # else, return the text encoding. + if images is not None: + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) + encoding.update(image_encoding) + + # Cast to desired return tensors type + encoding = BatchFeature(encoding, tensor_type=return_tensors) + return encoding + + +__all__ = ["Blip2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/blt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..703b81ecdd09dda47a97c641f7e440bcb5e81119 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blt/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_blt import * + from .modeling_blt import * + from .tokenization_blt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/blt/configuration_blt.py b/venv/lib/python3.13/site-packages/transformers/models/blt/configuration_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc6718e5bd151e902dc755e21aa07cccbd6e73f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blt/configuration_blt.py @@ -0,0 +1,423 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Blt model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BltLocalEncoderConfig(PretrainedConfig): + """ + Configuration class for the Blt Local Encoder component. + """ + + model_type = "blt_local_encoder" + + def __init__( + self, + vocab_size=260, + cross_attn_all_layers=False, + cross_attn_k=2, + hidden_size_global=2048, + hidden_size=1024, + num_attention_heads=16, + num_key_value_heads=None, + num_hidden_layers=1, + rms_norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=24576, + rope_theta=500000.0, + rope_scaling=None, + hidden_act="silu", + intermediate_size=2816, + initializer_range=0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.cross_attn_all_layers = cross_attn_all_layers + self.cross_attn_k = cross_attn_k + self.hidden_size_global = hidden_size_global + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) + self.num_hidden_layers = num_hidden_layers + self.rms_norm_eps = rms_norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.hidden_act = hidden_act + self.initializer_range = initializer_range + + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) + + +class BltLocalDecoderConfig(PretrainedConfig): + """ + Configuration class for the Blt Local Decoder component. + """ + + model_type = "blt_local_decoder" + + def __init__( + self, + vocab_size=260, + cross_attn_all_layers=True, + cross_attn_k=2, + hidden_size_global=2048, + hidden_size=1024, + num_attention_heads=16, + num_key_value_heads=None, + num_hidden_layers=9, + rms_norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=24576, + rope_theta=500000.0, + rope_scaling=None, + hidden_act="silu", + intermediate_size=2816, + initializer_range=0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.cross_attn_all_layers = cross_attn_all_layers + self.cross_attn_k = cross_attn_k + self.hidden_size_global = hidden_size_global + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) + self.num_hidden_layers = num_hidden_layers + self.rms_norm_eps = rms_norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.hidden_act = hidden_act + self.initializer_range = initializer_range + + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) + + +class BltGlobalTransformerConfig(PretrainedConfig): + """ + Configuration class for the Blt Global Transformer component. + """ + + model_type = "blt_global_transformer" + + def __init__( + self, + hidden_size=2048, + num_attention_heads=16, + num_key_value_heads=None, + num_hidden_layers=25, + rms_norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=4096, + rope_theta=500000.0, + rope_scaling=None, + hidden_act="silu", + intermediate_size=5632, + initializer_range=0.02, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) + self.num_hidden_layers = num_hidden_layers + self.rms_norm_eps = rms_norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.hidden_act = hidden_act + self.initializer_range = initializer_range + + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) + + +class BltPatcherConfig(PretrainedConfig): + r""" + Configuration class for the Blt Patcher/Entropy model component. + + Args: + vocab_size (`int`, *optional*, defaults to 260): + Vocabulary size of the Blt patcher model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling the patcher model. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + num_hidden_layers (`int`, *optional*, defaults to 14): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + rope_scaling (`dict`, *optional*): + Dictionary containing the RoPE scaling configuration. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "blt_patcher" + + def __init__( + self, + vocab_size=260, + hidden_size=768, + num_hidden_layers=14, + num_attention_heads=12, + num_key_value_heads=None, + max_position_embeddings=8192, + rms_norm_eps=1e-5, + dropout=0.0, + rope_theta=10000.0, + intermediate_size=2048, + rope_scaling=None, + initializer_range=0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.dropout = dropout + self.rope_theta = rope_theta + self.hidden_act = "silu" # Blt uses silu activation + self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) + self.rope_scaling = rope_scaling + self.initializer_range = initializer_range + + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) + + +class BltConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BltModel`]. It is used to instantiate a + Blt model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 260): + Vocabulary size of the Blt model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BltModel`]. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + patch_in_forward (`bool`, *optional*, defaults to `True`): + Whether to perform patching during the forward pass. + patch_size (`int`, *optional*, defaults to 4): + Size of the patches used in the patching mechanism. + patching_mode (`str`, *optional*, defaults to `"entropy"`): + The mode used for patching, such as entropy-based patching. + patching_threshold (`float`, *optional*, defaults to 1.34): + Threshold value used for determining when to apply patches. + patching_batch_size (`int`, *optional*, defaults to 1): + Batch size used during the patching process. + max_patch_length (`int`, *optional*): + Maximum length of patches that can be generated. + cross_attn_k (`int`, *optional*, defaults to 2): + Number of cross-attention heads used in the model. + encoder_hash_byte_group_size (`list`, *optional*): + List of byte group sizes used in the encoder hash function. + encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 500002): + Vocabulary size for the encoder hash byte groups. + encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 1): + Number of hash functions used in the encoder byte grouping. + patcher_config (`BltPatcherConfig`, *optional*): + Configuration for the patcher component of the model. + encoder_config (`BltLocalEncoderConfig`, *optional*): + Configuration for the local encoder component of the model. + decoder_config (`BltLocalDecoderConfig`, *optional*): + Configuration for the local decoder component of the model. + global_config (`BltGlobalTransformerConfig`, *optional*): + Configuration for the global transformer component of the model. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rope_theta (`float`, *optional*, defaults to 500000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + Dictionary containing the RoPE scaling configuration. + + ```python + >>> from transformers import BltModel, BltConfig + + >>> # Initializing a Blt configuration + >>> configuration = BltConfig() + + >>> # Initializing a model from the configuration + >>> model = BltModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + + Checkpoint: [facebook/blt](https://huggingface.co/facebook/blt) + """ + + model_type = "blt" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = { + "patcher_config": BltPatcherConfig, + "encoder_config": BltLocalEncoderConfig, + "decoder_config": BltLocalDecoderConfig, + "global_config": BltGlobalTransformerConfig, + } + + def __init__( + self, + vocab_size=260, + max_position_embeddings=4096, + patch_in_forward=True, + patch_size=4, + patching_mode="entropy", + patching_threshold=1.335442066192627, + patching_batch_size=1, + max_patch_length=None, + cross_attn_k=2, + encoder_hash_byte_group_size=None, + encoder_hash_byte_group_vocab=500002, + encoder_hash_byte_group_nb_functions=1, + patcher_config=None, + encoder_config=None, + decoder_config=None, + global_config=None, + tie_word_embeddings=False, + initializer_range=0.02, + rope_theta=500000.0, + rope_scaling=None, + **kwargs, + ): + # Basic model configuration + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + # Patching configuration + self.patch_in_forward = patch_in_forward + self.patch_size = patch_size + self.patching_mode = patching_mode + self.patching_threshold = patching_threshold + self.patching_batch_size = patching_batch_size + self.max_patch_length = max_patch_length + self.patching_device = kwargs.get("patching_device", "cuda") + self.realtime_patching = kwargs.get("realtime_patching", True) + self.patching_threshold_add = kwargs.get("patching_threshold_add") + self.monotonicity = kwargs.get("monotonicity", False) + + # Cross attention configurations + self.cross_attn_k = cross_attn_k + + # Encoder configurations + self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [3, 4, 5, 6, 7, 8] + self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions + + # Initialize component configurations + if patcher_config is None: + self.patcher_config = BltPatcherConfig(initializer_range=initializer_range) + logger.info("patcher_config is None, using default Blt patcher config") + elif isinstance(patcher_config, dict): + patcher_config.setdefault("initializer_range", initializer_range) + self.patcher_config = BltPatcherConfig(**patcher_config) + elif isinstance(patcher_config, BltPatcherConfig): + self.patcher_config = patcher_config + + if encoder_config is None: + self.encoder_config = BltLocalEncoderConfig(initializer_range=initializer_range) + logger.info("encoder_config is None, using default Blt encoder config") + elif isinstance(encoder_config, dict): + encoder_config.setdefault("initializer_range", initializer_range) + self.encoder_config = BltLocalEncoderConfig(**encoder_config) + elif isinstance(encoder_config, BltLocalEncoderConfig): + self.encoder_config = encoder_config + + if decoder_config is None: + self.decoder_config = BltLocalDecoderConfig(initializer_range=initializer_range) + logger.info("decoder_config is None, using default Blt decoder config") + elif isinstance(decoder_config, dict): + decoder_config.setdefault("initializer_range", initializer_range) + self.decoder_config = BltLocalDecoderConfig(**decoder_config) + elif isinstance(decoder_config, BltLocalDecoderConfig): + self.decoder_config = decoder_config + + if global_config is None: + self.global_config = BltGlobalTransformerConfig(initializer_range=initializer_range) + logger.info("global_config is None, using default Blt global config") + elif isinstance(global_config, dict): + global_config.setdefault("initializer_range", initializer_range) + self.global_config = BltGlobalTransformerConfig(**global_config) + elif isinstance(global_config, BltGlobalTransformerConfig): + self.global_config = global_config + + # Determine if token embedding projection is needed based on dimension mismatch (7b) + encoder_cross_output_size = self.encoder_config.hidden_size * self.cross_attn_k + self.global_config.encoder_cross_output_size = ( + encoder_cross_output_size if encoder_cross_output_size != self.global_config.hidden_size else None + ) + + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = [ + "BltConfig", + "BltPatcherConfig", + "BltLocalEncoderConfig", + "BltLocalDecoderConfig", + "BltGlobalTransformerConfig", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blt/modeling_blt.py b/venv/lib/python3.13/site-packages/transformers/models/blt/modeling_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5459ccbff02cc132b330076c94d68d9fdd0efc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blt/modeling_blt.py @@ -0,0 +1,1311 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/blt/modular_blt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_blt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.distributions +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_blt import ( + BltConfig, + BltGlobalTransformerConfig, + BltLocalDecoderConfig, + BltLocalEncoderConfig, + BltPatcherConfig, +) + + +class BltMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + # Ignore copy + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BltRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BltRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class BltRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: BltConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer +class BltTransformerLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BltMLP(config) + self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class BltSelfAttention(nn.Module): + def __init__(self, config: BltConfig, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.is_causal = True + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + use_cache: bool = False, + past_key_values=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class BltCrossAttention(nn.Module): + """Cross-attention module for Blt, following transformers style""" + + def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.is_causal = False + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(hidden_states) + query_states = self.q_proj(query_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if past_key_values is not None: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_values.layers[self.layer_idx].keys, + past_key_values.layers[self.layer_idx].values, + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = attn_output + hidden_states + return attn_output, attn_weights + + +@auto_docstring +class BltPreTrainedModel(PreTrainedModel): + config: BltConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["BltTransformerLayer"] + _can_compile_fullgraph = False # static cache cannot have different shapes for each layer + _supports_sdpa = True + _supports_flash_attn = False + _supports_flex_attn = False + _supports_attention_backend = False + _can_record_outputs = { + "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), + "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), + } + + +class BltLocalEncoder(BltPreTrainedModel): + config: BltLocalEncoderConfig + _can_record_outputs = { + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + } + + def __init__(self, config: BltLocalEncoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + self.layers = nn.ModuleList( + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = BltRotaryEmbedding(config=config) + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.cross_attn_layers = nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size = inputs_embeds.shape[0] + hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training) + + if position_ids is None: + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: + patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) + layer_idx = idx if self.config.cross_attn_all_layers else 0 + cross_attention_output, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=encoder_attention_mask, + **kwargs, + ) + patch_embeds = patch_embeds + cross_attention_output + encoder_cross_states = patch_embeds + return hidden_states, encoder_cross_states + + def patch_reduce(self, hidden_states, max_num_patches, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + batch_size = hidden_states.shape[0] + embedding_dim = hidden_states.shape[-1] + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + + reduced_embeddings = torch.zeros( + (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, + dim=1, + index=patch_ids, + reduce="amax", + include_self=False, + ) + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] + + return reduced_embeddings + + +class BltLocalDecoder(BltPreTrainedModel): + config: BltLocalDecoderConfig + + def __init__(self, config: BltLocalDecoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + self.cross_attn_decoder = True + self.layers = nn.ModuleList( + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = BltRotaryEmbedding(config=config) + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size_global, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_layers = nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + self.post_init() + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + batch_size = inputs_embeds.shape[0] + hidden_states = inputs_embeds + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + hidden_states = hidden_states + patch_embeds + + if position_ids is None: + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + if i == 0 or self.config.cross_attn_all_layers: + cross_attention_output, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=encoder_attention_mask, + **kwargs, + ) + hidden_states = hidden_states + cross_attention_output + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + logits = self.norm(hidden_states) + return logits + + +class BltGlobalTransformer(BltPreTrainedModel): + config: BltGlobalTransformerConfig + _can_record_outputs = { + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } + + def __init__(self, config: BltGlobalTransformerConfig): + super().__init__(config) + self.config = config + self.layers = nn.ModuleList() + for layer_idx in range(config.num_hidden_layers): + self.layers.append(BltTransformerLayer(config, layer_idx)) + self.rotary_emb = BltRotaryEmbedding(config=config) + + # Create token embedding projection (use nn.Identity() when no projection needed) + if getattr(config, "encoder_cross_output_size", None) is not None: + self.token_embedding_projection = nn.Linear( + config.encoder_cross_output_size, config.hidden_size, bias=False + ) + else: + self.token_embedding_projection = nn.Identity() + + self.post_init() + + def forward( + self, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + batch_size, seq_len, _ = input_embeds.shape + hidden_states = self.token_embedding_projection(input_embeds) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for i, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + return hidden_states + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded + + +class BltPatcher(BltPreTrainedModel): + config: BltPatcherConfig + + def __init__(self, config: BltPatcherConfig): + super().__init__(config) + self.rotary_emb = BltRotaryEmbedding(config=self.config) + self.layers = nn.ModuleList() + for layer_idx in range(self.config.num_hidden_layers): + self.layers.append(BltTransformerLayer(self.config, layer_idx)) + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.lm_head = nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + **kwargs: Unpack[TransformersKwargs], + ): + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) + + logits = self.lm_head(self.norm(hidden_states)) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() + + batch_size, sequence_length = inputs_embeds.shape[:2] + if patch_size is not None: + patch_lengths = self.patch_lengths_from_entropies( + entropies=prediction_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, + ) + else: + patch_lengths = torch.ones( + (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return prediction_entropies, patch_lengths, logits + + @staticmethod + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): + """ + Computes patch lengths from token entropies. + + Depending on whether a threshold is provided, the function uses either: + - Thresholding the entropy values (when `threshold` is set). + """ + + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + ) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] + + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths + + +def rolling_polynomial_hash(token_tensor, prime: int = 1000000007): + """ + A polynomial rolling hash algorithm that converts sequences + of tokens into hash values. The hash is computed as: + hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n) + + The rolling hash allows the model to efficiently + identify and encode recurring byte-level patterns in the input text. + + Args: + token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash + prime (int): Prime number used as the base for the polynomial hash. + + Returns: + torch.Tensor: Hash values of shape [batch_size, seq_len] where each value + represents the hash of the corresponding token group + + Example: + >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> hashes = rolling_polynomial_hash(tokens, prime=31) + >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2 + >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2 + """ + prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime_tensor**powers + return torch.sum(token_tensor * prime_powers, dim=-1) + + +def byte_group_hash_function( + token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000 +): + """Hash token groups and map to range [0, max_hash].""" + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) + hashes = rolling_polynomial_hash(windows, prime) + hash_values = hashes % max_hash + + return hash_values + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.Embedding, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + # Available primes for hash functions + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab) + # Apply offset to get the correct slice of the fused embedding + offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab + embeddings += encoder_hash_tok_embedding(offset_hash_ids) + embedding_idx += 1 + + return embeddings + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = ( + torch.arange(num_patches, device=device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(batch_size, num_patches, seq_len) + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = ( + torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches) + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError( + f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}" + ) + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + return cross_attention_mask + + +class BltModel(BltPreTrainedModel): + def __init__(self, config: BltConfig): + super().__init__(config) + self.gradient_checkpointing = False + + self.config = config + self.local_encoder = BltLocalEncoder(config.encoder_config) + self.global_transformer = BltGlobalTransformer(config.global_config) + self.local_decoder = BltLocalDecoder(config.decoder_config) + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) + total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings + self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size) + if self.config.patch_in_forward: + self.patcher = BltPatcher(config.patcher_config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + self.post_init() + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + patch_lengths: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # Extract input embeddings as early as possible + if inputs_embeds is not None: + encoder_embeds = inputs_embeds + batch_size, sequence_length, _ = inputs_embeds.shape + else: + batch_size, sequence_length = input_ids.shape + encoder_embeds = compute_hash_embeddings( + input_ids, + self.local_encoder, + self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + + if patch_lengths is None: + if self.config.patching_mode == "entropy" and self.patcher is not None: + if input_ids is None: + raise ValueError("input_ids is required for entropy-based patching") + _, patch_lengths, _ = self.patcher( + input_ids, + patch_size=self.config.patch_size, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=input_ids.device, + ) + else: + device = input_ids.device if input_ids is not None else inputs_embeds.device + dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype + patch_lengths = process_patch_lengths( + torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device), + self.config.max_patch_length, + ) + patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=encoder_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + cross_attn_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids=patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, + patches_as_queries=True, + cross_attn_k=self.config.cross_attn_k, + dtype=encoder_embeds.dtype, + ) + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=input_ids, + inputs_embeds=encoder_embeds, + attention_mask=causal_mask, + position_ids=position_ids, + encoder_attention_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + **kwargs, + ) + encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device) + global_position_ids = global_cache_position.unsqueeze(0) + global_causal_mask = create_causal_mask( + config=self.config, + input_embeds=encoder_cross_states, + attention_mask=None, + cache_position=global_cache_position, + past_key_values=None, + position_ids=None, + ) + + global_hidden_states = self.global_transformer( + input_embeds=encoder_cross_states, + attention_mask=global_causal_mask, + position_ids=global_position_ids, + **kwargs, + ) + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) + cross_attn_mask_dec = _prepare_patch_cross_attention_mask( + patch_ids=decoder_patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, + patches_as_queries=False, + cross_attn_k=self.config.cross_attn_k, + dtype=encoder_embeds.dtype, + ) + output = self.local_decoder( + input_ids=input_ids, + inputs_embeds=encoder_hidden_states, + patch_embeds=global_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + encoder_attention_mask=cross_attn_mask_dec, + **kwargs, + ) + return BaseModelOutputWithPast( + last_hidden_state=output, + past_key_values=past_key_values, + ) + + def get_input_embeddings(self): + return self.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + self.local_encoder.embed_tokens = value + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + batch_size = patch_lengths.shape[0] + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], + ], + dim=-1, + ) + token_positions = torch.arange(seq_len, device=patch_lengths.device) + return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 + + +@auto_docstring( + custom_intro=""" + The Blt Text Model with a language modeling head on top. + """ +) +class BltForCausalLM(BltPreTrainedModel, GenerationMixin): + config: BltConfig + _can_compile_fullgraph = False + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: BltConfig): + super().__init__(config.get_text_config()) + self.text_config = config.get_text_config() + self.vocab_size = config.vocab_size + self.model = BltModel(config) + self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, # Keep for compatibility + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + cross_attention_states (`torch.FloatTensor`, *optional*): + Output of the vision model, used for cross-attention. This tensor contains the processed image features that + the language model will attend to. + cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): + Cross-attention mask to control the interaction between text tokens and image tiles. + This 4D tensor defines which image tiles each text token should attend to. + + For each text token (in seq_length): + - 1 indicates the token **should attend** to the corresponding image tile + - 0 indicates the token **should not attend** to the corresponding image tile + full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*): + A tuple containing two tensors that mask out rows in the cross-attention mechanism: + - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1. + A value of 0 indicates that the corresponding text token's entire row in the cross-attention + matrix should be masked out (all image tokens ignored). + - The second tensor has the same shape and is used internally to apply the masking during + the forward pass of cross-attention layers. + This mask is derived from the cross_attention_mask and is used to handle cases where a text token + should not attend to any image token. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BltForCausalLM + + >>> model = BltForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` + """ + # Call parent forward but exclude cross_attention_states from model call + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/blt/modular_blt.py b/venv/lib/python3.13/site-packages/transformers/models/blt/modular_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..b9047bc852d5363a8cb0a24b1ef91dfc5cb8df8f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/blt/modular_blt.py @@ -0,0 +1,1015 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Blt modular model, inheriting from Mllama where appropriate.""" + +from typing import Callable, Optional, Union + +import torch +import torch.distributions +import torch.nn as nn +import torch.nn.functional as F + +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import OutputRecorder, check_model_inputs +from ..cohere2.modeling_cohere2 import ( + Cohere2RotaryEmbedding, + rotate_half, # noqa: F401 +) +from ..mllama.modeling_mllama import ( + MllamaForCausalLM, + MllamaPreTrainedModel, + MllamaSelfAttentionDecoderLayer, + MllamaTextCrossAttention, + MllamaTextMLP, + MllamaTextRMSNorm, + MllamaTextSelfAttention, + eager_attention_forward, +) +from .configuration_blt import ( + BltConfig, + BltGlobalTransformerConfig, + BltLocalDecoderConfig, + BltLocalEncoderConfig, + BltPatcherConfig, +) + + +logger = logging.get_logger(__name__) + + +def rolling_polynomial_hash(token_tensor, prime: int = 1000000007): + """ + A polynomial rolling hash algorithm that converts sequences + of tokens into hash values. The hash is computed as: + hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n) + + The rolling hash allows the model to efficiently + identify and encode recurring byte-level patterns in the input text. + + Args: + token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash + prime (int): Prime number used as the base for the polynomial hash. + + Returns: + torch.Tensor: Hash values of shape [batch_size, seq_len] where each value + represents the hash of the corresponding token group + + Example: + >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> hashes = rolling_polynomial_hash(tokens, prime=31) + >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2 + >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2 + """ + prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime_tensor**powers + return torch.sum(token_tensor * prime_powers, dim=-1) + + +def byte_group_hash_function( + token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000 +): + """Hash token groups and map to range [0, max_hash].""" + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) + hashes = rolling_polynomial_hash(windows, prime) + hash_values = hashes % max_hash + + return hash_values + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.Embedding, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + # Available primes for hash functions + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab) + # Apply offset to get the correct slice of the fused embedding + offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab + embeddings += encoder_hash_tok_embedding(offset_hash_ids) + embedding_idx += 1 + + return embeddings + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = ( + torch.arange(num_patches, device=device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(batch_size, num_patches, seq_len) + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = ( + torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches) + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError( + f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}" + ) + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + return cross_attention_mask + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded + + +class BltMLP(MllamaTextMLP): + pass + + +class BltRMSNorm(MllamaTextRMSNorm): + pass + + +class BltRotaryEmbedding(Cohere2RotaryEmbedding): + pass + + +class BltTransformerLayer(MllamaSelfAttentionDecoderLayer): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BltMLP(config) + self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class BltSelfAttention(MllamaTextSelfAttention): + def __init__(self, config: BltConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.is_causal = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + use_cache: bool = False, + past_key_values=None, + cache_position=None, + **kwargs, + ): + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + use_cache=use_cache, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + + +class BltCrossAttention(MllamaTextCrossAttention): + """Cross-attention module for Blt, following transformers style""" + + def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.is_causal = False + self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(hidden_states) + query_states = self.q_proj(query_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if past_key_values is not None: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_values.layers[self.layer_idx].keys, + past_key_values.layers[self.layer_idx].values, + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = attn_output + hidden_states + return attn_output, attn_weights + + +@auto_docstring +class BltPreTrainedModel(MllamaPreTrainedModel): + config: BltConfig + _supports_attention_backend = False + _supports_flash_attn = False + _supports_flex_attn = False + _no_split_modules = ["BltTransformerLayer"] + _can_record_outputs = { + "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), + "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), + } + + def _init_weights(self, module): + raise AttributeError("No need to inherit it!") + + def _update_causal_mask(self, module): + raise AttributeError("No need to inherit it!") + + def _prepare_4d_causal_attention_mask_with_cache_position(self, module): + raise AttributeError("No need to inherit it!") + + +class BltLocalEncoder(BltPreTrainedModel): + config: BltLocalEncoderConfig + _can_record_outputs = { + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + } + + def __init__(self, config: BltLocalEncoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + self.layers = nn.ModuleList( + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = BltRotaryEmbedding(config=config) + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.cross_attn_layers = nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size = inputs_embeds.shape[0] + hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training) + + if position_ids is None: + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: + patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) + layer_idx = idx if self.config.cross_attn_all_layers else 0 + cross_attention_output, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=encoder_attention_mask, + **kwargs, + ) + patch_embeds = patch_embeds + cross_attention_output + encoder_cross_states = patch_embeds + return hidden_states, encoder_cross_states + + def patch_reduce(self, hidden_states, max_num_patches, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + batch_size = hidden_states.shape[0] + embedding_dim = hidden_states.shape[-1] + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + + reduced_embeddings = torch.zeros( + (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, + dim=1, + index=patch_ids, + reduce="amax", + include_self=False, + ) + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] + + return reduced_embeddings + + +class BltLocalDecoder(BltPreTrainedModel): + config: BltLocalDecoderConfig + + def __init__(self, config: BltLocalDecoderConfig): + super().__init__(config) + self.gradient_checkpointing = False + self.config = config + self.cross_attn_decoder = True + self.layers = nn.ModuleList( + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = BltRotaryEmbedding(config=config) + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size_global, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_layers = nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + self.post_init() + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + batch_size = inputs_embeds.shape[0] + hidden_states = inputs_embeds + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + hidden_states = hidden_states + patch_embeds + + if position_ids is None: + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + if i == 0 or self.config.cross_attn_all_layers: + cross_attention_output, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=encoder_attention_mask, + **kwargs, + ) + hidden_states = hidden_states + cross_attention_output + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + logits = self.norm(hidden_states) + return logits + + +class BltGlobalTransformer(BltPreTrainedModel): + config: BltGlobalTransformerConfig + _can_record_outputs = { + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } + + def __init__(self, config: BltGlobalTransformerConfig): + super().__init__(config) + self.config = config + self.layers = nn.ModuleList() + for layer_idx in range(config.num_hidden_layers): + self.layers.append(BltTransformerLayer(config, layer_idx)) + self.rotary_emb = BltRotaryEmbedding(config=config) + + # Create token embedding projection (use nn.Identity() when no projection needed) + if getattr(config, "encoder_cross_output_size", None) is not None: + self.token_embedding_projection = nn.Linear( + config.encoder_cross_output_size, config.hidden_size, bias=False + ) + else: + self.token_embedding_projection = nn.Identity() + + self.post_init() + + def forward( + self, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + batch_size, seq_len, _ = input_embeds.shape + hidden_states = self.token_embedding_projection(input_embeds) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for i, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + return hidden_states + + +class BltPatcher(BltPreTrainedModel): + config: BltPatcherConfig + + def __init__(self, config: BltPatcherConfig): + super().__init__(config) + self.rotary_emb = BltRotaryEmbedding(config=self.config) + self.layers = nn.ModuleList() + for layer_idx in range(self.config.num_hidden_layers): + self.layers.append(BltTransformerLayer(self.config, layer_idx)) + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.lm_head = nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + **kwargs: Unpack[TransformersKwargs], + ): + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) + + logits = self.lm_head(self.norm(hidden_states)) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() + + batch_size, sequence_length = inputs_embeds.shape[:2] + if patch_size is not None: + patch_lengths = self.patch_lengths_from_entropies( + entropies=prediction_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, + ) + else: + patch_lengths = torch.ones( + (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return prediction_entropies, patch_lengths, logits + + @staticmethod + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): + """ + Computes patch lengths from token entropies. + + Depending on whether a threshold is provided, the function uses either: + - Thresholding the entropy values (when `threshold` is set). + """ + + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + ) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] + + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths + + +class BltModel(BltPreTrainedModel): + def __init__(self, config: BltConfig): + super().__init__(config) + self.gradient_checkpointing = False + + self.config = config + self.local_encoder = BltLocalEncoder(config.encoder_config) + self.global_transformer = BltGlobalTransformer(config.global_config) + self.local_decoder = BltLocalDecoder(config.decoder_config) + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) + total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings + self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size) + if self.config.patch_in_forward: + self.patcher = BltPatcher(config.patcher_config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + self.post_init() + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + patch_lengths: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # Extract input embeddings as early as possible + if inputs_embeds is not None: + encoder_embeds = inputs_embeds + batch_size, sequence_length, _ = inputs_embeds.shape + else: + batch_size, sequence_length = input_ids.shape + encoder_embeds = compute_hash_embeddings( + input_ids, + self.local_encoder, + self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + + if patch_lengths is None: + if self.config.patching_mode == "entropy" and self.patcher is not None: + if input_ids is None: + raise ValueError("input_ids is required for entropy-based patching") + _, patch_lengths, _ = self.patcher( + input_ids, + patch_size=self.config.patch_size, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=input_ids.device, + ) + else: + device = input_ids.device if input_ids is not None else inputs_embeds.device + dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype + patch_lengths = process_patch_lengths( + torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device), + self.config.max_patch_length, + ) + patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=encoder_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + cross_attn_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids=patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, + patches_as_queries=True, + cross_attn_k=self.config.cross_attn_k, + dtype=encoder_embeds.dtype, + ) + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=input_ids, + inputs_embeds=encoder_embeds, + attention_mask=causal_mask, + position_ids=position_ids, + encoder_attention_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + **kwargs, + ) + encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device) + global_position_ids = global_cache_position.unsqueeze(0) + global_causal_mask = create_causal_mask( + config=self.config, + input_embeds=encoder_cross_states, + attention_mask=None, + cache_position=global_cache_position, + past_key_values=None, + position_ids=None, + ) + + global_hidden_states = self.global_transformer( + input_embeds=encoder_cross_states, + attention_mask=global_causal_mask, + position_ids=global_position_ids, + **kwargs, + ) + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) + cross_attn_mask_dec = _prepare_patch_cross_attention_mask( + patch_ids=decoder_patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, + patches_as_queries=False, + cross_attn_k=self.config.cross_attn_k, + dtype=encoder_embeds.dtype, + ) + output = self.local_decoder( + input_ids=input_ids, + inputs_embeds=encoder_hidden_states, + patch_embeds=global_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + encoder_attention_mask=cross_attn_mask_dec, + **kwargs, + ) + return BaseModelOutputWithPast( + last_hidden_state=output, + past_key_values=past_key_values, + ) + + def get_input_embeddings(self): + return self.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + self.local_encoder.embed_tokens = value + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + batch_size = patch_lengths.shape[0] + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], + ], + dim=-1, + ) + token_positions = torch.arange(seq_len, device=patch_lengths.device) + return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 + + +class BltForCausalLM(MllamaForCausalLM): + config: BltConfig + _can_compile_fullgraph = False + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: BltConfig): + super().__init__(config) + self.vocab_size = config.vocab_size + self.model = BltModel(config) + self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, # Keep for compatibility + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + # Call parent forward but exclude cross_attention_states from model call + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "BltPreTrainedModel", + "BltModel", + "BltPatcher", + "BltForCausalLM", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bros/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bros/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2178cfd03a40593bc9acb97111204aab0423860c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bros/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bros import * + from .modeling_bros import * + from .processing_bros import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/bros/configuration_bros.py b/venv/lib/python3.13/site-packages/transformers/models/bros/configuration_bros.py new file mode 100644 index 0000000000000000000000000000000000000000..84c9989f309fb782b47aed37f9728dd7a26ba6b8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bros/configuration_bros.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bros model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BrosConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BrosModel`] or a [`TFBrosModel`]. It is used to + instantiate a Bros model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Bros + [jinho8345/bros-base-uncased](https://huggingface.co/jinho8345/bros-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Bros model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BrosModel`] or [`TFBrosModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`BrosModel`] or [`TFBrosModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + The index of the padding token in the token vocabulary. + dim_bbox (`int`, *optional*, defaults to 8): + The dimension of the bounding box coordinates. (x0, y1, x1, y0, x1, y1, x0, y1) + bbox_scale (`float`, *optional*, defaults to 100.0): + The scale factor of the bounding box coordinates. + n_relations (`int`, *optional*, defaults to 1): + The number of relations for SpadeEE(entity extraction), SpadeEL(entity linking) head. + classifier_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the classifier head. + + + Examples: + + ```python + >>> from transformers import BrosConfig, BrosModel + + >>> # Initializing a BROS jinho8345/bros-base-uncased style configuration + >>> configuration = BrosConfig() + + >>> # Initializing a model from the jinho8345/bros-base-uncased style configuration + >>> model = BrosModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "bros" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + dim_bbox=8, + bbox_scale=100.0, + n_relations=1, + classifier_dropout_prob=0.1, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + initializer_range=initializer_range, + layer_norm_eps=layer_norm_eps, + pad_token_id=pad_token_id, + **kwargs, + ) + + self.dim_bbox = dim_bbox + self.bbox_scale = bbox_scale + self.n_relations = n_relations + self.dim_bbox_sinusoid_emb_2d = self.hidden_size // 4 + self.dim_bbox_sinusoid_emb_1d = self.dim_bbox_sinusoid_emb_2d // self.dim_bbox + self.dim_bbox_projection = self.hidden_size // self.num_attention_heads + self.classifier_dropout_prob = classifier_dropout_prob + + +__all__ = ["BrosConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bros/modeling_bros.py b/venv/lib/python3.13/site-packages/transformers/models/bros/modeling_bros.py new file mode 100644 index 0000000000000000000000000000000000000000..d01a4c5a1c6da7644e6b3433d1657f5fe4f93849 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bros/modeling_bros.py @@ -0,0 +1,1136 @@ +# coding=utf-8 +# Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Bros model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging +from .configuration_bros import BrosConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of token classification models. + """ +) +class BrosSpadeOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + initial_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores for entity initial tokens (before SoftMax). + subsequent_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length+1)`): + Classification scores for entity sequence tokens (before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + initial_token_logits: Optional[torch.FloatTensor] = None + subsequent_token_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +class BrosPositionalEmbedding1D(nn.Module): + # Reference: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15 + + def __init__(self, config): + super().__init__() + + self.dim_bbox_sinusoid_emb_1d = config.dim_bbox_sinusoid_emb_1d + + inv_freq = 1 / ( + 10000 ** (torch.arange(0.0, self.dim_bbox_sinusoid_emb_1d, 2.0) / self.dim_bbox_sinusoid_emb_1d) + ) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, pos_seq: torch.Tensor) -> torch.Tensor: + seq_size = pos_seq.size() + b1, b2, b3 = seq_size + sinusoid_inp = pos_seq.view(b1, b2, b3, 1) * self.inv_freq.view(1, 1, 1, self.dim_bbox_sinusoid_emb_1d // 2) + pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) + return pos_emb + + +class BrosPositionalEmbedding2D(nn.Module): + def __init__(self, config): + super().__init__() + + self.dim_bbox = config.dim_bbox + self.x_pos_emb = BrosPositionalEmbedding1D(config) + self.y_pos_emb = BrosPositionalEmbedding1D(config) + + def forward(self, bbox: torch.Tensor) -> torch.Tensor: + stack = [] + for i in range(self.dim_bbox): + if i % 2 == 0: + stack.append(self.x_pos_emb(bbox[..., i])) + else: + stack.append(self.y_pos_emb(bbox[..., i])) + bbox_pos_emb = torch.cat(stack, dim=-1) + return bbox_pos_emb + + +class BrosBboxEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.bbox_sinusoid_emb = BrosPositionalEmbedding2D(config) + self.bbox_projection = nn.Linear(config.dim_bbox_sinusoid_emb_2d, config.dim_bbox_projection, bias=False) + + def forward(self, bbox: torch.Tensor): + bbox_t = bbox.transpose(0, 1) + bbox_pos = bbox_t[None, :, :, :] - bbox_t[:, None, :, :] + bbox_pos_emb = self.bbox_sinusoid_emb(bbox_pos) + bbox_pos_emb = self.bbox_projection(bbox_pos_emb) + + return bbox_pos_emb + + +class BrosTextEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.register_buffer( + "token_type_ids", + torch.zeros( + self.position_ids.size(), + dtype=torch.long, + device=self.position_ids.device, + ), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BrosSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def forward( + self, + hidden_states: torch.Tensor, + bbox_pos_emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[torch.Tensor] = False, + ) -> tuple[torch.Tensor]: + hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2) + attention_mask = encoder_attention_mask + else: + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + # bbox positional encoding + batch_size, n_head, seq_length, d_head = query_layer.shape + bbox_pos_emb = bbox_pos_emb.view(seq_length, seq_length, batch_size, d_head) + bbox_pos_emb = bbox_pos_emb.permute([2, 0, 1, 3]) + bbox_pos_scores = torch.einsum("bnid,bijd->bnij", (query_layer, bbox_pos_emb)) + + attention_scores = attention_scores + bbox_pos_scores + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BrosModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (None,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Bros +class BrosSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BrosAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = BrosSelfAttention(config) + self.output = BrosSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + bbox_pos_emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states=hidden_states, + bbox_pos_emb=bbox_pos_emb, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Bros +class BrosIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BrosOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BrosLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BrosAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise Exception(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BrosAttention(config) + self.intermediate = BrosIntermediate(config) + self.output = BrosOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + bbox_pos_emb: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + bbox_pos_emb=bbox_pos_emb, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if hasattr(self, "crossattention"): + raise Exception( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (None,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BrosEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)]) + + @can_return_tuple + def forward( + self, + hidden_states: torch.Tensor, + bbox_pos_emb: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + bbox_pos_emb=bbox_pos_emb, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Bros +class BrosPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BrosRelationExtractor(nn.Module): + def __init__(self, config): + super().__init__() + self.n_relations = config.n_relations + self.backbone_hidden_size = config.hidden_size + self.head_hidden_size = config.hidden_size + self.classifier_dropout_prob = config.classifier_dropout_prob + + self.drop = nn.Dropout(self.classifier_dropout_prob) + self.query = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size) + + self.key = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size) + + self.dummy_node = nn.Parameter(torch.zeros(1, self.backbone_hidden_size)) + + def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor): + query_layer = self.query(self.drop(query_layer)) + + dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, key_layer.size(1), 1) + key_layer = torch.cat([key_layer, dummy_vec], axis=0) + key_layer = self.key(self.drop(key_layer)) + + query_layer = query_layer.view( + query_layer.size(0), query_layer.size(1), self.n_relations, self.head_hidden_size + ) + key_layer = key_layer.view(key_layer.size(0), key_layer.size(1), self.n_relations, self.head_hidden_size) + + relation_score = torch.matmul( + query_layer.permute(2, 1, 0, 3), key_layer.permute(2, 1, 3, 0) + ) # equivalent to torch.einsum("ibnd,jbnd->nbij", (query_layer, key_layer)) + + return relation_score + + +@auto_docstring +class BrosPreTrainedModel(PreTrainedModel): + config: BrosConfig + base_model_prefix = "bros" + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, BrosRelationExtractor): + nn.init.normal_(module.dummy_node, std=std) + + +@auto_docstring +class BrosModel(BrosPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = BrosTextEmbeddings(config) + self.bbox_embeddings = BrosBboxEmbeddings(config) + self.encoder = BrosEncoder(config) + + self.pooler = BrosPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + bbox: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'): + Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values + (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the + bounding box. + + Examples: + + ```python + >>> import torch + >>> from transformers import BrosProcessor, BrosModel + + >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased") + + >>> model = BrosModel.from_pretrained("jinho8345/bros-base-uncased") + + >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt") + >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1) + >>> encoding["bbox"] = bbox + + >>> outputs = model(**encoding) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if bbox is None: + raise ValueError("You have to specify bbox") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + + # if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token + if bbox.shape[-1] == 4: + bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]] + scaled_bbox = bbox * self.config.bbox_scale + bbox_position_embeddings = self.bbox_embeddings(scaled_bbox) + + encoder_outputs = self.encoder( + embedding_output, + bbox_pos_emb=bbox_position_embeddings, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring +class BrosForTokenClassification(BrosPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bros = BrosModel(config) + classifier_dropout = ( + config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + bbox: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + bbox_first_token_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'): + Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values + (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the + bounding box. + bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Examples: + + ```python + >>> import torch + >>> from transformers import BrosProcessor, BrosForTokenClassification + + >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased") + + >>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased") + + >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt") + >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1) + >>> encoding["bbox"] = bbox + + >>> outputs = model(**encoding) + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bros( + input_ids, + bbox=bbox, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + if bbox_first_token_mask is not None: + bbox_first_token_mask = bbox_first_token_mask.view(-1) + loss = loss_fct( + logits.view(-1, self.num_labels)[bbox_first_token_mask], labels.view(-1)[bbox_first_token_mask] + ) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Bros Model with a token classification head on top (initial_token_layers and subsequent_token_layer on top of the + hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. The initial_token_classifier is used to + predict the first token of each entity, and the subsequent_token_classifier is used to predict the subsequent + tokens within an entity. Compared to BrosForTokenClassification, this model is more robust to serialization errors + since it predicts next token from one token. + """ +) +class BrosSpadeEEForTokenClassification(BrosPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.config = config + self.num_labels = config.num_labels + self.n_relations = config.n_relations + self.backbone_hidden_size = config.hidden_size + + self.bros = BrosModel(config) + classifier_dropout = ( + config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob + ) + + # Initial token classification for Entity Extraction (NER) + self.initial_token_classifier = nn.Sequential( + nn.Dropout(classifier_dropout), + nn.Linear(config.hidden_size, config.hidden_size), + nn.Dropout(classifier_dropout), + nn.Linear(config.hidden_size, config.num_labels), + ) + + # Subsequent token classification for Entity Extraction (NER) + self.subsequent_token_classifier = BrosRelationExtractor(config) + + self.init_weights() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + bbox: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + bbox_first_token_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + initial_token_labels: Optional[torch.Tensor] = None, + subsequent_token_labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BrosSpadeOutput]: + r""" + bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'): + Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values + (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the + bounding box. + bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + initial_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for the initial token classification. + subsequent_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for the subsequent token classification. + + Examples: + + ```python + >>> import torch + >>> from transformers import BrosProcessor, BrosSpadeEEForTokenClassification + + >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased") + + >>> model = BrosSpadeEEForTokenClassification.from_pretrained("jinho8345/bros-base-uncased") + + >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt") + >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1) + >>> encoding["bbox"] = bbox + + >>> outputs = model(**encoding) + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bros( + input_ids=input_ids, + bbox=bbox, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + last_hidden_states = outputs[0] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + initial_token_logits = self.initial_token_classifier(last_hidden_states).transpose(0, 1).contiguous() + subsequent_token_logits = self.subsequent_token_classifier(last_hidden_states, last_hidden_states).squeeze(0) + + # make subsequent token (sequence token classification) mask + inv_attention_mask = 1 - attention_mask + batch_size, max_seq_length = inv_attention_mask.shape + device = inv_attention_mask.device + invalid_token_mask = torch.cat([inv_attention_mask, torch.zeros([batch_size, 1]).to(device)], axis=1).bool() + subsequent_token_logits = subsequent_token_logits.masked_fill( + invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min + ) + self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool) + subsequent_token_logits = subsequent_token_logits.masked_fill( + self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min + ) + subsequent_token_mask = attention_mask.view(-1).bool() + + loss = None + if initial_token_labels is not None and subsequent_token_labels is not None: + loss_fct = CrossEntropyLoss() + + # get initial token loss + initial_token_labels = initial_token_labels.view(-1) + if bbox_first_token_mask is not None: + bbox_first_token_mask = bbox_first_token_mask.view(-1) + initial_token_loss = loss_fct( + initial_token_logits.view(-1, self.num_labels)[bbox_first_token_mask], + initial_token_labels[bbox_first_token_mask], + ) + else: + initial_token_loss = loss_fct(initial_token_logits.view(-1, self.num_labels), initial_token_labels) + + subsequent_token_labels = subsequent_token_labels.view(-1) + subsequent_token_loss = loss_fct( + subsequent_token_logits.view(-1, max_seq_length + 1)[subsequent_token_mask], + subsequent_token_labels[subsequent_token_mask], + ) + + loss = initial_token_loss + subsequent_token_loss + + return BrosSpadeOutput( + loss=loss, + initial_token_logits=initial_token_logits, + subsequent_token_logits=subsequent_token_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Bros Model with a token classification head on top (a entity_linker layer on top of the hidden-states output) e.g. + for Entity-Linking. The entity_linker is used to predict intra-entity links (one entity to another entity). + """ +) +class BrosSpadeELForTokenClassification(BrosPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.config = config + self.num_labels = config.num_labels + self.n_relations = config.n_relations + self.backbone_hidden_size = config.hidden_size + + self.bros = BrosModel(config) + (config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob) + + self.entity_linker = BrosRelationExtractor(config) + + self.init_weights() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + bbox: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + bbox_first_token_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'): + Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values + (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the + bounding box. + bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Examples: + + ```python + >>> import torch + >>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification + + >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased") + + >>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased") + + >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt") + >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1) + >>> encoding["bbox"] = bbox + + >>> outputs = model(**encoding) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bros( + input_ids=input_ids, + bbox=bbox, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + last_hidden_states = outputs[0] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + logits = self.entity_linker(last_hidden_states, last_hidden_states).squeeze(0) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + batch_size, max_seq_length = attention_mask.shape + device = attention_mask.device + + self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool) + + mask = bbox_first_token_mask.view(-1) + bbox_first_token_mask = torch.cat( + [ + ~bbox_first_token_mask, + torch.zeros([batch_size, 1], dtype=torch.bool, device=device), + ], + axis=1, + ) + logits = logits.masked_fill(bbox_first_token_mask[:, None, :], torch.finfo(logits.dtype).min) + logits = logits.masked_fill(self_token_mask[None, :, :], torch.finfo(logits.dtype).min) + + loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask]) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "BrosPreTrainedModel", + "BrosModel", + "BrosForTokenClassification", + "BrosSpadeEEForTokenClassification", + "BrosSpadeELForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/bros/processing_bros.py b/venv/lib/python3.13/site-packages/transformers/models/bros/processing_bros.py new file mode 100644 index 0000000000000000000000000000000000000000..8de0a1c49b0d0adcb3650bb8fba55582d4fa9588 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/bros/processing_bros.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Bros. +""" + +from ...processing_utils import ProcessingKwargs, ProcessorMixin + + +class BrosProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + } + + +class BrosProcessor(ProcessorMixin): + r""" + Constructs a Bros processor which wraps a BERT tokenizer. + + [`BrosProcessor`] offers all the functionalities of [`BertTokenizerFast`]. See the docstring of + [`~BrosProcessor.__call__`] and [`~BrosProcessor.decode`] for more information. + + Args: + tokenizer (`BertTokenizerFast`, *optional*): + An instance of ['BertTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["tokenizer"] + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + valid_processor_kwargs = BrosProcessorKwargs + + def __init__(self, tokenizer=None, **kwargs): + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(tokenizer) + + +__all__ = ["BrosProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/camembert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/camembert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d90f64de97f78ccaf1592c3c5cad40a9b2d5dcb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/camembert/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_camembert import * + from .modeling_camembert import * + from .modeling_tf_camembert import * + from .tokenization_camembert import * + from .tokenization_camembert_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/camembert/configuration_camembert.py b/venv/lib/python3.13/site-packages/transformers/models/camembert/configuration_camembert.py new file mode 100644 index 0000000000000000000000000000000000000000..3979e54874439aa41feb6abe3815df5c7a997419 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/camembert/configuration_camembert.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CamemBERT configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class CamembertConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`CamembertModel`] or a [`TFCamembertModel`]. It is + used to instantiate a Camembert model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Camembert + [almanach/camembert-base](https://huggingface.co/almanach/camembert-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`CamembertModel`] or [`TFCamembertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`CamembertModel`] or [`TFCamembertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Example: + + ```python + >>> from transformers import CamembertConfig, CamembertModel + + >>> # Initializing a Camembert almanach/camembert-base style configuration + >>> configuration = CamembertConfig() + + >>> # Initializing a model (with random weights) from the almanach/camembert-base style configuration + >>> model = CamembertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "camembert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class CamembertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["CamembertConfig", "CamembertOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/camembert/modeling_camembert.py b/venv/lib/python3.13/site-packages/transformers/models/camembert/modeling_camembert.py new file mode 100644 index 0000000000000000000000000000000000000000..3a07402f739a603ad3b8a50a197b7784bfc1d2a7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/camembert/modeling_camembert.py @@ -0,0 +1,1576 @@ +# coding=utf-8 +# Copyright 2019 Inria, Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch CamemBERT model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_camembert import CamembertConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Camembert +class CamembertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Camembert +class CamembertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in CamembertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert +class CamembertSdpaSelfAttention(CamembertSelfAttention): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) + self.dropout_prob = config.attention_probs_dropout_prob + + # Adapted from CamembertSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + past_key_values, + output_attentions, + cache_position, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + return attn_output, None + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert +class CamembertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +CAMEMBERT_SELF_ATTENTION_CLASSES = { + "eager": CamembertSelfAttention, + "sdpa": CamembertSdpaSelfAttention, +} + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert,ROBERTA->CAMEMBERT +class CamembertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.self = CAMEMBERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, + ) + self.output = CamembertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Roberta->Camembert +class CamembertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Roberta->Camembert +class CamembertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert +class CamembertLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = CamembertAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = CamembertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) + self.intermediate = CamembertIntermediate(config) + self.output = CamembertOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->Camembert +class CamembertEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer = nn.ModuleList([CamembertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class CamembertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class CamembertPreTrainedModel(PreTrainedModel): + config: CamembertConfig + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _supports_sdpa = True + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->CamembertLMHead + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, CamembertLMHead): + module.bias.data.zero_() + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Camembert +class CamembertClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Camembert +class CamembertLMHead(nn.Module): + """Camembert Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@auto_docstring +class CamembertModel(CamembertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as a decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to + `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762 + + """ + + _no_split_modules = [] + + # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.__init__ with Roberta->Camembert + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = CamembertEmbeddings(config) + self.encoder = CamembertEncoder(config) + + self.pooler = CamembertPooler(config) if add_pooling_layer else None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring +# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT +class CamembertForMaskedLM(CamembertPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `CamembertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roberta = CamembertModel(config, add_pooling_layer=False) + self.lm_head = CamembertLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT +class CamembertForSequenceClassification(CamembertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roberta = CamembertModel(config, add_pooling_layer=False) + self.classifier = CamembertClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +# Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT +class CamembertForMultipleChoice(CamembertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.roberta = CamembertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.roberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(reshaped_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +# Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT +class CamembertForTokenClassification(CamembertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = CamembertModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +# Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT +class CamembertForQuestionAnswering(CamembertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = CamembertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + CamemBERT Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, FacebookAI/roberta-base->almanach/camembert-base +class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `CamembertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = CamembertModel(config, add_pooling_layer=False) + self.lm_head = CamembertLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> from transformers import AutoTokenizer, CamembertForCausalLM, AutoConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("almanach/camembert-base") + >>> config = AutoConfig.from_pretrained("almanach/camembert-base") + >>> config.is_decoder = True + >>> model = CamembertForCausalLM.from_pretrained("almanach/camembert-base", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "CamembertForCausalLM", + "CamembertForMaskedLM", + "CamembertForMultipleChoice", + "CamembertForQuestionAnswering", + "CamembertForSequenceClassification", + "CamembertForTokenClassification", + "CamembertModel", + "CamembertPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/camembert/modeling_tf_camembert.py b/venv/lib/python3.13/site-packages/transformers/models/camembert/modeling_tf_camembert.py new file mode 100644 index 0000000000000000000000000000000000000000..0869902aa96245ead1bb22f2e818517288a35534 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/camembert/modeling_tf_camembert.py @@ -0,0 +1,1800 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 CamemBERT model.""" + +from __future__ import annotations + +import math +import warnings + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_camembert import CamembertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "almanach/camembert-base" +_CONFIG_FOR_DOC = "CamembertConfig" + + +CAMEMBERT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`CamembertConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CAMEMBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings +class TFCamembertEmbeddings(keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.padding_idx = 1 + self.config = config + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: tf.Tensor + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask + + return incremental_indices + self.padding_idx + + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) + else: + position_ids = tf.expand_dims( + tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Camembert +class TFCamembertPooler(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Camembert +class TFCamembertSelfAttention(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFCamembertModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Camembert +class TFCamembertSelfOutput(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Camembert +class TFCamembertAttention(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFCamembertSelfAttention(config, name="self") + self.dense_output = TFCamembertSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Camembert +class TFCamembertIntermediate(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Camembert +class TFCamembertOutput(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Camembert +class TFCamembertLayer(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFCamembertAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFCamembertAttention(config, name="crossattention") + self.intermediate = TFCamembertIntermediate(config, name="intermediate") + self.bert_output = TFCamembertOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Camembert +class TFCamembertEncoder(keras.layers.Layer): + def __init__(self, config: CamembertConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFCamembertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_values: tuple[tuple[tf.Tensor]] | None, + use_cache: bool | None, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaMainLayer with Roberta->Camembert +class TFCamembertMainLayer(keras.layers.Layer): + config_class = CamembertConfig + + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.encoder = TFCamembertEncoder(config, name="encoder") + self.pooler = TFCamembertPooler(config, name="pooler") if add_pooling_layer else None + # The embeddings must be the last declaration in order to follow the weights order + self.embeddings = TFCamembertEmbeddings(config, name="embeddings") + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + + +class TFCamembertPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CamembertConfig + base_model_prefix = "roberta" + + +@add_start_docstrings( + "The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.", + CAMEMBERT_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaModel with Roberta->Camembert, ROBERTA->CAMEMBERT +class TFCamembertModel(TFCamembertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.roberta = TFCamembertMainLayer(config, name="roberta") + + @unpack_inputs + @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFBaseModelOutputWithPoolingAndCrossAttentions: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Camembert +class TFCamembertLMHead(keras.layers.Layer): + """Camembert Head for masked language modeling.""" + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.act = get_tf_activation("gelu") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + # project back to size of vocabulary with bias + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings( + """CamemBERT Model with a `language modeling` head on top.""", + CAMEMBERT_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT +class TFCamembertForMaskedLM(TFCamembertPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFCamembertLMHead(config, self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.1, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead +class TFCamembertClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.out_proj = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + self.config = config + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + CAMEMBERT_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT +class TFCamembertForSequenceClassification(TFCamembertPreTrainedModel, TFSequenceClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") + self.classifier = TFCamembertClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="cardiffnlp/twitter-roberta-base-emotion", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'optimism'", + expected_loss=0.08, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + CamemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + CAMEMBERT_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT +class TFCamembertForTokenClassification(TFCamembertPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="ydshieh/roberta-large-ner-english", + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", + expected_loss=0.01, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + CamemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + CAMEMBERT_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT +class TFCamembertForMultipleChoice(TFCamembertPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFCamembertMainLayer(config, name="roberta") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward( + CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + outputs = self.roberta( + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + CAMEMBERT_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT +class TFCamembertForQuestionAnswering(TFCamembertPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="ydshieh/roberta-base-squad2", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="' puppet'", + expected_loss=0.86, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT +class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config: CamembertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFCamembertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFCamembertLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + @unpack_inputs + @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + logits = self.lm_head(hidden_states=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +__all__ = [ + "TFCamembertForCausalLM", + "TFCamembertForMaskedLM", + "TFCamembertForMultipleChoice", + "TFCamembertForQuestionAnswering", + "TFCamembertForSequenceClassification", + "TFCamembertForTokenClassification", + "TFCamembertModel", + "TFCamembertPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/camembert/tokenization_camembert.py b/venv/lib/python3.13/site-packages/transformers/models/camembert/tokenization_camembert.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6e399f208df858d647dce23fe0ad74248d80b8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/camembert/tokenization_camembert.py @@ -0,0 +1,323 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for Camembert model.""" + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +SPIECE_UNDERLINE = "▁" + + +@requires(backends=("sentencepiece",)) +class CamembertTokenizer(PreTrainedTokenizer): + """ + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Construct a CamemBERT tokenizer. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + additional_special_tokens (`list[str]`, *optional*, defaults to `['NOTUSED', 'NOTUSED', 'NOTUSED']`): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + additional_special_tokens=["NOTUSED", "NOTUSED", "NOTUSED"], + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False, special=True) + if isinstance(mask_token, str) + else mask_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # HACK: These tokens were added by the author for an obscure reason as they were already part of the + # sentencepiece vocabulary (this is the case for and and ). + # In this case it is recommended to properly set the tokens by hand. + self._added_tokens_decoder = { + 0: AddedToken("NOTUSED", special=True), + 1: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token, + 2: AddedToken("NOTUSED", special=True), + 3: AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token, + 4: AddedToken("NOTUSED", special=True), + } + + self.fairseq_offset = 4 # 3 tokens are newly added, but the offset starts from 4 + + # legacy: camemebert is a particular case were we have to make sure `"NOTUSED"` is here + if "added_tokens_decoder" in kwargs: + # this is the only class that requires this unfortunately..... + # the reason is that the fast version has a whole. + kwargs["added_tokens_decoder"].update(self._added_tokens_decoder) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self): + # The length of the vocabulary without added tokens is len(self.sp_model) but the added tokens are added at the beginning. + return len(self.sp_model) + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.fairseq_offset)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> list[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + # specific to camembert, both 3 and 4 point to the unk token. + if self.sp_model.PieceToId(token) == 0: + # Convert sentence piece unk token to fairseq unk token index + return self.unk_token_id + return self.fairseq_offset + self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # TODO decode outputs do not match between fast and slow + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An CamemBERT sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. CamemBERT, like + RoBERTa, does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + +__all__ = ["CamembertTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/camembert/tokenization_camembert_fast.py b/venv/lib/python3.13/site-packages/transformers/models/camembert/tokenization_camembert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..423058ed959aeb3e7122fb3730a6d0f1f57b5982 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/camembert/tokenization_camembert_fast.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Fast tokenization classes for Camembert model.""" + +import os +from shutil import copyfile +from typing import Optional + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_camembert import CamembertTokenizer +else: + CamembertTokenizer = None + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +SPIECE_UNDERLINE = "▁" + + +class CamembertTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" CamemBERT tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from + [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + additional_special_tokens (`list[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = CamembertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + additional_special_tokens=["NOTUSED", "NOTUSED", "NOTUSED"], + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it. Will have normalized = False + mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An CamemBERT sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. CamemBERT, like + RoBERTa, does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["CamembertTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/canine/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/canine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb00d8fd8827035ff7b351db49d7128a54429c53 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/canine/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_canine import * + from .modeling_canine import * + from .tokenization_canine import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/canine/configuration_canine.py b/venv/lib/python3.13/site-packages/transformers/models/canine/configuration_canine.py new file mode 100644 index 0000000000000000000000000000000000000000..29e90327d08f02d490006a464cb566adbca52007 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/canine/configuration_canine.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CANINE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class CanineConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CanineModel`]. It is used to instantiate an + CANINE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the CANINE + [google/canine-s](https://huggingface.co/google/canine-s) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the deep Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoders. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoders. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoders, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 16384): + The maximum sequence length that this model might ever be used with. + type_vocab_size (`int`, *optional*, defaults to 16): + The vocabulary size of the `token_type_ids` passed when calling [`CanineModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 57344): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 57345): + End of stream token id. + downsampling_rate (`int`, *optional*, defaults to 4): + The rate at which to downsample the original character sequence length before applying the deep Transformer + encoder. + upsampling_kernel_size (`int`, *optional*, defaults to 4): + The kernel size (i.e. the number of characters in each window) of the convolutional projection layer when + projecting back from `hidden_size`*2 to `hidden_size`. + num_hash_functions (`int`, *optional*, defaults to 8): + The number of hash functions to use. Each hash function has its own embedding matrix. + num_hash_buckets (`int`, *optional*, defaults to 16384): + The number of hash buckets to use. + local_transformer_stride (`int`, *optional*, defaults to 128): + The stride of the local attention of the first shallow Transformer encoder. Defaults to 128 for good + TPU/XLA memory alignment. + + Example: + + ```python + >>> from transformers import CanineConfig, CanineModel + + >>> # Initializing a CANINE google/canine-s style configuration + >>> configuration = CanineConfig() + + >>> # Initializing a model (with random weights) from the google/canine-s style configuration + >>> model = CanineModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "canine" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=16384, + type_vocab_size=16, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + bos_token_id=0xE000, + eos_token_id=0xE001, + downsampling_rate=4, + upsampling_kernel_size=4, + num_hash_functions=8, + num_hash_buckets=16384, + local_transformer_stride=128, # Good TPU/XLA memory alignment. + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + + # Character config: + self.downsampling_rate = downsampling_rate + self.upsampling_kernel_size = upsampling_kernel_size + self.num_hash_functions = num_hash_functions + self.num_hash_buckets = num_hash_buckets + self.local_transformer_stride = local_transformer_stride + + +__all__ = ["CanineConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/canine/modeling_canine.py b/venv/lib/python3.13/site-packages/transformers/models/canine/modeling_canine.py new file mode 100644 index 0000000000000000000000000000000000000000..585961180f9e914bc1a0670306ab94d9f805fb6e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/canine/modeling_canine.py @@ -0,0 +1,1549 @@ +# coding=utf-8 +# Copyright 2021 Google AI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch CANINE model.""" + +import copy +import math +import os +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + ModelOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from .configuration_canine import CanineConfig + + +logger = logging.get_logger(__name__) + + +# Support up to 16 hash functions. +_PRIMES = [31, 43, 59, 61, 73, 97, 103, 113, 137, 149, 157, 173, 181, 193, 211, 223] + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`CanineModel`]. Based on [`~modeling_outputs.BaseModelOutputWithPooling`], but with slightly + different `hidden_states` and `attentions`, as these also include the hidden states and attentions of the shallow + Transformer encoders. + """ +) +class CanineModelOutputWithPooling(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model (i.e. the output of the final + shallow Transformer encoder). + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Hidden-state of the first token of the sequence (classification token) at the last layer of the deep + Transformer encoder, further processed by a Linear layer and a Tanh activation function. The Linear layer + weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the input to each encoder + one for the output of each layer of each + encoder) of shape `(batch_size, sequence_length, hidden_size)` and `(batch_size, sequence_length // + config.downsampling_rate, hidden_size)`. Hidden-states of the model at the output of each layer plus the + initial input to each Transformer encoder. The hidden states of the shallow encoders have length + `sequence_length`, but the hidden states of the deep encoder have length `sequence_length` // + `config.downsampling_rate`. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of the 3 Transformer encoders of shape `(batch_size, + num_heads, sequence_length, sequence_length)` and `(batch_size, num_heads, sequence_length // + config.downsampling_rate, sequence_length // config.downsampling_rate)`. Attentions weights after the + attention softmax, used to compute the weighted average in the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +def load_tf_weights_in_canine(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + # also discard the cls weights (which were used for the next sentence prediction pre-training task) + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + "cls", + "autoregressive_decoder", + "char_output_weights", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + # if first scope name starts with "bert", change it to "encoder" + if name[0] == "bert": + name[0] = "encoder" + # remove "embeddings" middle name of HashBucketCodepointEmbedders + elif name[1] == "embeddings": + name.remove(name[1]) + # rename segment_embeddings to token_type_embeddings + elif name[1] == "segment_embeddings": + name[1] = "token_type_embeddings" + # rename initial convolutional projection layer + elif name[1] == "initial_char_encoder": + name = ["chars_to_molecules"] + name[-2:] + # rename final convolutional projection layer + elif name[0] == "final_char_encoder" and name[1] in ["LayerNorm", "conv"]: + name = ["projection"] + name[1:] + pointer = model + for m_name in name: + if (re.fullmatch(r"[A-Za-z]+_\d+", m_name)) and "Embedder" not in m_name: + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name[-10:] in [f"Embedder_{i}" for i in range(8)]: + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class CanineEmbeddings(nn.Module): + """Construct the character, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + + self.config = config + + # character embeddings + shard_embedding_size = config.hidden_size // config.num_hash_functions + for i in range(config.num_hash_functions): + name = f"HashBucketCodepointEmbedder_{i}" + setattr(self, name, nn.Embedding(config.num_hash_buckets, shard_embedding_size)) + self.char_position_embeddings = nn.Embedding(config.num_hash_buckets, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int): + """ + Converts ids to hash bucket ids via multiple hashing. + + Args: + input_ids: The codepoints or other IDs to be hashed. + num_hashes: The number of hash functions to use. + num_buckets: The number of hash buckets (i.e. embeddings in each table). + + Returns: + A list of tensors, each of which is the hash bucket IDs from one hash function. + """ + if num_hashes > len(_PRIMES): + raise ValueError(f"`num_hashes` must be <= {len(_PRIMES)}") + + primes = _PRIMES[:num_hashes] + + result_tensors = [] + for prime in primes: + hashed = ((input_ids + 1) * prime) % num_buckets + result_tensors.append(hashed) + return result_tensors + + def _embed_hash_buckets(self, input_ids, embedding_size: int, num_hashes: int, num_buckets: int): + """Converts IDs (e.g. codepoints) into embeddings via multiple hashing.""" + if embedding_size % num_hashes != 0: + raise ValueError(f"Expected `embedding_size` ({embedding_size}) % `num_hashes` ({num_hashes}) == 0") + + hash_bucket_tensors = self._hash_bucket_tensors(input_ids, num_hashes=num_hashes, num_buckets=num_buckets) + embedding_shards = [] + for i, hash_bucket_ids in enumerate(hash_bucket_tensors): + name = f"HashBucketCodepointEmbedder_{i}" + shard_embeddings = getattr(self, name)(hash_bucket_ids) + embedding_shards.append(shard_embeddings) + + return torch.cat(embedding_shards, dim=-1) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self._embed_hash_buckets( + input_ids, self.config.hidden_size, self.config.num_hash_functions, self.config.num_hash_buckets + ) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + + if self.position_embedding_type == "absolute": + position_embeddings = self.char_position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class CharactersToMolecules(nn.Module): + """Convert character sequence to initial molecule sequence (i.e. downsample) using strided convolutions.""" + + def __init__(self, config): + super().__init__() + + self.conv = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=config.downsampling_rate, + stride=config.downsampling_rate, + ) + self.activation = ACT2FN[config.hidden_act] + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, char_encoding: torch.Tensor) -> torch.Tensor: + # `cls_encoding`: [batch, 1, hidden_size] + cls_encoding = char_encoding[:, 0:1, :] + + # char_encoding has shape [batch, char_seq, hidden_size] + # We transpose it to be [batch, hidden_size, char_seq] + char_encoding = torch.transpose(char_encoding, 1, 2) + downsampled = self.conv(char_encoding) + downsampled = torch.transpose(downsampled, 1, 2) + downsampled = self.activation(downsampled) + + # Truncate the last molecule in order to reserve a position for [CLS]. + # Often, the last position is never used (unless we completely fill the + # text buffer). This is important in order to maintain alignment on TPUs + # (i.e. a multiple of 128). + downsampled_truncated = downsampled[:, 0:-1, :] + + # We also keep [CLS] as a separate sequence position since we always + # want to reserve a position (and the model capacity that goes along + # with that) in the deep BERT stack. + # `result`: [batch, molecule_seq, molecule_dim] + result = torch.cat([cls_encoding, downsampled_truncated], dim=1) + + result = self.LayerNorm(result) + + return result + + +class ConvProjection(nn.Module): + """ + Project representations from hidden_size*2 back to hidden_size across a window of w = config.upsampling_kernel_size + characters. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.conv = nn.Conv1d( + in_channels=config.hidden_size * 2, + out_channels=config.hidden_size, + kernel_size=config.upsampling_kernel_size, + stride=1, + ) + self.activation = ACT2FN[config.hidden_act] + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + inputs: torch.Tensor, + final_seq_char_positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final] + # we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq] + inputs = torch.transpose(inputs, 1, 2) + + # PyTorch < 1.9 does not support padding="same" (which is used in the original implementation), + # so we pad the tensor manually before passing it to the conv layer + # based on https://github.com/google-research/big_transfer/blob/49afe42338b62af9fbe18f0258197a33ee578a6b/bit_tf2/models.py#L36-L38 + pad_total = self.config.upsampling_kernel_size - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + + pad = nn.ConstantPad1d((pad_beg, pad_end), 0) + # `result`: shape (batch_size, char_seq_len, hidden_size) + result = self.conv(pad(inputs)) + result = torch.transpose(result, 1, 2) + result = self.activation(result) + result = self.LayerNorm(result) + result = self.dropout(result) + final_char_seq = result + + if final_seq_char_positions is not None: + # Limit transformer query seq and attention mask to these character + # positions to greatly reduce the compute cost. Typically, this is just + # done for the MLM training task. + # TODO add support for MLM + raise NotImplementedError("CanineForMaskedLM is currently not supported") + else: + query_seq = final_char_seq + + return query_seq + + +class CanineSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + def forward( + self, + from_tensor: torch.Tensor, + to_tensor: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch_size, seq_length, _ = from_tensor.shape + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + + key_layer = ( + self.key(to_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(to_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(from_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = from_tensor.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=from_tensor.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=from_tensor.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + if attention_mask.ndim == 3: + # if attention_mask is 3D, do the following: + attention_mask = torch.unsqueeze(attention_mask, dim=1) + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min + # Apply the attention mask (precomputed for all layers in CanineModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class CanineSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: tuple[torch.FloatTensor], input_tensor: torch.FloatTensor + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class CanineAttention(nn.Module): + """ + Additional arguments related to local attention: + + - **local** (`bool`, *optional*, defaults to `False`) -- Whether to apply local attention. + - **always_attend_to_first_position** (`bool`, *optional*, defaults to `False`) -- Should all blocks be able to + attend + to the `to_tensor`'s first position (e.g. a [CLS] position)? - **first_position_attends_to_all** (`bool`, + *optional*, defaults to `False`) -- Should the *from_tensor*'s first position be able to attend to all + positions within the *from_tensor*? - **attend_from_chunk_width** (`int`, *optional*, defaults to 128) -- The + width of each block-wise chunk in `from_tensor`. - **attend_from_chunk_stride** (`int`, *optional*, defaults to + 128) -- The number of elements to skip when moving to the next block in `from_tensor`. - + **attend_to_chunk_width** (`int`, *optional*, defaults to 128) -- The width of each block-wise chunk in + *to_tensor*. - **attend_to_chunk_stride** (`int`, *optional*, defaults to 128) -- The number of elements to + skip when moving to the next block in `to_tensor`. + """ + + def __init__( + self, + config, + local=False, + always_attend_to_first_position: bool = False, + first_position_attends_to_all: bool = False, + attend_from_chunk_width: int = 128, + attend_from_chunk_stride: int = 128, + attend_to_chunk_width: int = 128, + attend_to_chunk_stride: int = 128, + ): + super().__init__() + self.self = CanineSelfAttention(config) + self.output = CanineSelfOutput(config) + self.pruned_heads = set() + + # additional arguments related to local attention + self.local = local + if attend_from_chunk_width < attend_from_chunk_stride: + raise ValueError( + "`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped." + ) + if attend_to_chunk_width < attend_to_chunk_stride: + raise ValueError( + "`attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped." + ) + self.always_attend_to_first_position = always_attend_to_first_position + self.first_position_attends_to_all = first_position_attends_to_all + self.attend_from_chunk_width = attend_from_chunk_width + self.attend_from_chunk_stride = attend_from_chunk_stride + self.attend_to_chunk_width = attend_to_chunk_width + self.attend_to_chunk_stride = attend_to_chunk_stride + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: tuple[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + if not self.local: + self_outputs = self.self(hidden_states, hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self_outputs[0] + else: + from_seq_length = to_seq_length = hidden_states.shape[1] + from_tensor = to_tensor = hidden_states + + # Create chunks (windows) that we will attend *from* and then concatenate them. + from_chunks = [] + if self.first_position_attends_to_all: + from_chunks.append((0, 1)) + # We must skip this first position so that our output sequence is the + # correct length (this matters in the *from* sequence only). + from_start = 1 + else: + from_start = 0 + for chunk_start in range(from_start, from_seq_length, self.attend_from_chunk_stride): + chunk_end = min(from_seq_length, chunk_start + self.attend_from_chunk_width) + from_chunks.append((chunk_start, chunk_end)) + + # Determine the chunks (windows) that will attend *to*. + to_chunks = [] + if self.first_position_attends_to_all: + to_chunks.append((0, to_seq_length)) + for chunk_start in range(0, to_seq_length, self.attend_to_chunk_stride): + chunk_end = min(to_seq_length, chunk_start + self.attend_to_chunk_width) + to_chunks.append((chunk_start, chunk_end)) + + if len(from_chunks) != len(to_chunks): + raise ValueError( + f"Expected to have same number of `from_chunks` ({from_chunks}) and " + f"`to_chunks` ({from_chunks}). Check strides." + ) + + # next, compute attention scores for each pair of windows and concatenate + attention_output_chunks = [] + attention_probs_chunks = [] + for (from_start, from_end), (to_start, to_end) in zip(from_chunks, to_chunks): + from_tensor_chunk = from_tensor[:, from_start:from_end, :] + to_tensor_chunk = to_tensor[:, to_start:to_end, :] + # `attention_mask`: [batch_size, from_seq, to_seq] + # `attention_mask_chunk`: [batch_size, from_seq_chunk, to_seq_chunk] + attention_mask_chunk = attention_mask[:, from_start:from_end, to_start:to_end] + if self.always_attend_to_first_position: + cls_attention_mask = attention_mask[:, from_start:from_end, 0:1] + attention_mask_chunk = torch.cat([cls_attention_mask, attention_mask_chunk], dim=2) + + cls_position = to_tensor[:, 0:1, :] + to_tensor_chunk = torch.cat([cls_position, to_tensor_chunk], dim=1) + + attention_outputs_chunk = self.self( + from_tensor_chunk, to_tensor_chunk, attention_mask_chunk, head_mask, output_attentions + ) + attention_output_chunks.append(attention_outputs_chunk[0]) + if output_attentions: + attention_probs_chunks.append(attention_outputs_chunk[1]) + + attention_output = torch.cat(attention_output_chunks, dim=1) + + attention_output = self.output(attention_output, hidden_states) + outputs = (attention_output,) + if not self.local: + outputs = outputs + self_outputs[1:] # add attentions if we output them + else: + outputs = outputs + tuple(attention_probs_chunks) # add attentions if we output them + return outputs + + +class CanineIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class CanineOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class CanineLayer(GradientCheckpointingLayer): + def __init__( + self, + config, + local, + always_attend_to_first_position, + first_position_attends_to_all, + attend_from_chunk_width, + attend_from_chunk_stride, + attend_to_chunk_width, + attend_to_chunk_stride, + ): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = CanineAttention( + config, + local, + always_attend_to_first_position, + first_position_attends_to_all, + attend_from_chunk_width, + attend_from_chunk_stride, + attend_to_chunk_width, + attend_to_chunk_stride, + ) + self.intermediate = CanineIntermediate(config) + self.output = CanineOutput(config) + + def forward( + self, + hidden_states: tuple[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class CanineEncoder(nn.Module): + def __init__( + self, + config, + local=False, + always_attend_to_first_position=False, + first_position_attends_to_all=False, + attend_from_chunk_width=128, + attend_from_chunk_stride=128, + attend_to_chunk_width=128, + attend_to_chunk_stride=128, + ): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [ + CanineLayer( + config, + local, + always_attend_to_first_position, + first_position_attends_to_all, + attend_from_chunk_width, + attend_from_chunk_stride, + attend_to_chunk_width, + attend_to_chunk_stride, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: tuple[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class CaninePooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class CaninePredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class CanineLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = CaninePredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor: + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class CanineOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = CanineLMPredictionHead(config) + + def forward( + self, + sequence_output: tuple[torch.Tensor], + ) -> tuple[torch.Tensor]: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@auto_docstring +class CaninePreTrainedModel(PreTrainedModel): + config: CanineConfig + load_tf_weights = load_tf_weights_in_canine + base_model_prefix = "canine" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class CanineModel(CaninePreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + shallow_config = copy.deepcopy(config) + shallow_config.num_hidden_layers = 1 + + self.char_embeddings = CanineEmbeddings(config) + # shallow/low-dim transformer encoder to get a initial character encoding + self.initial_char_encoder = CanineEncoder( + shallow_config, + local=True, + always_attend_to_first_position=False, + first_position_attends_to_all=False, + attend_from_chunk_width=config.local_transformer_stride, + attend_from_chunk_stride=config.local_transformer_stride, + attend_to_chunk_width=config.local_transformer_stride, + attend_to_chunk_stride=config.local_transformer_stride, + ) + self.chars_to_molecules = CharactersToMolecules(config) + # deep transformer encoder + self.encoder = CanineEncoder(config) + self.projection = ConvProjection(config) + # shallow/low-dim transformer encoder to get a final character encoding + self.final_char_encoder = CanineEncoder(shallow_config) + + self.pooler = CaninePooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def _create_3d_attention_mask_from_input_mask(self, from_tensor, to_mask): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. + to_mask: int32 Tensor of shape [batch_size, to_seq_length]. + + Returns: + float Tensor of shape [batch_size, from_seq_length, to_seq_length]. + """ + batch_size, from_seq_length = from_tensor.shape[0], from_tensor.shape[1] + + to_seq_length = to_mask.shape[1] + + to_mask = torch.reshape(to_mask, (batch_size, 1, to_seq_length)).float() + + # We don't assume that `from_tensor` is a mask (although it could be). We + # don't actually care if we attend *from* padding tokens (only *to* padding) + # tokens so we create a tensor of all ones. + broadcast_ones = torch.ones(size=(batch_size, from_seq_length, 1), dtype=torch.float32, device=to_mask.device) + + # Here we broadcast along two dimensions to create the mask. + mask = broadcast_ones * to_mask + + return mask + + def _downsample_attention_mask(self, char_attention_mask: torch.Tensor, downsampling_rate: int): + """Downsample 2D character attention mask to 2D molecule attention mask using MaxPool1d layer.""" + + # first, make char_attention_mask 3D by adding a channel dim + batch_size, char_seq_len = char_attention_mask.shape + poolable_char_mask = torch.reshape(char_attention_mask, (batch_size, 1, char_seq_len)) + + # next, apply MaxPool1d to get pooled_molecule_mask of shape (batch_size, 1, mol_seq_len) + pooled_molecule_mask = torch.nn.MaxPool1d(kernel_size=downsampling_rate, stride=downsampling_rate)( + poolable_char_mask.float() + ) + + # finally, squeeze to get tensor of shape (batch_size, mol_seq_len) + molecule_attention_mask = torch.squeeze(pooled_molecule_mask, dim=-1) + + return molecule_attention_mask + + def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: int) -> torch.Tensor: + """Repeats molecules to make them the same length as the char sequence.""" + + rate = self.config.downsampling_rate + + molecules_without_extra_cls = molecules[:, 1:, :] + # `repeated`: [batch_size, almost_char_seq_len, molecule_hidden_size] + repeated = torch.repeat_interleave(molecules_without_extra_cls, repeats=rate, dim=-2) + + # So far, we've repeated the elements sufficient for any `char_seq_length` + # that's a multiple of `downsampling_rate`. Now we account for the last + # n elements (n < `downsampling_rate`), i.e. the remainder of floor + # division. We do this by repeating the last molecule a few extra times. + last_molecule = molecules[:, -1:, :] + remainder_length = char_seq_length % rate + remainder_repeated = torch.repeat_interleave( + last_molecule, + # +1 molecule to compensate for truncation. + repeats=remainder_length + rate, + dim=-2, + ) + + # `repeated`: [batch_size, char_seq_len, molecule_hidden_size] + return torch.cat([repeated, remainder_repeated], dim=-2) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, CanineModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + molecule_attention_mask = self._downsample_attention_mask( + attention_mask, downsampling_rate=self.config.downsampling_rate + ) + extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask( + molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]) + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # `input_char_embeddings`: shape (batch_size, char_seq, char_dim) + input_char_embeddings = self.char_embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + + # Contextualize character embeddings using shallow Transformer. + # We use a 3D attention mask for the local attention. + # `input_char_encoding`: shape (batch_size, char_seq_len, char_dim) + char_attention_mask = self._create_3d_attention_mask_from_input_mask( + input_ids if input_ids is not None else inputs_embeds, attention_mask + ) + init_chars_encoder_outputs = self.initial_char_encoder( + input_char_embeddings, + attention_mask=char_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + input_char_encoding = init_chars_encoder_outputs.last_hidden_state + + # Downsample chars to molecules. + # The following lines have dimensions: [batch, molecule_seq, molecule_dim]. + # In this transformation, we change the dimensionality from `char_dim` to + # `molecule_dim`, but do *NOT* add a resnet connection. Instead, we rely on + # the resnet connections (a) from the final char transformer stack back into + # the original char transformer stack and (b) the resnet connections from + # the final char transformer stack back into the deep BERT stack of + # molecules. + # + # Empirically, it is critical to use a powerful enough transformation here: + # mean pooling causes training to diverge with huge gradient norms in this + # region of the model; using a convolution here resolves this issue. From + # this, it seems that molecules and characters require a very different + # feature space; intuitively, this makes sense. + init_molecule_encoding = self.chars_to_molecules(input_char_encoding) + + # Deep BERT encoder + # `molecule_sequence_output`: shape (batch_size, mol_seq_len, mol_dim) + encoder_outputs = self.encoder( + init_molecule_encoding, + attention_mask=extended_molecule_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + molecule_sequence_output = encoder_outputs[0] + pooled_output = self.pooler(molecule_sequence_output) if self.pooler is not None else None + + # Upsample molecules back to characters. + # `repeated_molecules`: shape (batch_size, char_seq_len, mol_hidden_size) + repeated_molecules = self._repeat_molecules(molecule_sequence_output, char_seq_length=input_shape[-1]) + + # Concatenate representations (contextualized char embeddings and repeated molecules): + # `concat`: shape [batch_size, char_seq_len, molecule_hidden_size+char_hidden_final] + concat = torch.cat([input_char_encoding, repeated_molecules], dim=-1) + + # Project representation dimension back to hidden_size + # `sequence_output`: shape (batch_size, char_seq_len, hidden_size]) + sequence_output = self.projection(concat) + + # Apply final shallow Transformer + # `sequence_output`: shape (batch_size, char_seq_len, hidden_size]) + final_chars_encoder_outputs = self.final_char_encoder( + sequence_output, + attention_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = final_chars_encoder_outputs.last_hidden_state + + if output_hidden_states: + deep_encoder_hidden_states = encoder_outputs.hidden_states if return_dict else encoder_outputs[1] + all_hidden_states = ( + all_hidden_states + + init_chars_encoder_outputs.hidden_states + + deep_encoder_hidden_states + + final_chars_encoder_outputs.hidden_states + ) + + if output_attentions: + deep_encoder_self_attentions = encoder_outputs.attentions if return_dict else encoder_outputs[-1] + all_self_attentions = ( + all_self_attentions + + init_chars_encoder_outputs.attentions + + deep_encoder_self_attentions + + final_chars_encoder_outputs.attentions + ) + + if not return_dict: + output = (sequence_output, pooled_output) + output += tuple(v for v in [all_hidden_states, all_self_attentions] if v is not None) + return output + + return CanineModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring( + custom_intro=""" + CANINE Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """ +) +class CanineForSequenceClassification(CaninePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.canine = CanineModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.canine( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class CanineForMultipleChoice(CaninePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.canine = CanineModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.canine( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class CanineForTokenClassification(CaninePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.canine = CanineModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, CanineForTokenClassification + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google/canine-s") + >>> model = CanineForTokenClassification.from_pretrained("google/canine-s") + + >>> inputs = tokenizer( + ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt" + ... ) + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_token_class_ids = logits.argmax(-1) + + >>> # Note that tokens are classified rather then input words which means that + >>> # there might be more predicted token classes than words. + >>> # Multiple token classes might account for the same word + >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]] + >>> predicted_tokens_classes # doctest: +SKIP + ``` + + ```python + >>> labels = predicted_token_class_ids + >>> loss = model(**inputs, labels=labels).loss + >>> round(loss.item(), 2) # doctest: +SKIP + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.canine( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class CanineForQuestionAnswering(CaninePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.canine = CanineModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.canine( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "CanineForMultipleChoice", + "CanineForQuestionAnswering", + "CanineForSequenceClassification", + "CanineForTokenClassification", + "CanineLayer", + "CanineModel", + "CaninePreTrainedModel", + "load_tf_weights_in_canine", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/canine/tokenization_canine.py b/venv/lib/python3.13/site-packages/transformers/models/canine/tokenization_canine.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b2a8bfd96efdf6ed87aa1c82571612cfd89a4e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/canine/tokenization_canine.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for CANINE.""" + +from typing import Optional + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +# Unicode defines 1,114,112 total “codepoints” +UNICODE_VOCAB_SIZE = 1114112 + +# Below: Constants defining canonical codepoints for special, pseudo-characters. +# Copied from https://github.com/google-research/language/blob/master/language/canine/special_codepoints.py +PAD = 0 +CLS = 0xE000 +SEP = 0xE001 +BOS = 0xE002 +MASK = 0xE003 +RESERVED = 0xE004 + +# Maps special codepoints to human-readable names. +SPECIAL_CODEPOINTS: dict[int, str] = { + # Special symbols are represented using codepoints values that are valid, + # but designated as "Private Use", meaning that they will never be assigned + # characters by the Unicode Consortium, and are thus safe for use here. + # + # NOTE: Do *NOT* add any sort of [UNK_CHAR] here. They are explicitly + # excluded and should fail with a hard error. + CLS: "[CLS]", + SEP: "[SEP]", + BOS: "[BOS]", + MASK: "[MASK]", + PAD: "[PAD]", + RESERVED: "[RESERVED]", +} + +# Maps special codepoint human-readable names to their codepoint values. +SPECIAL_CODEPOINTS_BY_NAME: dict[str, int] = {name: codepoint for codepoint, name in SPECIAL_CODEPOINTS.items()} + + +class CanineTokenizer(PreTrainedTokenizer): + r""" + Construct a CANINE tokenizer (i.e. a character splitter). It turns text into a sequence of characters, and then + converts each character into its Unicode code point. + + [`CanineTokenizer`] inherits from [`PreTrainedTokenizer`]. + + Refer to superclass [`PreTrainedTokenizer`] for usage examples and documentation concerning parameters. + + Args: + model_max_length (`int`, *optional*, defaults to 2048): + The maximum sentence length the model accepts. + """ + + def __init__( + self, + bos_token=chr(CLS), + eos_token=chr(SEP), + sep_token=chr(SEP), + cls_token=chr(CLS), + pad_token=chr(PAD), + mask_token=chr(MASK), + add_prefix_space=False, + model_max_length=2048, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + # Creates a mapping for looking up the IDs of special symbols. + self._special_codepoints: dict[str, int] = {} + for codepoint, name in SPECIAL_CODEPOINTS.items(): + self._special_codepoints[name] = codepoint + + # Creates a mapping for looking up the string forms of special symbol IDs. + self._special_codepoint_strings: dict[int, str] = { + codepoint: name for name, codepoint in self._special_codepoints.items() + } + + self._unicode_vocab_size = UNICODE_VOCAB_SIZE + self._num_special_tokens = len(self._special_codepoints) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + model_max_length=model_max_length, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return self._unicode_vocab_size + + def get_vocab(self): + vocab = {chr(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> list[str]: + """Tokenize a string (i.e. perform character splitting).""" + return list(text) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (i.e. a Unicode character) in an id (i.e. its integer Unicode code point value).""" + try: + return ord(token) + except TypeError: + raise ValueError(f"invalid token: '{token}'") + + def _convert_id_to_token(self, index: int) -> str: + """ + Converts a Unicode code point (integer) in a token (str). In case it's a special code point, convert to + human-readable format. + """ + try: + if index in SPECIAL_CODEPOINTS: + return SPECIAL_CODEPOINTS[index] + return chr(index) + except TypeError: + raise ValueError(f"invalid id: {index}") + + def convert_tokens_to_string(self, tokens): + return "".join(tokens) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A CANINE sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + result = cls + token_ids_0 + sep + if token_ids_1 is not None: + result += token_ids_1 + sep + return result + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + result = [1] + ([0] * len(token_ids_0)) + [1] + if token_ids_1 is not None: + result += ([0] * len(token_ids_1)) + [1] + return result + + # CanineTokenizer has no vocab file + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None): + return () + + +__all__ = ["CanineTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4476b9080ff514fe273fde33d96b5c0edaba2b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_chinese_clip import * + from .feature_extraction_chinese_clip import * + from .image_processing_chinese_clip import * + from .image_processing_chinese_clip_fast import * + from .modeling_chinese_clip import * + from .processing_chinese_clip import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/configuration_chinese_clip.py b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/configuration_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c628107048b9b7651edf1bd3ac28ccffc6e29dbc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/configuration_chinese_clip.py @@ -0,0 +1,423 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Chinese-CLIP model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional + + +if TYPE_CHECKING: + from ...processing_utils import ProcessorMixin + from ...utils import TensorType + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ChineseCLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate a + Chinese CLIP model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Chinese CLIP + [OFA-Sys/chinese-clip-vit-base-patch16](https: + //huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the CHINESE_CLIP model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`ChineseCLIPModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`ChineseCLIPModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + + Example: + + ```python + >>> from transformers import ChineseCLIPTextConfig, ChineseCLIPTextModel + + >>> # Initializing a ChineseCLIPTextConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> configuration = ChineseCLIPTextConfig() + + >>> # Initializing a ChineseCLIPTextModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> model = ChineseCLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "chinese_clip_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + initializer_factor=1.0, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + + +class ChineseCLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an + ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ChineseCLIP + [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + Example: + ```python + >>> from transformers import ChineseCLIPVisionConfig, ChineseCLIPVisionModel + + >>> # Initializing a ChineseCLIPVisionConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> configuration = ChineseCLIPVisionConfig() + + >>> # Initializing a ChineseCLIPVisionModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> model = ChineseCLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "chinese_clip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class ChineseCLIPConfig(PretrainedConfig): + r""" + [`ChineseCLIPConfig`] is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used + to instantiate Chinese-CLIP model according to the specified arguments, defining the text model and vision model + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + Chinese-CLIP [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ChineseCLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ChineseCLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original ChineseCLIP + implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ChineseCLIPConfig, ChineseCLIPModel + + >>> # Initializing a ChineseCLIPConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> configuration = ChineseCLIPConfig() + + >>> # Initializing a ChineseCLIPModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> model = ChineseCLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a ChineseCLIPConfig from a ChineseCLIPTextConfig and a ChineseCLIPVisionConfig + + >>> # Initializing a ChineseCLIPTextConfig and ChineseCLIPVisionConfig configuration + >>> config_text = ChineseCLIPTextConfig() + >>> config_vision = ChineseCLIPVisionConfig() + + >>> config = ChineseCLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "chinese_clip" + sub_configs = {"text_config": ChineseCLIPTextConfig, "vision_config": ChineseCLIPVisionConfig} + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = ChineseCLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key != "transformers_version": + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `ChineseCLIPTextConfig`. " + f'The value `text_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = ChineseCLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key != "transformers_version": + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize " + f'`ChineseCLIPVisionConfig`. The value `vision_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `ChineseCLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `ChineseCLIPVisionConfig` with default values.") + + self.text_config = ChineseCLIPTextConfig(**text_config) + self.vision_config = ChineseCLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + self.initializer_range = 0.02 + + +class ChineseCLIPOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("logits_per_image", {0: "batch"}), + ("logits_per_text", {0: "batch"}), + ("text_embeds", {0: "batch"}), + ("image_embeds", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + processor: "ProcessorMixin", + batch_size: int = -1, + seq_length: int = -1, + framework: Optional["TensorType"] = None, + ) -> Mapping[str, Any]: + text_input_dict = super().generate_dummy_inputs( + processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + ) + image_input_dict = super().generate_dummy_inputs( + processor.image_processor, batch_size=batch_size, framework=framework + ) + return {**text_input_dict, **image_input_dict} + + @property + def default_onnx_opset(self) -> int: + return 14 + + +__all__ = ["ChineseCLIPConfig", "ChineseCLIPOnnxConfig", "ChineseCLIPTextConfig", "ChineseCLIPVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/feature_extraction_chinese_clip.py b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/feature_extraction_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c4895bb06b510cfeb64294759c31bcc8d0e3d098 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/feature_extraction_chinese_clip.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Chinese-CLIP.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_chinese_clip import ChineseCLIPImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ChineseCLIPFeatureExtractor(ChineseCLIPImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class ChineseCLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use ChineseCLIPImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["ChineseCLIPFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/image_processing_chinese_clip.py b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/image_processing_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c55805f28913dca050f48ee24f8fa7050f974513 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -0,0 +1,314 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Chinese-CLIP.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ChineseCLIPImageProcessor(BaseImageProcessor): + r""" + Constructs a Chinese-CLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input + image. + """ + size = get_size_dict(size, default_to_square=False) + output_size = get_resize_output_image_size( + image, size=(size["height"], size["width"]), default_to_square=False, input_data_format=input_data_format + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["ChineseCLIPImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..6ced07e0c29e523ce3be8775d16a8e1858887604 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py @@ -0,0 +1,37 @@ +# coding=utf-8 +# Copyright 2025 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Chinese-CLIP.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...utils import auto_docstring + + +@auto_docstring +class ChineseCLIPImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + +__all__ = ["ChineseCLIPImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b461ab3ed37445c0360b82cd20f9d76a7e48bb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -0,0 +1,1213 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Chinese-CLIP model.""" + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int +from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig + + +logger = logging.get_logger(__name__) + + +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +def chinese_clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +@auto_docstring +class ChineseCLIPOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of + [`ChineseCLIPTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`ChineseCLIPVisionModel`]. + text_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`): + The output of the [`ChineseCLIPTextModel`]. + vision_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`): + The output of the [`ChineseCLIPVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None + vision_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.align.modeling_align.AlignTextEmbeddings with Align->ChineseCLIP +class ChineseCLIPTextEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->ChineseCLIP +class ChineseCLIPVisionEmbeddings(nn.Module): + def __init__(self, config: ChineseCLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." + ) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->ChineseCLIP +class ChineseCLIPTextSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) + + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->ChineseCLIPText +class ChineseCLIPTextSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->ChineseCLIP +class ChineseCLIPTextAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = ChineseCLIPTextSelfAttention(config) + self.output = ChineseCLIPTextSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ChineseCLIPVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.scale + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + None, + dropout=0.0 if not self.training else self.dropout, + scaling=1.0, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText +class ChineseCLIPTextIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->ChineseCLIPText +class ChineseCLIPTextOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->ChineseCLIPVision +class ChineseCLIPVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->ChineseCLIP +class ChineseCLIPTextLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ChineseCLIPTextAttention(config) + self.intermediate = ChineseCLIPTextIntermediate(config) + self.output = ChineseCLIPTextOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class ChineseCLIPVisionLayer(GradientCheckpointingLayer): + def __init__(self, config: ChineseCLIPConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ChineseCLIPVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = ChineseCLIPVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->ChineseCLIPText +class ChineseCLIPTextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class ChineseCLIPPreTrainedModel(PreTrainedModel): + config: ChineseCLIPConfig + base_model_prefix = "chinese_clip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, ChineseCLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, ChineseCLIPTextEmbeddings): + nn.init.normal_(module.word_embeddings.weight, mean=0.0, std=self.config.initializer_range) + nn.init.normal_(module.position_embeddings.weight, mean=0.0, std=self.config.initializer_range) + nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range) + for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]: + if embedding.padding_idx is not None: + embedding.weight.data[embedding.padding_idx].zero_() + elif isinstance(module, ChineseCLIPVisionAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, ChineseCLIPVisionMLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, ChineseCLIPModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP +class ChineseCLIPTextEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ChineseCLIPVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`ChineseCLIPVisionEncoderLayer`]. + + Args: + config: ChineseCLIPConfig + """ + + def __init__(self, config: ChineseCLIPConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + inputs_embeds, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class ChineseCLIPVisionTransformer(nn.Module): + def __init__(self, config: ChineseCLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = ChineseCLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = ChineseCLIPVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The text model from CHINESE_CLIP without any head or projection on top. + """ +) +class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + config: ChineseCLIPTextConfig + _no_split_modules = ["ChineseCLIPTextEmbeddings"] + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = ChineseCLIPTextEmbeddings(config) + self.encoder = ChineseCLIPTextEncoder(config) + + self.pooler = ChineseCLIPTextPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The vision model from CHINESE_CLIP without any head or projection on top. + """ +) +class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel): + config: ChineseCLIPVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"] + + def __init__(self, config: ChineseCLIPVisionConfig): + super().__init__(config) + self.vision_model = ChineseCLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import CLIPProcessor, ChineseCLIPVisionModel + + >>> model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> processor = CLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + +@auto_docstring +class ChineseCLIPModel(ChineseCLIPPreTrainedModel): + config: ChineseCLIPConfig + + def __init__(self, config: ChineseCLIPConfig): + super().__init__(config) + + if not isinstance(config.text_config, ChineseCLIPTextConfig): + raise TypeError( + "config.text_config is expected to be of type ChineseCLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, ChineseCLIPVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type ChineseCLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + vision_config._attn_implementation = config._attn_implementation + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = ChineseCLIPTextModel(text_config, add_pooling_layer=False) + self.vision_model = ChineseCLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the final [CLS] hidden state of Text-Transformer. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, ChineseCLIPModel + + >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> tokenizer = AutoTokenizer.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> inputs = tokenizer(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], padding=True, return_tensors="pt") + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + >>> text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) + ```""" + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + pooled_output = text_outputs.pooler_output + text_features = self.text_projection(pooled_output) + + return text_features + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the final [CLS] hidden state of Vision-Transformer. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, ChineseCLIPModel + >>> from transformers.image_utils import load_image + + >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.inference_mode(): + ... image_features = model.get_image_features(**inputs) + >>> image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs.pooler_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ChineseCLIPOutput]: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, ChineseCLIPModel + >>> from transformers.image_utils import load_image + + >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg" + >>> image = load_image(url) + + >>> inputs = processor(text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], images=image, return_tensors="pt", padding=True) + + >>> with torch.inference_mode(): + ... outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[0][:, 0, :] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = chinese_clip_loss(logits_per_text) + + return ChineseCLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["ChineseCLIPModel", "ChineseCLIPPreTrainedModel", "ChineseCLIPTextModel", "ChineseCLIPVisionModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/processing_chinese_clip.py b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/processing_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c6fb9f9df247555b2946cf471c650db2974450df --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/chinese_clip/processing_chinese_clip.py @@ -0,0 +1,67 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for Chinese-CLIP +""" + +import warnings + +from ...processing_utils import ProcessorMixin + + +class ChineseCLIPProcessor(ProcessorMixin): + r""" + Constructs a Chinese-CLIP processor which wraps a Chinese-CLIP image processor and a Chinese-CLIP tokenizer into a + single processor. + + [`ChineseCLIPProcessor`] offers all the functionalities of [`ChineseCLIPImageProcessor`] and [`BertTokenizerFast`]. + See the [`~ChineseCLIPProcessor.__call__`] and [`~ChineseCLIPProcessor.decode`] for more information. + + Args: + image_processor ([`ChineseCLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`BertTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast") + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + +__all__ = ["ChineseCLIPProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/clvp/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/clvp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..986e185ff7771e5909db2b0e2ab91ba95cfa3a27 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/clvp/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_clvp import * + from .feature_extraction_clvp import * + from .modeling_clvp import * + from .processing_clvp import * + from .tokenization_clvp import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/clvp/configuration_clvp.py b/venv/lib/python3.13/site-packages/transformers/models/clvp/configuration_clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d9957cbd4009a49096be686e3936d1b427e102 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/clvp/configuration_clvp.py @@ -0,0 +1,439 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CLVP model configuration""" + +import os +from typing import Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ClvpEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ClvpEncoder`]. It is used to instantiate a CLVP + text or CLVP speech encoder according to the specified arguments. Instantiating a configuration with the defaults + will yield a similar configuration to that of the encoder of the CLVP + [susnato/clvp_dev](https://huggingface.co/susnato/clvp_dev) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the CLVP Encoder model. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 1536): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + projection_dim (`int`, *optional*, defaults to 768): + Dimensionality of the projection vector. + num_hidden_layers (`int`, *optional*, defaults to 20): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the feed-forward layers in [`ClvpEncoderMLP`]. + use_rotary_embedding (`bool`, *optional*, defaults to `True`): + Whether to use rotary_embedding or not. + use_attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in Query, Key and Value layers during self attention. + summary_type (`str`, *optional*, defaults to `"mean"`): + What strategy to use to get pooler_output from the last_hidden_state. `"last"`, `"first"`, `"mean"` and + `"cls_index"` are supported. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization + testing). + bos_token_id (`int`, *optional*, defaults to 255): + Beginning of sequence token id. + eos_token_id (`int`, *optional*, defaults to 0): + End of sequence token id. + + Example: + + ```python + >>> from transformers import ClvpEncoderConfig, ClvpEncoder + + >>> # Initializing a ClvpEncoderConfig with susnato/clvp_dev style configuration + >>> encoder_configuration = ClvpEncoderConfig() + + >>> # Initializing a ClvpEncoder (with random weights) from the susnato/clvp_dev style configuration + >>> model = ClvpEncoder(encoder_configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clvp_encoder" + base_config_key = ["text_config", "speech_config"] + + def __init__( + self, + vocab_size=256, + hidden_size=768, + intermediate_size=1536, + projection_dim=768, + num_hidden_layers=20, + num_attention_heads=12, + hidden_act="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.1, + dropout=0.1, + use_rotary_embedding=True, + use_attention_bias=False, + summary_type="mean", + initializer_factor=1.0, + bos_token_id=255, + eos_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.dropout = dropout + self.use_rotary_embedding = use_rotary_embedding + self.use_attention_bias = use_attention_bias + self.summary_type = summary_type + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_type: str = "text_config", **kwargs + ): + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # make sure to have the config_type be either "text_config" or "speech_config" + # this is to make sure that we can load only text or speech configs from the nested ClvpConfig. + if config_type not in cls.base_config_key: + raise ValueError( + f"We can only load either 'text_config' or 'speech_config' but you are trying to load{config_type}" + ) + + # get the text config dict if we are loading from ClvpConfig + if config_dict.get("model_type") == "clvp": + config_dict = config_dict[config_type] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class ClvpDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ClvpDecoder`]. It is used to instantiate a CLVP + Decoder Model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Decoder part of the CLVP + [susnato/clvp_dev](https://huggingface.co/susnato/clvp_dev) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + The architecture is similar to GPT2. + + Args: + vocab_size (`int`, *optional*, defaults to 8194): + Vocabulary size of the model. + max_position_embeddings (`int`, *optional*, defaults to 608): + The maximum sequence length of mel tokens that this model might ever be used with. Similar to `n_positions` + in `GPT2Config`. + max_text_tokens (`int`, *optional*, defaults to 404): + The maximum sequence length of text tokens that this model might ever be used with. Similar to + `n_positions` in `GPT2Config`. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the embeddings and hidden states. + num_hidden_layers (`int`, *optional*, defaults to 30): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times `hidden_size`. + num_mel_attn_blocks (`int`, *optional*, defaults to 6): + Denotes the number of self attention layers in [`ClvpConditioningEncoder`]. + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio to be used after the projection and activation. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 8192): + Beginning of sequence token id, used at the start of the generation. + eos_token_id (`int`, *optional*, defaults to 8193): + End of sequence token id, used in the method + [`ClvpModelForConditionalGeneration.fix_speech_decoder_output()`] to correct decoder outputs. + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted mel features. This value is used in [`ClvpConditioningEncoder`]. + use_attention_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in Query, Key and Value layers during self attention. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization + testing). + decoder_fixing_codes (`list`, *optional*, defaults to `[83, 45, 45, 248]`): + These values are used in the method `fix_speech_decoder_output` to fix decoder generated outputs. + + Example: + + ```python + >>> from transformers import ClvpDecoderConfig, ClvpDecoder + + >>> # Initializing a ClvpDecoderConfig with susnato/clvp_dev style configuration + >>> decoder_configuration = ClvpDecoderConfig() + + >>> # Initializing a ClvpDecoder (with random weights) from the susnato/clvp_dev style configuration + >>> model = ClvpDecoder(decoder_configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clvp_decoder" + base_config_key = "decoder_config" + + def __init__( + self, + vocab_size=8194, + max_position_embeddings=608, + max_text_tokens=404, + hidden_size=1024, + num_hidden_layers=30, + num_attention_heads=16, + n_inner=None, + num_mel_attn_blocks=6, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attention_dropout=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + use_cache=True, + bos_token_id=8192, + eos_token_id=8193, + feature_size=80, + use_attention_bias=True, + initializer_factor=1.0, + decoder_fixing_codes=[83, 45, 45, 248], + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.max_text_tokens = max_text_tokens + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_inner = n_inner + self.num_mel_attn_blocks = num_mel_attn_blocks + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.use_cache = use_cache + self.feature_size = feature_size + self.use_attention_bias = use_attention_bias + self.initializer_factor = initializer_factor + self.decoder_fixing_codes = decoder_fixing_codes + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class ClvpConfig(PretrainedConfig): + r""" + [`ClvpConfig`] is the configuration class to store the configuration of a [`ClvpModelForConditionalGeneration`]. It + is used to instantiate a CLVP model according to the specified arguments, defining the text model, speech model and + decoder model configs. Instantiating a configuration with the defaults will yield a similar configuration to that + of the CLVP [susnato/clvp_dev](https://huggingface.co/susnato/clvp_dev) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize the CLVP text encoder. + speech_config (`dict`, *optional*): + Dictionary of configuration options used to initialize CLVP speech encoder. + decoder_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ClvpDecoderConfig`]. + projection_dim (`int`, *optional*, defaults to 768): + Dimensionality of text and speech projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original CLVP implementation. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization + testing). + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ClvpConfig, ClvpModelForConditionalGeneration + + >>> # Initializing a ClvpConfig with susnato/clvp_dev style configuration + >>> configuration = ClvpConfig() + + >>> # Initializing a ClvpModelForConditionalGeneration (with random weights) from the susnato/clvp_dev style configuration + >>> model = ClvpModelForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a CLVPConfig from a CLVPTextConfig, CLVPSpeechConfig and a CLVPAutoRegressiveConfig + >>> from transformers import ClvpEncoderConfig, ClvpDecoderConfig + + >>> # Initializing a CLVP text, CLVP speech and CLVP decoder configuration + >>> config_text = ClvpEncoderConfig() + >>> config_speech = ClvpEncoderConfig() + >>> decoder_config = ClvpDecoderConfig() + + >>> config = ClvpConfig.from_sub_model_configs(config_text, config_speech, decoder_config) + ```""" + + model_type = "clvp" + sub_configs = { + "text_config": ClvpEncoderConfig, + "speech_config": ClvpEncoderConfig, + "decoder_config": ClvpDecoderConfig, + } + + def __init__( + self, + text_config=None, + speech_config=None, + decoder_config=None, + projection_dim=768, + logit_scale_init_value=2.6592, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `ClvpEncoderConfig` with default values.") + + if speech_config is None: + speech_config = {} + logger.info("`speech_config` is `None`. initializing the `ClvpEncoderConfig` with default values.") + + if decoder_config is None: + decoder_config = {} + logger.info("`decoder_config` is `None`. initializing the `ClvpDecoderConfig` with default values.") + + self.text_config = ClvpEncoderConfig(**text_config) + self.speech_config = ClvpEncoderConfig(**speech_config) + self.decoder_config = ClvpDecoderConfig(**decoder_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = initializer_factor + + @classmethod + def from_sub_model_configs( + cls, + text_config: ClvpEncoderConfig, + speech_config: ClvpEncoderConfig, + decoder_config: ClvpDecoderConfig, + **kwargs, + ): + r""" + Instantiate a [`ClvpConfig`] (or a derived class) from CLVP text model configuration, CLVP speech model + configuration and CLVP decoder model configuration. + + Args: + text_config (`ClvpEncoderConfig`): + Text model configuration of type [`ClvpEncoderConfig`]. + speech_config (`ClvpEncoderConfig`): + Speech model configuration of type [`ClvpEncoderConfig`]. + decoder_config (`ClvpDecoderConfig`): + Decoder model configuration of type [`ClvpDecoderConfig`]. + + Returns: + [`ClvpConfig`]: An instance of a configuration object + """ + + return cls( + text_config=text_config.to_dict(), + speech_config=speech_config.to_dict(), + decoder_config=decoder_config.to_dict(), + **kwargs, + ) + + +__all__ = ["ClvpConfig", "ClvpDecoderConfig", "ClvpEncoderConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/clvp/feature_extraction_clvp.py b/venv/lib/python3.13/site-packages/transformers/models/clvp/feature_extraction_clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..160666ef78c073ac249c2192e4d589fb9e1786f1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/clvp/feature_extraction_clvp.py @@ -0,0 +1,241 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Feature extractor class for CLVP +""" + +from typing import Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class ClvpFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a CLVP feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts log-mel-spectrogram features from raw speech using a custom numpy implementation of the `Short + Time Fourier Transform` which should match pytorch's `torch.stft` equivalent. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 22050): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + default_audio_length (`int`, *optional*, defaults to 6): + The default length of raw audio in seconds. If `max_length` is not set during `__call__` then it will + automatically be set to default_audio_length * `self.sampling_rate`. + hop_length (`int`, *optional*, defaults to 256): + Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. + chunk_length (`int`, *optional*, defaults to 30): + The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio + sequences. + n_fft (`int`, *optional*, defaults to 1024): + Size of the Fourier transform. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + mel_norms (`list` of length `feature_size`, *optional*): + If `mel_norms` is provided then it will be used to normalize the log-mel spectrograms along each + mel-filter. + return_attention_mask (`bool`, *optional*, defaults to `False`): + Whether to return the attention mask. If left to the default, it will return the attention mask. + + [What are attention masks?](../glossary#attention-mask) + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=22050, + default_audio_length=6, + hop_length=256, + chunk_length=30, + n_fft=1024, + padding_value=0.0, + mel_norms=None, + return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.n_fft = n_fft + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sampling_rate + self.nb_max_frames = self.n_samples // hop_length + self.sampling_rate = sampling_rate + self.default_audio_length = default_audio_length + self.mel_norms = mel_norms + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + (n_fft // 2), + num_mel_filters=feature_size, + min_frequency=0.0, + max_frequency=8000.0, + sampling_rate=sampling_rate, + norm="slaney", + mel_scale="htk", + ) + + def _np_extract_fbank_features(self, waveform: np.ndarray) -> np.ndarray: + """ + This method first computes the log-mel spectrogram of the provided audio then applies normalization along the + each mel-filterbank, if `mel_norms` is provided. + """ + log_spec = spectrogram( + waveform, + window_function(self.n_fft, "hann"), + frame_length=self.n_fft, + hop_length=self.hop_length, + power=2.0, + mel_filters=self.mel_filters, + log_mel=None, + ) + + log_spec = np.log(np.clip(log_spec, a_min=1e-5, a_max=None)) + + if self.mel_norms is not None: + log_spec = log_spec / np.array(self.mel_norms)[:, None] + + return log_spec + + def __call__( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + sampling_rate: Optional[int] = None, + truncation: bool = True, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = True, + padding: Optional[str] = "max_length", + max_length: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + """ + `ClvpFeatureExtractor` is used to extract various voice specific properties such as the pitch and tone of the + voice, speaking speed, and even speaking defects like a lisp or stuttering from a sample voice or `raw_speech`. + + First the voice is padded or truncated in a way such that it becomes a waveform of `self.default_audio_length` + seconds long and then the log-mel spectrogram is extracted from it. + + Args: + raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition + pipeline. + truncation (`bool`, *optional*, default to `True`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the attention mask. If left to the default, it will return the attention mask. + + [What are attention masks?](../glossary#attention-mask) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values / vectors. + max_length (`int`, *optional*): + The maximum input length of the inputs. + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [np.asarray([raw_speech]).T] + + batched_speech = BatchFeature({"input_features": raw_speech}) + + max_length = self.default_audio_length * self.sampling_rate if max_length is None else max_length + + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + # make sure list is in array format + input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + + input_features = [ + self._np_extract_fbank_features(waveform).astype(np.float32) for waveform in input_features[0] + ] + + if isinstance(input_features[0], list): + padded_inputs["input_features"] = [np.asarray(feature) for feature in input_features] + else: + padded_inputs["input_features"] = input_features + + return padded_inputs.convert_to_tensors(return_tensors) + + +__all__ = ["ClvpFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/clvp/modeling_clvp.py b/venv/lib/python3.13/site-packages/transformers/models/clvp/modeling_clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..552434b5bb2290d4b722adc95e9aca5750a0d51b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/clvp/modeling_clvp.py @@ -0,0 +1,1961 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch CLVP model.""" + +import copy +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationConfig, GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + CausalLMOutputWithCrossAttentions, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import Conv1D, isin_mps_friendly +from ...utils import ( + ModelOutput, + auto_docstring, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_clvp import ( + ClvpConfig, + ClvpDecoderConfig, + ClvpEncoderConfig, +) + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clvp, image_loss->speech_loss +def clvp_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + speech_loss = contrastive_loss(similarity.t()) + return (caption_loss + speech_loss) / 2.0 + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + v_embed = (v * cos) + (rotate_half(v) * sin) + return q_embed, k_embed, v_embed + + +def _pad_extra_bos_eos_tokens( + input_ids, + attention_mask=None, + pad_token_id=0, + bos_token_id=255, + eos_token_id=0, + add_bos_token=True, + add_eos_token=True, +): + """ + This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in + `ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`. + """ + + # add the bos token at the beginning + if add_bos_token: + input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id) + attention_mask = ( + torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask + ) + + modified_input_ids = input_ids + if add_eos_token: + modified_input_ids = torch.zeros( + (input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device + ) + for i, each_input_id in enumerate(input_ids): + # locate where the valid tokens end and then add the eos token + if isin_mps_friendly(each_input_id, pad_token_id).sum(): + pos = torch.where(each_input_id == pad_token_id)[0].min() + modified_input_ids[i] = torch.concatenate( + [each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]] + ) + else: + # if there are no pad tokens present, then add eos to the end + modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id) + attention_mask = ( + torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask + ) + + return modified_input_ids, attention_mask + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for CLVP encoder's outputs that contains a pooling of the last hidden states as well as a projection + output (a linear layer on top of the pooled output). + """ +) +class ClvpEncoderOutput(ModelOutput): + r""" + embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`): + The embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The hidden state of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Pooled output of the `last_hidden_state`. + """ + + embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring +class ClvpOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for speech-text similarity. + speech_ids (`torch.LongTensor`, *optional*): + speech_ids (or speech candidates) generated by the `ClvpForCausalLM` model. + logits_per_speech (`torch.FloatTensor` of shape `(speech_batch_size, text_batch_size)`): + The scaled dot product scores between `speech_embeds` and `text_embeds`. This represents the speech-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, speech_batch_size)`): + The scaled dot product scores between `text_embeds` and `speech_embeds`. This represents the text-speech + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of the text encoder + model. + speech_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The speech embeddings obtained by applying the projection layer to the pooled output of the speech encoder + model. + text_model_output (`BaseModelOutputWithPooling`): + The pooled output of the `last_hidden_state` of the text encoder Model. + speech_model_output (`BaseModelOutputWithPooling`): + The pooled output of the `last_hidden_state` of the speech encoder Model. + decoder_hidden_states (`torch.FloatTensor`, *optional*): + The hidden states of the decoder model. + text_encoder_hidden_states (`torch.FloatTensor`, *optional*): + The hidden states of the text encoder model. + speech_encoder_hidden_states (`torch.FloatTensor`, *optional*): + The hidden states of the speech encoder model. + """ + + loss: Optional[torch.FloatTensor] = None + speech_ids: Optional[torch.LongTensor] = None + logits_per_speech: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + speech_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + speech_model_output: BaseModelOutputWithPooling = None + decoder_hidden_states: Optional[torch.FloatTensor] = None + text_encoder_hidden_states: Optional[torch.FloatTensor] = None + speech_encoder_hidden_states: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Clvp +class ClvpRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + ClvpRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class ClvpRotaryPositionalEmbedding(nn.Module): + """ + Rotary Position Embedding Class for CLVP. It was proposed in the paper 'ROFORMER: ENHANCED TRANSFORMER WITH ROTARY + POSITION EMBEDDING', Please see https://huggingface.co/papers/2104.09864v1.pdf . + """ + + def __init__(self, config): + super().__init__() + dim = max(config.projection_dim // (config.num_attention_heads * 2), 32) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + time_stamps = torch.arange(sequence_length, device=hidden_states.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + self.cached_rotary_positional_embedding = embeddings.unsqueeze(0) + return self.cached_rotary_positional_embedding + + +class ClvpSelfAttention(nn.Module): + """ + Multi-headed attention to combine Absolute and Rotary Positional Embeddings into a single Attention module. + """ + + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.layer_idx = layer_idx + + if hasattr(config, "max_position_embeddings"): + max_positions = config.max_position_embeddings + bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) + bias = bias.view(1, 1, max_positions, max_positions) + self.register_buffer("bias", bias, persistent=False) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.FloatTensor, + rotary_pos_emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[tuple[torch.FloatTensor]]]: + # Raise error when position_ids is None but rotary_pos_emb is provided, because we need that when applying + # rotary_pos_emb to query and key states. + if rotary_pos_emb is not None and position_ids is None: + raise ValueError("`position_ids` must be provided when `rotary_pos_emb` is not None.") + + bsz, _, embed_dim = hidden_states.size() + + # get query proj + query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if past_key_values is not None: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + if rotary_pos_emb is not None: + rotary_emb_dim = rotary_pos_emb.shape[-1] + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., :rotary_emb_dim], + query_states[..., rotary_emb_dim:], + ) + key_rot, key_pass = ( + key_states[..., :rotary_emb_dim], + key_states[..., rotary_emb_dim:], + ) + value_rot, value_pass = ( + value_states[..., :rotary_emb_dim], + value_states[..., rotary_emb_dim:], + ) + + cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0) + query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids) + + # [batch_size, num_heads, seq_length, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + value_states = torch.cat((value_rot, value_pass), dim=-1) + + tgt_len = query_states.shape[2] + src_len = key_states.shape[2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class ClvpGatedLinearUnit(nn.Module): + """ + `ClvpGatedLinearUnit` uses the second half of the `hidden_states` to act as a gate for the first half of the + `hidden_states` which controls the flow of data from the first of the tensor. + """ + + def __init__(self, config): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.activation_fn(gate) + + +class ClvpEncoderMLP(nn.Module): + """ + This MLP is used in CLVP speech or text encoder models. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.fc1 = ClvpGatedLinearUnit(config) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout_layer = nn.Dropout(config.dropout) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.dropout_layer(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class ClvpEncoderLayer(nn.Module): + def __init__(self, config: ClvpConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.self_attn = ClvpSelfAttention(config) + self.mlp = ClvpEncoderMLP(config) + + self.input_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.post_attention_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.FloatTensor, + rotary_pos_emb: torch.FloatTensor, + attention_mask: torch.LongTensor, + position_ids: torch.LongTensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`): + input to the layer. + rotary_pos_emb (`torch.FloatTensor`): + rotary position embeddings generated by `ClvpRotaryPositionalEmbedding` module. + attention_mask (`torch.FloatTensor` of shape `(batch, 1, tgt_len, src_len)`): + attention mask where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor`): + Denotes position ids of the input tokens. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.input_rmsnorm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_rmsnorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, attn_weights + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp +class ClvpSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`ClvpConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: ClvpConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP +class ClvpDecoderMLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class ClvpDecoderLayer(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = ClvpSelfAttention(config, layer_idx=layer_idx) + self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = ClvpDecoderMLP(inner_dim, config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: Optional[tuple[torch.FloatTensor]], + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_outputs = self.attn( + hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attn_output = attn_outputs[0] + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return (hidden_states,) + attn_outputs[1:] + + +class ClvpConditioningEncoder(nn.Module): + """ + This class processes the log-mel spectrograms(extracted by the Feature Extractor) and text tokens(produced by the + tokenizer) as inputs for the decoder model. + + First each log-mel spectrogram is processed into a single vector which captures valuable characteristics from each + of them, then the text tokens are converted into token embeddings and position embeddings are added afterwards. + Both of these vectors are concatenated and then passed to the decoder model. + + The text tokens helps to incorporate the "text information" and the log-mel spectrogram is used to specify the + "voice characteristics" into the generated mel tokens. + """ + + def __init__(self, config: ClvpConfig): + super().__init__() + + self.text_config = config.text_config + self.decoder_config = config.decoder_config + + self.text_token_embedding = nn.Embedding(self.text_config.vocab_size, self.decoder_config.hidden_size) + self.text_position_embedding = nn.Embedding( + self.decoder_config.max_text_tokens, self.decoder_config.hidden_size + ) + + self.mel_conv = nn.Conv1d(self.decoder_config.feature_size, self.decoder_config.hidden_size, kernel_size=1) + + # define group norms to be used before each attention layer + num_groups = self.compute_groupnorm_groups(self.decoder_config.hidden_size) + self.group_norms = nn.ModuleList( + [ + nn.GroupNorm(num_groups, self.decoder_config.hidden_size, eps=1e-5, affine=True) + for _ in range(self.decoder_config.num_mel_attn_blocks) + ] + ) + + # define the attention layers + self.mel_attn_blocks = nn.ModuleList( + [ClvpSelfAttention(self.decoder_config) for _ in range(self.decoder_config.num_mel_attn_blocks)] + ) + + self.gradient_checkpointing = False + + def compute_groupnorm_groups(self, channels: int, groups: int = 32): + """ + Calculates the value of `num_groups` for nn.GroupNorm. This logic is taken from the official tortoise + repository. link : + https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/models/arch_util.py#L26 + """ + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + + if groups <= 2: + raise ValueError( + f"Number of groups for the GroupNorm must be greater than 2, but it is {groups}." + f"Please consider using a different `hidden_size`" + ) + + return groups + + def forward( + self, + input_features: torch.FloatTensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + # process text + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.size() + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # construct attention mask if not given + if attention_mask is None: + attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device) + + # We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple + # This logic is specific to ClvpConditioningEncoder and not used by other modules. + input_ids, attention_mask = _pad_extra_bos_eos_tokens( + input_ids, + attention_mask, + bos_token_id=self.text_config.bos_token_id, + eos_token_id=self.text_config.eos_token_id, + ) + + inputs_embeds = self.text_token_embedding(input_ids) + position_ids = attention_mask.cumsum(-1) - 1 + position_embeds = self.text_position_embedding(position_ids) + text_embeds = inputs_embeds + position_embeds + + if self.gradient_checkpointing and self.training: + # process each log-mel spectrogram into a single vector + mel_spec = torch.utils.checkpoint.checkpoint(self.mel_conv, input_features) + + for i, mel_attn_block in enumerate(self.mel_attn_blocks): + residual_mel_spec = mel_spec.transpose(1, 2) + + mel_spec = torch.utils.checkpoint.checkpoint(self.group_norms[i], mel_spec).transpose(1, 2) + mel_spec = torch.utils.checkpoint.checkpoint(mel_attn_block, mel_spec)[0] + residual_mel_spec + mel_spec = mel_spec.transpose(1, 2) + + else: + # process each log-mel spectrogram into a single vector + mel_spec = self.mel_conv(input_features) + + for i, mel_attn_block in enumerate(self.mel_attn_blocks): + residual_mel_spec = mel_spec.transpose(1, 2) + + mel_spec = self.group_norms[i](mel_spec).transpose(1, 2) + mel_spec = mel_attn_block(mel_spec)[0] + residual_mel_spec + mel_spec = mel_spec.transpose(1, 2) + + mel_spec = mel_spec[:, :, 0] + mel_spec = mel_spec.unsqueeze(1) + + # repeat if there is either (1 text vs N audios) or (N texts vs 1 audio) + if text_embeds.shape[0] == 1 and mel_spec.shape[0] != 1: + text_embeds = text_embeds.repeat(mel_spec.shape[0], 1, 1) + elif text_embeds.shape[0] != 1 and mel_spec.shape[0] == 1: + mel_spec = mel_spec.repeat(text_embeds.shape[0], 1, 1) + # If there is N texts and M audios we will raise error since the number of text and audio must be same. + elif text_embeds.shape[0] != mel_spec.shape[0]: + raise ValueError( + f"The number of texts and number of audios must be same. " + f"Found {text_embeds.shape[0]} texts vs {mel_spec.shape[0]} audios" + ) + + return torch.concat([mel_spec, text_embeds], dim=1) + + +@auto_docstring +class ClvpPreTrainedModel(PreTrainedModel): + config: ClvpConfig + base_model_prefix = "clvp" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=factor * 0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, ClvpRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, ClvpEncoderMLP): + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, ClvpEncoder): + config = self.config.get_text_config() + factor = config.initializer_factor + module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5)) + elif isinstance(module, ClvpConditioningEncoder): + module.mel_conv.weight.data.normal_(mean=0.0, std=factor) + module.mel_conv.bias.data.zero_() + elif isinstance(module, ClvpForCausalLM): + for name, p in module.named_parameters(): + if name == "c_proj.weight": + p.data.normal_( + mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)) + ) + elif isinstance(module, ClvpModelForConditionalGeneration): + module.logit_scale.data.fill_(self.config.logit_scale_init_value) + + if isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class ClvpEncoder(ClvpPreTrainedModel): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`ClvpEncoderLayer`]. + + Args: + config: ClvpConfig + """ + + def __init__(self, config: ClvpConfig): + super().__init__(config) + + self.config = config + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None + self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + self.sequence_summary = ClvpSequenceSummary(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.token_embedding + + def set_input_embeddings(self, value): + self.token_embedding = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + input embeddings for the model. This bypasses the model's internal embedding lookup matrix. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor`, *optional*): + Denotes the position ids of `input_ids`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + inputs_embeds = self.token_embedding(input_ids) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # expand attention_mask and create position_ids if needed + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(input_shape[1], dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + rotary_pos_emb = self.rotary_pos_emb(inputs_embeds) if self.rotary_pos_emb is not None else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + encoder_layer.__call__, + hidden_states, + rotary_pos_emb, + attention_mask, + position_ids, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + rotary_pos_emb, + attention_mask, + position_ids, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + last_hidden_state = hidden_states + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # take the mean over axis 1 and get pooled output + pooled_output = self.sequence_summary(last_hidden_state) + + # apply the projection layer + embeds = self.projection(pooled_output) + + if not return_dict: + return tuple( + v for v in [embeds, last_hidden_state, pooled_output, encoder_states, all_attentions] if v is not None + ) + + return ClvpEncoderOutput( + embeds=embeds, + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class ClvpDecoder(ClvpPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ClvpDecoderLayer`] + """ + + def __init__(self, config): + super().__init__(config) + + self.config = config + + self.input_embeds_layer = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.position_embeds_layer = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size) + + self.drop = nn.Dropout(self.config.embd_pdrop) + self.layers = nn.ModuleList( + [ClvpDecoderLayer(self.config, layer_idx=i) for i in range(self.config.num_hidden_layers)] + ) + self.layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_epsilon) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.input_embeds_layer + + def set_input_embeddings(self, new_embeddings): + self.input_embeds_layer = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.layers[layer].attn.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.input_embeds_layer(input_ids) + position_embeds = self.position_embeds_layer(position_ids) + inputs_embeds = inputs_embeds + position_embeds + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape num_hidden_layers x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.input_embeds_layer(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = torch.utils.checkpoint.checkpoint( + block.__call__, + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + cache_position, + ) + else: + outputs = block( + hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class ClvpModel(ClvpPreTrainedModel): + def __init__(self, config: ClvpDecoderConfig): + super().__init__(config) + self.config = config + self.decoder = ClvpDecoder(self.config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.input_embeds_layer + + def set_input_embeddings(self, value): + self.decoder.input_embeds_layer = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + The CLVP decoder model with a language modelling head on top. + """ +) +class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + + self.config = config + self.model = ClvpModel(self.config) + + self.final_norm = nn.LayerNorm(self.config.hidden_size) + self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=True) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return None + + def get_input_embeddings(self): + return self.model.decoder.input_embeds_layer + + def set_input_embeddings(self, new_embeddings): + self.model.decoder.input_embeds_layer = new_embeddings + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} + + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed." + f"Make sure to either pass {inputs} or {input_name}=..." + ) + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + if input_name == "input_ids" and "inputs_embeds" in model_kwargs: + model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs + ) + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # Check if conditioning_embeds are provided or not, if yes then concatenate the bos_token_id at the end of the conditioning_embeds. + # Then we must subtract the positional_ids because during the forward pass it will be added anyways, so we must cancel them out here. + conditioning_embeds = model_kwargs.get("conditioning_embeds") + + if conditioning_embeds is not None: + mel_start_token_embedding = self.model.decoder.input_embeds_layer( + torch.full( + (conditioning_embeds.shape[0], 1), + fill_value=self.config.bos_token_id, + device=conditioning_embeds.device, + ) + ) + mel_start_token_embedding += self.model.decoder.position_embeds_layer( + torch.full((conditioning_embeds.shape[0], 1), fill_value=0, device=conditioning_embeds.device) + ) + conditioning_embeds = torch.concat([conditioning_embeds, mel_start_token_embedding], dim=1) + + # subtract the positional_ids here + if hasattr(model_kwargs, "attention_mask"): + position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1 + else: + position_ids = torch.arange( + 0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device + ) + position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1) + + model_kwargs["inputs_embeds"] = conditioning_embeds - self.model.decoder.position_embeds_layer( + position_ids + ) + model_kwargs["input_ids"] = ( + torch.ones((model_kwargs["inputs_embeds"].shape[0], 1), dtype=torch.long, device=self.device) + * self.config.bos_token_id + ) + + return model_kwargs["inputs_embeds"], "inputs_embeds", model_kwargs + + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + return inputs, input_name, model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + conditioning_embeds=None, + cache_position=None, + **kwargs, + ): + # Overwritten: has `conditioning_embeds`-related logic + + input_ids_length = input_ids.shape[-1] + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + if conditioning_embeds is not None and cache_position[0] != 0: + model_inputs["position_ids"] = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device) + + return model_inputs + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + lm_logits = self.final_norm(hidden_states) + lm_logits = self.lm_head(lm_logits) + + loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + The composite CLVP model with a text encoder, speech encoder and speech decoder model. + """ +) +class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin): + config: ClvpConfig + + def __init__(self, config: ClvpConfig): + super().__init__(config) + + if not isinstance(config.text_config, ClvpEncoderConfig): + raise TypeError( + "config.text_config is expected to be of type `ClvpEncoderConfig` but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.speech_config, ClvpEncoderConfig): + raise TypeError( + "config.speech_config is expected to be of type `ClvpEncoderConfig` but is of type" + f" {type(config.speech_config)}." + ) + + if not isinstance(config.decoder_config, ClvpDecoderConfig): + raise TypeError( + "config.decoder_config is expected to be of type `ClvpDecoderConfig` but is of type" + f" {type(config.decoder_config)}." + ) + + self.conditioning_encoder = ClvpConditioningEncoder(config) + + self.speech_decoder_model = ClvpForCausalLM(config.decoder_config) + + self.text_encoder_model = ClvpEncoder(config.text_config) + self.speech_encoder_model = ClvpEncoder(config.speech_config) + + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + # taken from the original repo, + # link : https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/api.py#L117 + def fix_speech_decoder_output(self, speech_ids: torch.LongTensor) -> torch.LongTensor: + """ + This method modifies the output of the decoder model, such as replacing the `eos_token_id` and changing the + last few tokens of each sequence. + + Args: + speech_ids (`torch.LongTensor`): + This refers to the output of the decoder model. + """ + decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes + speech_ids = speech_ids[:, 1:] + + stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0) + speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0]) + + for i, each_seq_stop_token_index in enumerate(stop_token_indices): + # This means that no stop tokens were found so the sentence was still being generated, in that case we don't need + # to apply any padding so just skip to the next sequence of tokens. + if each_seq_stop_token_index.sum() == 0: + continue + + stm = each_seq_stop_token_index.argmax() + speech_ids[i, stm:] = decoder_fixing_codes[0] + if stm - 3 < speech_ids.shape[1]: + speech_ids[i, -3:] = torch.tensor( + [decoder_fixing_codes[1:]], device=speech_ids.device, dtype=torch.long + ) + + return speech_ids + + def get_text_features( + self, + input_ids: Optional[torch.LongTensor] = None, + text_encoder_inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + r""" + This method can be used to extract text_embeds from a text. The text embeddings obtained by applying the + projection layer to the pooled output of the CLVP text encoder model. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + [What are input IDs?](../glossary#input-ids) + text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + inputs_embeds for the text encoder model passed in place of `input_ids`. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Returns: + `torch.FloatTensor` of shape `(batch_size, output_dim)`: + The text embeddings obtained by applying the projection layer to the pooled output of the CLVP Text + Model. + + Examples: + + ```python + >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration + + >>> # Define the Text + >>> text = "This is an example text." + + >>> # Define processor and model + >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev") + >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev") + + >>> # Generate processor output and text embeds + >>> processor_output = processor(text=text, return_tensors="pt") + >>> text_embeds = model.get_text_features(input_ids=processor_output["input_ids"]) + ``` + """ + + outputs = self.text_encoder_model( + input_ids=input_ids, + inputs_embeds=text_encoder_inputs_embeds, + attention_mask=attention_mask, + ) + + return outputs[0] + + def get_speech_features( + self, + speech_ids: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + conditioning_encoder_inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ) -> torch.FloatTensor: + r""" + This method can be used to extract speech_embeds. The speech embeddings are obtained by applying the speech + model on speech_ids. If speech_ids is not present but both input_ids and input_features are given then the + decoder model will be used to first generate the speech_ids and then applying the speech model. + + Args: + speech_ids (`torch.LongTensor` of shape `(batch_size, num_speech_ids)`, *optional*): + Speech Tokens. Padding will be ignored by default should you provide it. If speech_ids are provided + then input_ids and input_features will be automatically ignored. + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Input text Tokens. Processed from the [`ClvpTokenizer`]. If speech_ids is not provided, then input_ids + and input_features will be used. + conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding speech token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + generation_config (`GenerationConfig`, *optional*): + generation config to control the generation of speech_ids if they are not provided. + + Returns: + `torch.FloatTensor` of shape `(batch_size, output_dim)`: + The speech embeddings obtained by applying the projection layer to the pooled output of the CLVP Speech + Model. + + Examples: + + ```python + >>> import datasets + >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration + + >>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library) + >>> text = "This is an example text." + >>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050)) + >>> audio = ds.sort("id")["audio"][0] + >>> audio_sample, sr = audio["array"], audio["sampling_rate"] + + >>> # Define processor and model + >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev") + >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev") + + >>> # Generate processor output and model output + >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt") + >>> speech_embeds = model.get_speech_features( + ... input_ids=processor_output["input_ids"], input_features=processor_output["input_features"] + ... ) + ``` + """ + + if speech_ids is None: + if (input_ids is None and conditioning_encoder_inputs_embeds is None) or input_features is None: + raise ValueError( + "Either speech_ids or input_ids/conditioning_encoder_inputs_embeds and input_features must be provided." + ) + + if generation_config is None: + generation_config = self.generation_config + generation_config.update(**kwargs) + + conditioning_embeds = self.conditioning_encoder( + input_features=input_features, + input_ids=input_ids, + inputs_embeds=conditioning_encoder_inputs_embeds, + attention_mask=attention_mask, + ) + + speech_ids = self.speech_decoder_model.generate( + conditioning_embeds=conditioning_embeds, + generation_config=generation_config, + ) + + speech_ids = self.fix_speech_decoder_output(speech_ids[0]) + + outputs = self.speech_encoder_model( + input_ids=speech_ids, + attention_mask=attention_mask, + ) + + return outputs[0] + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + conditioning_encoder_inputs_embeds: Optional[torch.FloatTensor] = None, + text_encoder_inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = False, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, ClvpOutput]: + r""" + conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`. + text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + inputs_embeds for the text encoder model passed in place of `input_ids`. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> import datasets + >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration + + >>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library) + >>> text = "This is an example text." + + >>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050)) + >>> audio = ds.sort("id")["audio"][0] + >>> audio_sample, sr = audio["array"], audio["sampling_rate"] + + >>> # Define processor and model + >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev") + >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev") + + >>> # processor outputs and model outputs + >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt") + >>> outputs = model( + ... input_ids=processor_output["input_ids"], + ... input_features=processor_output["input_features"], + ... return_dict=True, + ... ) + ``` + """ + + # Use CLVP model's config for some fields (if specified) instead of those of speech & text components. + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + conditioning_embeds = self.conditioning_encoder( + input_features=input_features, + input_ids=input_ids, + inputs_embeds=conditioning_encoder_inputs_embeds, + attention_mask=attention_mask, + ) + + decoder_outputs = self.speech_decoder_model( + inputs_embeds=conditioning_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + speech_ids = decoder_outputs[0] + + # since we will get the embeds of shape `(batch_size, seq_len, embedding_dim)` during the forward pass + # we must convert it to tokens, to make it compaitable with speech_transformer + if speech_ids.ndim == 3: + speech_ids = speech_ids.argmax(2) + speech_ids = self.fix_speech_decoder_output(speech_ids) + + speech_outputs = self.speech_encoder_model( + input_ids=speech_ids, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_encoder_model( + input_ids=input_ids, + inputs_embeds=text_encoder_inputs_embeds, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + speech_embeds = speech_outputs[0] + text_embeds = text_outputs[0] + + # normalized features + speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale + logits_per_speech = logits_per_text.t() + + loss = None + if return_loss: + loss = clvp_loss(logits_per_text) + + if not return_dict: + output = ( + logits_per_speech, + logits_per_text, + text_embeds, + speech_embeds, + text_outputs[2], + speech_outputs[2], + ) + if output_hidden_states: + output += ( + decoder_outputs[-1], + text_outputs[-1], + speech_outputs[-1], + ) + + return ((loss,) + output) if loss is not None else output + + return ClvpOutput( + loss=loss, + logits_per_speech=logits_per_speech, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + speech_embeds=speech_embeds, + text_model_output=text_outputs[2], + speech_model_output=speech_outputs[2], + decoder_hidden_states=decoder_outputs.hidden_states, + text_encoder_hidden_states=text_outputs.hidden_states, + speech_encoder_hidden_states=speech_outputs.hidden_states, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + generation_config: Optional[GenerationConfig] = None, + pad_to_max_mel_tokens: Optional[int] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ): + """ + Generate method for `ClvpModelForConditionalGeneration`, this method calls the `generate` method of + `ClvpForCausalLM` and then uses those generated `speech_ids` to process `text_embeds` and `speech_embeds` using + `ClvpEncoder`. + + Args: + input_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Input text Tokens. Processed from the [`ClvpTokenizer`]. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding text token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + pad_to_max_mel_tokens (`int`, *optional*): + Pads generated speech_ids to the specified value. This is to implement the same logic from the official + repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430 + and to make sure the logits are same. + This does not affect generation quality so please don't consider using it since it is less efficient. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of decoder model, text encoder and speech encoder models. + + Returns: + `ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when + `config.return_dict_in_generate=True`) or a tuple. + """ + + # If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error, + # because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to + # properly sample + sequence_length = input_ids.shape[-1] + if sequence_length > (self.config.decoder_config.max_text_tokens - 3): + raise ValueError( + f"Maximum sequence length reached! Found input_ids of length {sequence_length}." + f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}" + ) + + if generation_config is None: + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + # pad input_ids as specified in the original repo + # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380 + input_ids, attention_mask = _pad_extra_bos_eos_tokens( + input_ids, + attention_mask, + add_bos_token=False, + bos_token_id=self.config.text_config.bos_token_id, + eos_token_id=self.config.text_config.eos_token_id, + ) + + conditioning_embeds = self.conditioning_encoder( + input_features=input_features, + input_ids=input_ids, + attention_mask=attention_mask, + ) + + decoder_outputs = self.speech_decoder_model.generate( + conditioning_embeds=conditioning_embeds, + generation_config=generation_config, + output_hidden_states=output_hidden_states, + return_dict=generation_config.return_dict_in_generate, + ) + if isinstance(decoder_outputs, ModelOutput): + speech_ids = decoder_outputs.sequences + + # pad to pad_to_max_mel_tokens if given, to replicate the original repo logic + # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430 + if pad_to_max_mel_tokens is not None: + padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1] + speech_ids = torch.nn.functional.pad( + speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id + ) + + speech_ids = self.fix_speech_decoder_output(speech_ids) + + speech_outputs = self.speech_encoder_model( + input_ids=speech_ids, + output_hidden_states=output_hidden_states, + return_dict=generation_config.return_dict_in_generate, + ) + text_outputs = self.text_encoder_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + return_dict=generation_config.return_dict_in_generate, + ) + + speech_embeds = speech_outputs[0] + text_embeds = text_outputs[0] + + # normalized features + speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale + logits_per_speech = logits_per_text.t() + + if not generation_config.return_dict_in_generate: + output = ( + speech_ids, + logits_per_speech, + logits_per_text, + text_embeds, + speech_embeds, + text_outputs[2], + speech_outputs[2], + ) + if output_hidden_states: + output += ( + decoder_outputs[-1], + text_outputs[-1], + speech_outputs[-1], + ) + + return output + + return ClvpOutput( + speech_ids=speech_ids, + logits_per_speech=logits_per_speech, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + speech_embeds=speech_embeds, + text_model_output=text_outputs[2], + speech_model_output=speech_outputs[2], + decoder_hidden_states=decoder_outputs.hidden_states, + text_encoder_hidden_states=text_outputs.hidden_states, + speech_encoder_hidden_states=speech_outputs.hidden_states, + ) + + +__all__ = [ + "ClvpModelForConditionalGeneration", + "ClvpForCausalLM", + "ClvpModel", + "ClvpPreTrainedModel", + "ClvpEncoder", + "ClvpDecoder", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/clvp/number_normalizer.py b/venv/lib/python3.13/site-packages/transformers/models/clvp/number_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b8dc085b8beffe02536669e6f0b04be38348184a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/clvp/number_normalizer.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""English Normalizer class for CLVP.""" + +import sys + + +if sys.version_info >= (3, 11): + # Atomic grouping support was only added to the core RE in Python 3.11 + import re +else: + import regex as re + + +class EnglishNormalizer: + def __init__(self): + # List of (regular expression, replacement) pairs for abbreviations: + self._abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] + ] + + self.ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] + self.teens = [ + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ] + self.tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] + + def number_to_words(self, num: int) -> str: + """ + Converts numbers(`int`) to words(`str`). + + Please note that it only supports upto - "'nine hundred ninety-nine quadrillion, nine hundred ninety-nine + trillion, nine hundred ninety-nine billion, nine hundred ninety-nine million, nine hundred ninety-nine + thousand, nine hundred ninety-nine'" or `number_to_words(999_999_999_999_999_999)`. + """ + if num == 0: + return "zero" + elif num < 0: + return "minus " + self.number_to_words(abs(num)) + elif num < 10: + return self.ones[num] + elif num < 20: + return self.teens[num - 10] + elif num < 100: + return self.tens[num // 10] + ("-" + self.number_to_words(num % 10) if num % 10 != 0 else "") + elif num < 1000: + return ( + self.ones[num // 100] + " hundred" + (" " + self.number_to_words(num % 100) if num % 100 != 0 else "") + ) + elif num < 1_000_000: + return ( + self.number_to_words(num // 1000) + + " thousand" + + (", " + self.number_to_words(num % 1000) if num % 1000 != 0 else "") + ) + elif num < 1_000_000_000: + return ( + self.number_to_words(num // 1_000_000) + + " million" + + (", " + self.number_to_words(num % 1_000_000) if num % 1_000_000 != 0 else "") + ) + elif num < 1_000_000_000_000: + return ( + self.number_to_words(num // 1_000_000_000) + + " billion" + + (", " + self.number_to_words(num % 1_000_000_000) if num % 1_000_000_000 != 0 else "") + ) + elif num < 1_000_000_000_000_000: + return ( + self.number_to_words(num // 1_000_000_000_000) + + " trillion" + + (", " + self.number_to_words(num % 1_000_000_000_000) if num % 1_000_000_000_000 != 0 else "") + ) + elif num < 1_000_000_000_000_000_000: + return ( + self.number_to_words(num // 1_000_000_000_000_000) + + " quadrillion" + + ( + ", " + self.number_to_words(num % 1_000_000_000_000_000) + if num % 1_000_000_000_000_000 != 0 + else "" + ) + ) + else: + return "number out of range" + + def convert_to_ascii(self, text: str) -> str: + """ + Converts unicode to ascii + """ + return text.encode("ascii", "ignore").decode("utf-8") + + def _expand_dollars(self, m: str) -> str: + """ + This method is used to expand numerical dollar values into spoken words. + """ + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + def _remove_commas(self, m: str) -> str: + """ + This method is used to remove commas from sentences. + """ + return m.group(1).replace(",", "") + + def _expand_decimal_point(self, m: str) -> str: + """ + This method is used to expand '.' into spoken word ' point '. + """ + return m.group(1).replace(".", " point ") + + def _expand_ordinal(self, num: str) -> str: + """ + This method is used to expand ordinals such as '1st', '2nd' into spoken words. + """ + ordinal_suffixes = {1: "st", 2: "nd", 3: "rd"} + + num = int(num.group(0)[:-2]) + if 10 <= num % 100 and num % 100 <= 20: + suffix = "th" + else: + suffix = ordinal_suffixes.get(num % 10, "th") + return self.number_to_words(num) + suffix + + def _expand_number(self, m: str) -> str: + """ + This method acts as a preprocessing step for numbers between 1000 and 3000 (same as the original repository, + link : + https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/utils/tokenizer.py#L86) + """ + num = int(m.group(0)) + + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + self.number_to_words(num % 100) + elif num % 100 == 0: + return self.number_to_words(num // 100) + " hundred" + else: + return self.number_to_words(num) + else: + return self.number_to_words(num) + + def normalize_numbers(self, text: str) -> str: + """ + This method is used to normalize numbers within a text such as converting the numbers to words, removing + commas, etc. + """ + text = re.sub(r"([0-9][0-9,]+[0-9])", self._remove_commas, text) + text = re.sub(r"£([0-9,]*[0-9])", r"\1 pounds", text) + text = re.sub(r"\$([0-9.,]*[0-9])", self._expand_dollars, text) + text = re.sub(r"([0-9]++\.[0-9]+)", self._expand_decimal_point, text) + text = re.sub(r"[0-9]++(st|nd|rd|th)", self._expand_ordinal, text) + text = re.sub(r"[0-9]+", self._expand_number, text) + return text + + def expand_abbreviations(self, text: str) -> str: + """ + Expands the abbreviate words. + """ + for regex, replacement in self._abbreviations: + text = re.sub(regex, replacement, text) + return text + + def collapse_whitespace(self, text: str) -> str: + """ + Removes multiple whitespaces + """ + return re.sub(re.compile(r"\s+"), " ", text) + + def __call__(self, text): + """ + Converts text to ascii, numbers / number-like quantities to their spelt-out counterparts and expands + abbreviations + """ + + text = self.convert_to_ascii(text) + text = text.lower() + text = self.normalize_numbers(text) + text = self.expand_abbreviations(text) + text = self.collapse_whitespace(text) + text = text.replace('"', "") + + return text diff --git a/venv/lib/python3.13/site-packages/transformers/models/clvp/processing_clvp.py b/venv/lib/python3.13/site-packages/transformers/models/clvp/processing_clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..8fad43cd2f30e1bde11e3f94e8393449d7026126 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/clvp/processing_clvp.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for CLVP +""" + +from ...processing_utils import ProcessorMixin +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ClvpProcessor(ProcessorMixin): + r""" + Constructs a CLVP processor which wraps a CLVP Feature Extractor and a CLVP Tokenizer into a single processor. + + [`ClvpProcessor`] offers all the functionalities of [`ClvpFeatureExtractor`] and [`ClvpTokenizer`]. See the + [`~ClvpProcessor.__call__`], [`~ClvpProcessor.decode`] and [`~ClvpProcessor.batch_decode`] for more information. + + Args: + feature_extractor (`ClvpFeatureExtractor`): + An instance of [`ClvpFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`ClvpTokenizer`): + An instance of [`ClvpTokenizer`]. The tokenizer is a required input. + """ + + feature_extractor_class = "ClvpFeatureExtractor" + tokenizer_class = "ClvpTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__(self, *args, **kwargs): + """ + Forwards the `audio` and `sampling_rate` arguments to [`~ClvpFeatureExtractor.__call__`] and the `text` + argument to [`~ClvpTokenizer.__call__`]. Please refer to the docstring of the above two methods for more + information. + """ + raw_speech = kwargs.pop("raw_speech", None) + if raw_speech is not None: + logger.warning( + "Using `raw_speech` keyword argument is deprecated when calling ClvpProcessor, instead use `audio`." + ) + kwargs["audio"] = raw_speech + return super().__call__(*args, **kwargs) + + +__all__ = ["ClvpProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/clvp/tokenization_clvp.py b/venv/lib/python3.13/site-packages/transformers/models/clvp/tokenization_clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0b285561c5475a624c4e33de429ec204951ae3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/clvp/tokenization_clvp.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for CLVP.""" + +import json +import os +from functools import lru_cache +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from .number_normalizer import EnglishNormalizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +@lru_cache +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class ClvpTokenizer(PreTrainedTokenizer): + """ + Construct a CLVP tokenizer. Based on byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import ClvpTokenizer + + >>> tokenizer = ClvpTokenizer.from_pretrained("susnato/clvp_dev") + >>> tokenizer("Hello world")["input_ids"] + [62, 84, 28, 2, 179, 79] + + >>> tokenizer(" Hello world")["input_ids"] + [2, 62, 84, 28, 2, 179, 79] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `"[STOP]"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"[STOP]"`): + The pad token of the sequence. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (CLVP tokenizer detect beginning of words by the preceding space). + add_bos_token (`bool`, *optional*, defaults to `False`): + Whether to add `bos_token` in front of the sequence when add_special_tokens=True. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether to add `eos_token` in end of the sequence when add_special_tokens=True. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = [ + "input_ids", + "attention_mask", + ] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="[UNK]", + bos_token="<|endoftext|>", + eos_token="[STOP]", + pad_token="[STOP]", + add_prefix_space=False, + add_bos_token=False, + add_eos_token=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self._normalizer = None + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_prefix_space=add_prefix_space, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.encoder) + + @property + def normalizer(self): + if self._normalizer is None: + self._normalizer = EnglishNormalizer() + return self._normalizer + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if not self.add_bos_token: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + text = self.normalizer(text) + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + + # if the token is "Ġ" we replace it with "[SPACE]" (if "[SPACE]" is present in the vocab), otherwise we keep the "Ġ". + bpe_tokens.extend( + "[SPACE]" if bpe_token == "\u0120" and "[SPACE]" in self.encoder else bpe_token + for bpe_token in self.bpe(token).split(" ") + ) + + return bpe_tokens + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def clean_up_tokenization(self, text): + text = "".join(text) + vocab_tokens = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys()) + + text = text.replace("[SPACE]", " ") if "[SPACE]" in vocab_tokens else text + text = text.replace("[STOP]", " ") if "[STOP]" in vocab_tokens else text + + text = text.replace(self.unk_token, "").replace(" ", " ").replace(" ", " ") + return text + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + +__all__ = ["ClvpTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/colqwen2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d313293c0a0a7b5c56ac5bec1e54e8eb1e5a6e4a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_colqwen2 import * + from .modeling_colqwen2 import * + from .processing_colqwen2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/colqwen2/configuration_colqwen2.py b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/configuration_colqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..21f6e46f1f0037a58503b7a0432f6c3f16657f58 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/configuration_colqwen2.py @@ -0,0 +1,92 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from copy import deepcopy +from typing import Any + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class ColQwen2Config(PretrainedConfig): + r""" + Configuration class to store the configuration of a [`ColQ2en2ForRetrieval`]. It is used to instantiate an instance + of `ColQwen2ForRetrieval` according to the specified arguments, defining the model architecture following the methodology + from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. + + Instantiating a configuration with the defaults will yield a similar configuration to the vision encoder used by the pre-trained + ColQwen2-v1.0 model, e.g. [vidore/colqwen2-v1.0-hf](https://huggingface.co/vidore/colqwen2-v1.0-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vlm_config (`PretrainedConfig`, *optional*): + Configuration of the VLM backbone model. + embedding_dim (`int`, *optional*, defaults to 128): + Dimension of the multi-vector embeddings produced by the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Example: + + ```python + from transformers.models.colqwen2 import ColQwen2Config, ColQwen2ForRetrieval + + config = ColQwen2Config() + model = ColQwen2ForRetrieval(config) + ``` + """ + + model_type = "colqwen2" + sub_configs: dict[str, Any] = {"vlm_config": PretrainedConfig} + + def __init__( + self, + vlm_config=None, + embedding_dim: int = 128, + initializer_range: float = 0.02, + **kwargs, + ): + if vlm_config is None: + vlm_config = CONFIG_MAPPING["qwen2_vl"]() + logger.info( + "`vlm_config` is `None`. Initializing `vlm_config` with the `Qwen2VLConfig` with default values." + ) + elif isinstance(vlm_config, dict): + vlm_config = deepcopy(vlm_config) + if "model_type" not in vlm_config: + raise KeyError( + "The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type." + ) + vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config) + elif not isinstance(vlm_config, PretrainedConfig): + raise TypeError( + f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}." + ) + + self.vlm_config = vlm_config + self.embedding_dim = embedding_dim + self.initializer_range = initializer_range + super().__init__(**kwargs) + + def get_text_config(self, *args, **kwargs) -> PretrainedConfig: + return self.vlm_config.get_text_config(*args, **kwargs) + + +__all__ = ["ColQwen2Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/colqwen2/modeling_colqwen2.py b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/modeling_colqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0f585531aedd4703bac149ebef450fd529c548 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/modeling_colqwen2.py @@ -0,0 +1,248 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_colqwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +from torch import nn + +from transformers import AutoModelForImageTextToText + +from ...cache_utils import Cache +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available +from .configuration_colqwen2 import ColQwen2Config + + +if is_torch_available(): + import torch + + +@auto_docstring +class ColQwen2PreTrainedModel(PreTrainedModel): + config: ColQwen2Config + base_model_prefix = "model" + _no_split_modules = [] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.vlm_config.text_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for ColQwen2 embeddings output. + """ +) +class ColQwen2ForRetrievalOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The embeddings of the model. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + embeddings: Optional[torch.Tensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring( + custom_intro=""" + Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly + from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity + between these document embeddings and the corresponding query embeddings, using the late interaction method + introduced in ColBERT. + + Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with + a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. + + ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper: + [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449). + """ +) +class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: ColQwen2Config): + super().__init__(config) + self.config = config + self.vocab_size = config.vlm_config.text_config.vocab_size + + self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config) + + self.embedding_dim = self.config.embedding_dim + self.embedding_proj_layer = nn.Linear( + self.config.vlm_config.text_config.hidden_size, + self.embedding_dim, + ) + self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] + + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> ColQwen2ForRetrievalOutput: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding + if pixel_values is not None and image_grid_thw is not None: + # NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size) + offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w) + pixel_values = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)], + dim=0, + ) # (num_patches_h * num_patches_w, pixel_values) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + position_ids, rope_deltas = self.vlm.model.get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=None, + attention_mask=attention_mask, + ) + + # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. + if inputs_embeds is None: + inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) + + if pixel_values is not None: + pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) + image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = ( + (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + vlm_output = self.vlm.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None + + last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size) + proj_dtype = self.embedding_proj_layer.weight.dtype + embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim) + + # L2 normalization + embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + + return ColQwen2ForRetrievalOutput( + embeddings=embeddings, + past_key_values=vlm_output.past_key_values, + hidden_states=vlm_hidden_states, + attentions=vlm_output.attentions, + ) + + def get_input_embeddings(self): + return self.vlm.get_input_embeddings() + + def set_input_embeddings(self, value): + self.vlm.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.vlm.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.vlm.set_output_embeddings(new_embeddings) + + def tie_weights(self): + return self.vlm.tie_weights() + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + model_embeds = self.vlm.resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + + self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vlm_config.vocab_size = model_embeds.num_embeddings + self.vlm.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + + return model_embeds + + +__all__ = ["ColQwen2ForRetrieval", "ColQwen2PreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/colqwen2/modular_colqwen2.py b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/modular_colqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ae79abf6fa2e3825a84b5c65e6e6fc416caa1d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/modular_colqwen2.py @@ -0,0 +1,410 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Union + +from ...cache_utils import Cache +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging +from ..colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel +from ..colpali.processing_colpali import ColPaliProcessor +from .configuration_colqwen2 import ColQwen2Config + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "longest", + }, + "images_kwargs": { + "data_format": "channels_first", + "do_convert_rgb": True, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class ColQwen2Processor(ColPaliProcessor): + r""" + Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as + well as to compute the late-interaction retrieval score. + + [`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`] + for more information. + + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens. + query_prefix (`str`, *optional*): A prefix to be used for the query. + """ + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + visual_prompt_prefix: Optional[str] = None, + query_prefix: Optional[str] = None, + **kwargs, + ): + ProcessorMixin.__init__(self, image_processor, tokenizer, chat_template=chat_template) + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + + if visual_prompt_prefix is None: + visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" + self.visual_prompt_prefix = visual_prompt_prefix + + if query_prefix is None: + query_prefix = "Query: " + self.query_prefix = query_prefix + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ColQwen2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom + wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process + both text and images at the same time. + + When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's + [`~Qwen2TokenizerFast.__call__`]. + When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's + [`~Qwen2VLImageProcessor.__call__`]. + Please refer to the doctsring of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + ColQwen2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) + + return_token_type_ids = suffix is not None + + if text is None and images is None: + raise ValueError("Either text or images must be provided") + if text is not None and images is not None: + raise ValueError("Only one of text or images can be processed at a time") + + if images is not None: + if is_valid_image(images): + images = [images] + elif isinstance(images, list) and is_valid_image(images[0]): + pass + elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): + raise ValueError("images must be an image, list of images or list of list of images") + + texts_doc = [self.visual_prompt_prefix] * len(images) + + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(texts_doc)): + while self.image_token in texts_doc[i]: + texts_doc[i] = texts_doc[i].replace( + self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1 + ) + index += 1 + texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token) + + text_inputs = self.tokenizer( + texts_doc, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return_data = BatchFeature(data={**text_inputs, **image_inputs}) + + # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs. + offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,) + + # Split the pixel_values tensor into a list of tensors, one per image + pixel_values = list( + torch.split(return_data["pixel_values"], offsets.tolist()) + ) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)] + + # Pad the list of pixel_value tensors to the same length along the sequence dimension + return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence( + pixel_values, batch_first=True + ) # (batch_size, max_num_patches, pixel_values) + + if return_token_type_ids: + labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + + return return_data + + elif text is not None: + if isinstance(text, str): + text = [text] + elif not (isinstance(text, list) and isinstance(text[0], str)): + raise ValueError("Text must be a string or a list of strings") + + if suffix is None: + suffix = self.query_augmentation_token * 10 + + texts_query: list[str] = [] + + for query in text: + augmented_query = self.query_prefix + query + suffix + texts_query.append(augmented_query) + + batch_query = self.tokenizer( + texts_query, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return batch_query + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = ColQwen2ProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + + # ColQwen doesn't process videos. Make a copy of list when removing + # otherwise `self.feature_extractor.model_input_names` is also modified + image_processor_input_names = [ + name for name in image_processor_input_names if name not in ["pixel_values_videos", "video_grid_thw"] + ] + return tokenizer_input_names + image_processor_input_names + + +class ColQwen2PreTrainedModel(ColPaliPreTrainedModel): + pass + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for ColQwen2 embeddings output. + """ +) +class ColQwen2ForRetrievalOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The embeddings of the model. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + embeddings: Optional[torch.Tensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring( + custom_intro=""" + Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly + from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity + between these document embeddings and the corresponding query embeddings, using the late interaction method + introduced in ColBERT. + + Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with + a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. + + ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper: + [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449). + """ +) +class ColQwen2ForRetrieval(ColPaliForRetrieval): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: ColQwen2Config): + super().__init__(config) + del self._tied_weights_keys + self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> ColQwen2ForRetrievalOutput: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding + if pixel_values is not None and image_grid_thw is not None: + # NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size) + offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w) + pixel_values = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)], + dim=0, + ) # (num_patches_h * num_patches_w, pixel_values) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + position_ids, rope_deltas = self.vlm.model.get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=None, + attention_mask=attention_mask, + ) + + # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. + if inputs_embeds is None: + inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) + + if pixel_values is not None: + pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) + image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = ( + (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + vlm_output = self.vlm.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None + + last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size) + proj_dtype = self.embedding_proj_layer.weight.dtype + embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim) + + # L2 normalization + embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + + return ColQwen2ForRetrievalOutput( + embeddings=embeddings, + past_key_values=vlm_output.past_key_values, + hidden_states=vlm_hidden_states, + attentions=vlm_output.attentions, + ) + + +__all__ = [ + "ColQwen2ForRetrieval", + "ColQwen2PreTrainedModel", + "ColQwen2Processor", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/colqwen2/processing_colqwen2.py b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/processing_colqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..1609f6e182dad154f3b85a2a5835dc027024550e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/colqwen2/processing_colqwen2.py @@ -0,0 +1,407 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_colqwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + + +class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "longest", + }, + "images_kwargs": { + "data_format": "channels_first", + "do_convert_rgb": True, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class ColQwen2Processor(ProcessorMixin): + r""" + Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as + well as to compute the late-interaction retrieval score. + + [`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`] + for more information. + + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens. + query_prefix (`str`, *optional*): A prefix to be used for the query. + """ + + attributes = ["image_processor", "tokenizer"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + visual_prompt_prefix: Optional[str] = None, + query_prefix: Optional[str] = None, + **kwargs, + ): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + + if visual_prompt_prefix is None: + visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" + self.visual_prompt_prefix = visual_prompt_prefix + + if query_prefix is None: + query_prefix = "Query: " + self.query_prefix = query_prefix + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ColQwen2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom + wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process + both text and images at the same time. + + When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's + [`~Qwen2TokenizerFast.__call__`]. + When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's + [`~Qwen2VLImageProcessor.__call__`]. + Please refer to the doctsring of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + ColQwen2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) + + return_token_type_ids = suffix is not None + + if text is None and images is None: + raise ValueError("Either text or images must be provided") + if text is not None and images is not None: + raise ValueError("Only one of text or images can be processed at a time") + + if images is not None: + if is_valid_image(images): + images = [images] + elif isinstance(images, list) and is_valid_image(images[0]): + pass + elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): + raise ValueError("images must be an image, list of images or list of list of images") + + texts_doc = [self.visual_prompt_prefix] * len(images) + + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(texts_doc)): + while self.image_token in texts_doc[i]: + texts_doc[i] = texts_doc[i].replace( + self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1 + ) + index += 1 + texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token) + + text_inputs = self.tokenizer( + texts_doc, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return_data = BatchFeature(data={**text_inputs, **image_inputs}) + + # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs. + offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,) + + # Split the pixel_values tensor into a list of tensors, one per image + pixel_values = list( + torch.split(return_data["pixel_values"], offsets.tolist()) + ) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)] + + # Pad the list of pixel_value tensors to the same length along the sequence dimension + return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence( + pixel_values, batch_first=True + ) # (batch_size, max_num_patches, pixel_values) + + if return_token_type_ids: + labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + + return return_data + + elif text is not None: + if isinstance(text, str): + text = [text] + elif not (isinstance(text, list) and isinstance(text[0], str)): + raise ValueError("Text must be a string or a list of strings") + + if suffix is None: + suffix = self.query_augmentation_token * 10 + + texts_query: list[str] = [] + + for query in text: + augmented_query = self.query_prefix + query + suffix + texts_query.append(augmented_query) + + batch_query = self.tokenizer( + texts_query, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return batch_query + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = ColQwen2ProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + @property + def query_augmentation_token(self) -> str: + """ + Return the query augmentation token. + + Query augmentation buffers are used as reasoning buffers during inference. + """ + return self.tokenizer.pad_token + + def process_images( + self, + images: Optional[ImageInput] = None, + **kwargs: Unpack[ColQwen2ProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's + [`ColQwen2Processor.__call__`]. + + This method forwards the `images` and `kwargs` arguments to the image processor. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + return self.__call__(images=images, **kwargs) + + def process_queries( + self, + text: Union[TextInput, list[TextInput]], + **kwargs: Unpack[ColQwen2ProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's + [`ColQwen2Processor.__call__`]. + + This method forwards the `text` and `kwargs` arguments to the tokenizer. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + """ + return self.__call__(text=text, **kwargs) + + def score_retrieval( + self, + query_embeddings: Union["torch.Tensor", list["torch.Tensor"]], + passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]], + batch_size: int = 128, + output_dtype: Optional["torch.dtype"] = None, + output_device: Union["torch.device", str] = "cpu", + ) -> "torch.Tensor": + """ + Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector + query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the + image of a document page. + + Because the embedding tensors are multi-vector and can thus have different shapes, they + should be fed as: + (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) + (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually + obtained by padding the list of tensors. + + Args: + query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings. + passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings. + batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. + output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. + If `None`, the dtype of the input embeddings is used. + output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. + + Returns: + `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score + tensor is saved on the "cpu" device. + """ + + if len(query_embeddings) == 0: + raise ValueError("No queries provided") + if len(passage_embeddings) == 0: + raise ValueError("No passages provided") + + if query_embeddings[0].device != passage_embeddings[0].device: + raise ValueError("Queries and passages must be on the same device") + + if query_embeddings[0].dtype != passage_embeddings[0].dtype: + raise ValueError("Queries and passages must have the same dtype") + + if output_dtype is None: + output_dtype = query_embeddings[0].dtype + + scores: list[torch.Tensor] = [] + + for i in range(0, len(query_embeddings), batch_size): + batch_scores: list[torch.Tensor] = [] + batch_queries = torch.nn.utils.rnn.pad_sequence( + query_embeddings[i : i + batch_size], batch_first=True, padding_value=0 + ) + for j in range(0, len(passage_embeddings), batch_size): + batch_passages = torch.nn.utils.rnn.pad_sequence( + passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0 + ) + batch_scores.append( + torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) + ) + scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) + + return torch.cat(scores, dim=0) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + + # ColQwen doesn't process videos. Make a copy of list when removing + # otherwise `self.feature_extractor.model_input_names` is also modified + image_processor_input_names = [ + name for name in image_processor_input_names if name not in ["pixel_values_videos", "video_grid_thw"] + ] + return tokenizer_input_names + image_processor_input_names + + +__all__ = ["ColQwen2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2979f2bd028fad29f97427754e25969d60683425 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_conditional_detr import * + from .feature_extraction_conditional_detr import * + from .image_processing_conditional_detr import * + from .image_processing_conditional_detr_fast import * + from .modeling_conditional_detr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/configuration_conditional_detr.py b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/configuration_conditional_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..26a8ed05140c681c8eb9a97d6d34226afab26719 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/configuration_conditional_detr.py @@ -0,0 +1,286 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conditional DETR model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class ConditionalDetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ConditionalDetrModel`]. It is used to instantiate + a Conditional DETR model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Conditional DETR + [microsoft/conditional-detr-resnet-50](https://huggingface.co/microsoft/conditional-detr-resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_timm_backbone (`bool`, *optional*, defaults to `True`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which + case it will default to `ResNetConfig()`. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_queries (`int`, *optional*, defaults to 100): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`ConditionalDetrModel`] can detect in a single image. For COCO, we recommend 100 queries. + d_model (`int`, *optional*, defaults to 256): + This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float`, *optional*, defaults to 1): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (`str`, *optional*, defaults to `"sine"`): + Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. + backbone (`str`, *optional*, defaults to `"resnet50"`): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `True`): + Whether to use pretrained weights for the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + dilation (`bool`, *optional*, defaults to `False`): + Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when + `use_timm_backbone` = `True`. + class_cost (`float`, *optional*, defaults to 1): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + mask_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Focal loss in the panoptic segmentation loss. + dice_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + focal_alpha (`float`, *optional*, defaults to 0.25): + Alpha parameter in the focal loss. + + Examples: + + ```python + >>> from transformers import ConditionalDetrConfig, ConditionalDetrModel + + >>> # Initializing a Conditional DETR microsoft/conditional-detr-resnet-50 style configuration + >>> configuration = ConditionalDetrConfig() + + >>> # Initializing a model (with random weights) from the microsoft/conditional-detr-resnet-50 style configuration + >>> model = ConditionalDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "conditional_detr" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + use_timm_backbone=True, + backbone_config=None, + num_channels=3, + num_queries=300, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + init_xavier_std=1.0, + auxiliary_loss=False, + position_embedding_type="sine", + backbone="resnet50", + use_pretrained_backbone=True, + backbone_kwargs=None, + dilation=False, + class_cost=2, + bbox_cost=5, + giou_cost=2, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + cls_loss_coefficient=2, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + focal_alpha=0.25, + **kwargs, + ): + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): + if backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.use_timm_backbone = use_timm_backbone + self.backbone_config = backbone_config + self.num_channels = num_channels + self.num_queries = num_queries + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.num_hidden_layers = encoder_layers + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.backbone_kwargs = backbone_kwargs + self.dilation = dilation + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.cls_loss_coefficient = cls_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.focal_alpha = focal_alpha + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + +class ConditionalDetrOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("pixel_mask", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 + + @property + def default_onnx_opset(self) -> int: + return 12 + + +__all__ = ["ConditionalDetrConfig", "ConditionalDetrOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/feature_extraction_conditional_detr.py b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/feature_extraction_conditional_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..eeed4db007e5a6b19fbc0f83bb3983d344b93e7f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/feature_extraction_conditional_detr.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Conditional DETR.""" + +import warnings + +from ...image_transforms import rgb_to_id as _rgb_to_id +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_conditional_detr import ConditionalDetrImageProcessor + + +logger = logging.get_logger(__name__) + + +def rgb_to_id(x): + warnings.warn( + "rgb_to_id has moved and will not be importable from this module from v5. " + "Please import from transformers.image_transforms instead.", + FutureWarning, + ) + return _rgb_to_id(x) + + +@requires(backends=("vision",)) +class ConditionalDetrFeatureExtractor(ConditionalDetrImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class ConditionalDetrFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use ConditionalDetrImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["ConditionalDetrFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/image_processing_conditional_detr.py b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/image_processing_conditional_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..9221c463fe8536fd1e3a6523962192ef15515520 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -0,0 +1,1858 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Conditional DETR.""" + +import io +import pathlib +from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import BaseImageProcessor, get_size_dict +from ...image_transforms import ( + PaddingMode, + center_to_corners_format, + corners_to_center_format, + id_to_rgb, + pad, + rescale, + resize, + rgb_to_id, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_annotations, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + is_flax_available, + is_jax_tensor, + is_scipy_available, + is_tf_available, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + is_vision_available, + logging, +) +from ...utils.import_utils import requires + + +if is_torch_available(): + import torch + from torch import nn + + +if is_vision_available(): + import PIL + + +if is_scipy_available(): + import scipy.special + import scipy.stats + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + + +# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + +# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, tuple[int, int], list[int]], + max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. If the desired output size + is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output + image size is computed by keeping the aspect ratio of the input image size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `tuple[int, int]` or `list[int]`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + if isinstance(size, (list, tuple)): + return size + + return get_size_with_aspect_ratio(image_size, size, max_size) + + +# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width +def get_image_size_for_max_height_width( + input_image: np.ndarray, + max_height: int, + max_width: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + input_image (`np.ndarray`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn +def get_numpy_to_framework_fn(arr) -> Callable: + """ + Returns a function that converts a numpy array to the framework of the input array. + + Args: + arr (`np.ndarray`): The array to convert. + """ + if isinstance(arr, np.ndarray): + return np.array + if is_tf_available() and is_tf_tensor(arr): + import tensorflow as tf + + return tf.convert_to_tensor + if is_torch_available() and is_torch_tensor(arr): + import torch + + return torch.tensor + if is_flax_available() and is_jax_tensor(arr): + import jax.numpy as jnp + + return jnp.array + raise ValueError(f"Cannot convert arrays of type {type(arr)}") + + +# Copied from transformers.models.detr.image_processing_detr.safe_squeeze +def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: + """ + Squeezes an array, but only if the axis specified has dim 1. + """ + if axis is None: + return arr.squeeze() + + try: + return arr.squeeze(axis=axis) + except ValueError: + return arr + + +# Copied from transformers.models.detr.image_processing_detr.normalize_annotation +def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + +# Copied from transformers.models.detr.image_processing_detr.max_across_indices +def max_across_indices(values: Iterable[Any]) -> list[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.detr.image_processing_detr.get_max_height_width +def get_max_height_width( + images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> list[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + + if input_data_format == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_data_format == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_data_format}") + return (max_height, max_width) + + +# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask +def make_pixel_mask( + image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +# Copied from transformers.models.detr.image_processing_detr.convert_coco_poly_to_mask +def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray: + """ + Convert a COCO polygon annotation to a mask. + + Args: + segmentations (`list[list[float]]`): + List of polygons, each polygon represented by a list of x-y coordinates. + height (`int`): + Height of the mask. + width (`int`): + Width of the mask. + """ + try: + from pycocotools import mask as coco_mask + except ImportError: + raise ImportError("Pycocotools is not installed in your environment.") + + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = np.asarray(mask, dtype=np.uint8) + mask = np.any(mask, axis=2) + masks.append(mask) + if masks: + masks = np.stack(masks, axis=0) + else: + masks = np.zeros((0, height, width), dtype=np.uint8) + + return masks + + +# Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->ConditionalDetr +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by ConditionalDetr. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + + image_id = target["image_id"] + image_id = np.asarray([image_id], dtype=np.int64) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + + classes = [obj["category_id"] for obj in annotations] + classes = np.asarray(classes, dtype=np.int64) + + # for conversion to coco api + area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32) + iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64) + + boxes = [obj["bbox"] for obj in annotations] + # guard against no boxes via resizing + boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = {} + new_target["image_id"] = image_id + new_target["class_labels"] = classes[keep] + new_target["boxes"] = boxes[keep] + new_target["area"] = area[keep] + new_target["iscrowd"] = iscrowd[keep] + new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64) + + if annotations and "keypoints" in annotations[0]: + keypoints = [obj["keypoints"] for obj in annotations] + # Converting the filtered keypoints list to a numpy array + keypoints = np.asarray(keypoints, dtype=np.float32) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + if return_segmentation_masks: + segmentation_masks = [obj["segmentation"] for obj in annotations] + masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width) + new_target["masks"] = masks[keep] + + return new_target + + +# Copied from transformers.models.detr.image_processing_detr.masks_to_boxes +def masks_to_boxes(masks: np.ndarray) -> np.ndarray: + """ + Compute the bounding boxes around the provided panoptic segmentation masks. + + Args: + masks: masks in format `[number_masks, height, width]` where N is the number of masks + + Returns: + boxes: bounding boxes in format `[number_masks, 4]` in xyxy format + """ + if masks.size == 0: + return np.zeros((0, 4)) + + h, w = masks.shape[-2:] + y = np.arange(0, h, dtype=np.float32) + x = np.arange(0, w, dtype=np.float32) + # see https://github.com/pytorch/pytorch/issues/50276 + y, x = np.meshgrid(y, x, indexing="ij") + + x_mask = masks * np.expand_dims(x, axis=0) + x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1) + x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool))) + x_min = x.filled(fill_value=1e8) + x_min = x_min.reshape(x_min.shape[0], -1).min(-1) + + y_mask = masks * np.expand_dims(y, axis=0) + y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1) + y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool))) + y_min = y.filled(fill_value=1e8) + y_min = y_min.reshape(y_min.shape[0], -1).min(-1) + + return np.stack([x_min, y_min, x_max, y_max], 1) + + +# Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->ConditionalDetr +def prepare_coco_panoptic_annotation( + image: np.ndarray, + target: dict, + masks_path: Union[str, pathlib.Path], + return_masks: bool = True, + input_data_format: Union[ChannelDimension, str] = None, +) -> dict: + """ + Prepare a coco panoptic annotation for ConditionalDetr. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + annotation_path = pathlib.Path(masks_path) / target["file_name"] + + new_target = {} + new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64) + new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64) + new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64) + + if "segments_info" in target: + masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32) + masks = rgb_to_id(masks) + + ids = np.array([segment_info["id"] for segment_info in target["segments_info"]]) + masks = masks == ids[:, None, None] + masks = masks.astype(np.uint8) + if return_masks: + new_target["masks"] = masks + new_target["boxes"] = masks_to_boxes(masks) + new_target["class_labels"] = np.array( + [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64 + ) + new_target["iscrowd"] = np.asarray( + [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64 + ) + new_target["area"] = np.asarray( + [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32 + ) + + return new_target + + +# Copied from transformers.models.detr.image_processing_detr.get_segmentation_image +def get_segmentation_image( + masks: np.ndarray, input_size: tuple, target_size: tuple, stuff_equiv_classes, deduplicate=False +): + h, w = input_size + final_h, final_w = target_size + + m_id = scipy.special.softmax(masks.transpose(0, 1), -1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = np.zeros((h, w), dtype=np.int64) + else: + m_id = m_id.argmax(-1).reshape(h, w) + + if deduplicate: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + for eq_id in equiv: + m_id[m_id == eq_id] = equiv[0] + + seg_img = id_to_rgb(m_id) + seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST) + return seg_img + + +# Copied from transformers.models.detr.image_processing_detr.get_mask_area +def get_mask_area(seg_img: np.ndarray, target_size: tuple[int, int], n_classes: int) -> np.ndarray: + final_h, final_w = target_size + np_seg_img = seg_img.astype(np.uint8) + np_seg_img = np_seg_img.reshape(final_h, final_w, 3) + m_id = rgb_to_id(np_seg_img) + area = [(m_id == i).sum() for i in range(n_classes)] + return area + + +# Copied from transformers.models.detr.image_processing_detr.score_labels_from_class_probabilities +def score_labels_from_class_probabilities(logits: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + probs = scipy.special.softmax(logits, axis=-1) + labels = probs.argmax(-1, keepdims=True) + scores = np.take_along_axis(probs, labels, axis=-1) + scores, labels = scores.squeeze(-1), labels.squeeze(-1) + return scores, labels + + +# Copied from transformers.models.detr.image_processing_detr.post_process_panoptic_sample with DetrForSegmentation->ConditionalDetrForSegmentation +def post_process_panoptic_sample( + out_logits: np.ndarray, + masks: np.ndarray, + boxes: np.ndarray, + processed_size: tuple[int, int], + target_size: tuple[int, int], + is_thing_map: dict, + threshold=0.85, +) -> dict: + """ + Converts the output of [`ConditionalDetrForSegmentation`] into panoptic segmentation predictions for a single sample. + + Args: + out_logits (`torch.Tensor`): + The logits for this sample. + masks (`torch.Tensor`): + The predicted segmentation masks for this sample. + boxes (`torch.Tensor`): + The predicted bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y, + width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding). + processed_size (`tuple[int, int]`): + The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size + after data augmentation but before batching. + target_size (`tuple[int, int]`): + The target size of the image, `(height, width)` corresponding to the requested final size of the + prediction. + is_thing_map (`Dict`): + A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not. + threshold (`float`, *optional*, defaults to 0.85): + The threshold used to binarize the segmentation masks. + """ + # we filter empty queries and detection below threshold + scores, labels = score_labels_from_class_probabilities(out_logits) + keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold) + + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_boxes = center_to_corners_format(boxes[keep]) + + if len(cur_boxes) != len(cur_classes): + raise ValueError("Not as many boxes as there are classes") + + cur_masks = masks[keep] + cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR) + cur_masks = safe_squeeze(cur_masks, 1) + b, h, w = cur_masks.shape + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.reshape(b, -1) + stuff_equiv_classes = defaultdict(list) + for k, label in enumerate(cur_classes): + if not is_thing_map[label]: + stuff_equiv_classes[label].append(k) + + seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True) + area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores)) + + # We filter out any mask that is too small + if cur_classes.size() > 0: + # We know filter empty masks as long as we find some + filtered_small = np.array([a <= 4 for a in area], dtype=bool) + while filtered_small.any(): + cur_masks = cur_masks[~filtered_small] + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True) + area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores)) + filtered_small = np.array([a <= 4 for a in area], dtype=bool) + else: + cur_classes = np.ones((1, 1), dtype=np.int64) + + segments_info = [ + {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a} + for i, (cat, a) in enumerate(zip(cur_classes, area)) + ] + del cur_classes + + with io.BytesIO() as out: + PIL.Image.fromarray(seg_img).save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + + return predictions + + +# Copied from transformers.models.detr.image_processing_detr.resize_annotation +def resize_annotation( + annotation: dict[str, Any], + orig_size: tuple[int, int], + target_size: tuple[int, int], + threshold: float = 0.5, + resample: PILImageResampling = PILImageResampling.NEAREST, +): + """ + Resizes an annotation to a target size. + + Args: + annotation (`dict[str, Any]`): + The annotation dictionary. + orig_size (`tuple[int, int]`): + The original size of the input image. + target_size (`tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`): + The resampling filter to use when resizing the masks. + """ + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size)) + ratio_height, ratio_width = ratios + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = np.array([resize(mask, target_size, resample=resample) for mask in masks]) + masks = masks.astype(np.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + +# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle +def binary_mask_to_rle(mask): + """ + Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + mask (`torch.Tensor` or `numpy.array`): + A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target + segment_id or class_id. + Returns: + `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE + format. + """ + if is_torch_tensor(mask): + mask = mask.numpy() + + pixels = mask.flatten() + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return list(runs) + + +# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle +def convert_segmentation_to_rle(segmentation): + """ + Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + segmentation (`torch.Tensor` or `numpy.array`): + A segmentation map of shape `(height, width)` where each value denotes a segment or class id. + Returns: + `list[List]`: A list of lists, where each list is the run-length encoding of a segment / class id. + """ + segment_ids = torch.unique(segmentation) + + run_length_encodings = [] + for idx in segment_ids: + mask = torch.where(segmentation == idx, 1, 0) + rle = binary_mask_to_rle(mask) + run_length_encodings.append(rle) + + return run_length_encodings + + +# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +# Copied from transformers.models.detr.image_processing_detr.check_segment_validity +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_area = (mask_probs[k] >= mask_threshold).sum() + mask_exists = mask_k_area > 0 and original_area > 0 + + # Eliminate disconnected tiny segments + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, mask_k + + +# Copied from transformers.models.detr.image_processing_detr.compute_segments +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_size: Optional[tuple[int, int]] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: list[dict] = [] + + if target_size is not None: + mask_probs = nn.functional.interpolate( + mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False + )[0] + + current_segment_id = 0 + + # Weigh each mask by its prediction score + mask_probs *= pred_scores.view(-1, 1, 1) + mask_labels = mask_probs.argmax(0) # [height, width] + + # Keep track of instances of each class + stuff_memory_list: dict[str, int] = {} + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + should_fuse = pred_class in label_ids_to_fuse + + # Check if mask exists and large enough to be a segment + mask_exists, mask_k = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if mask_exists: + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + + # Add current object segment to final segmentation map + segmentation[mask_k] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "was_fused": should_fuse, + "score": segment_score, + } + ) + if should_fuse: + stuff_memory_list[pred_class] = current_segment_id + + return segmentation, segments + + +@requires(backends=("vision",)) +class ConditionalDetrImageProcessor(BaseImageProcessor): + r""" + Constructs a Conditional Detr image processor. + + Args: + format (`str`, *optional*, defaults to `"coco_detection"`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be + overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize: + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean values to use when normalizing the image. Can be a single value or a list of values, one for each + channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one + for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.__init__ + def __init__( + self, + format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_annotations: Optional[bool] = None, + do_pad: bool = True, + pad_size: Optional[dict[str, int]] = None, + **kwargs, + ) -> None: + if "pad_and_return_pixel_mask" in kwargs: + do_pad = kwargs.pop("pad_and_return_pixel_mask") + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None if size is None else 1333 + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + size = get_size_dict(size, max_size=max_size, default_to_square=False) + + # Backwards compatibility + if do_convert_annotations is None: + do_convert_annotations = do_normalize + + super().__init__(**kwargs) + self.format = format + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_annotations = do_convert_annotations + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + self._valid_processor_keys = [ + "images", + "annotations", + "return_segmentation_masks", + "masks_path", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "do_convert_annotations", + "image_mean", + "image_std", + "do_pad", + "pad_size", + "format", + "return_tensors", + "data_format", + "input_data_format", + ] + + @classmethod + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->ConditionalDetr + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `ConditionalDetrImageProcessor.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->ConditionalDetr + def prepare_annotation( + self, + image: np.ndarray, + target: dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> dict: + """ + Prepare an annotation for feeding into ConditionalDetr model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + elif format == AnnotationFormat.COCO_PANOPTIC: + return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_panoptic_annotation( + image, + target, + masks_path=masks_path, + return_masks=return_segmentation_masks, + input_data_format=input_data_format, + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + new_size = get_resize_output_image_size( + image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format + ) + elif "max_height" in size and "max_width" in size: + new_size = get_image_size_for_max_height_width( + image, size["max_height"], size["max_width"], input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + image = resize( + image, + size=new_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return image + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation + def resize_annotation( + self, + annotation, + orig_size, + size, + resample: PILImageResampling = PILImageResampling.NEAREST, + ) -> dict: + """ + Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched + to this number. + """ + return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation + def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict: + """ + Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to + `[center_x, center_y, width, height]` format and from absolute to relative pixel values. + """ + return normalize_annotation(annotation, image_size=image_size) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image + def _update_annotation_for_padded_image( + self, + annotation: dict, + input_image_size: tuple[int, int], + output_image_size: tuple[int, int], + padding, + update_bboxes, + ) -> dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = pad( + masks, + padding, + mode=PaddingMode.CONSTANT, + constant_values=0, + input_data_format=ChannelDimension.FIRST, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= np.asarray( + [ + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + ] + ) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image + def _pad_image( + self, + image: np.ndarray, + output_size: tuple[int, int], + annotation: Optional[dict[str, Any]] = None, + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes + ) + return padded_image, annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad + def pad( + self, + images: list[np.ndarray], + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + pad_size: Optional[dict[str, int]] = None, + ) -> BatchFeature: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + images (list[`np.ndarray`]): + Images to pad. + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + Annotations to transform according to the padding that is applied to the images. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + update_bboxes (`bool`, *optional*, defaults to `True`): + Whether to update the bounding boxes in the annotations to match the padded images. If the + bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)` + format, the bounding boxes will not be updated. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) + + annotation_list = annotations if annotations is not None else [None] * len(images) + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotation_list): + padded_image, padded_annotation = self._pad_image( + image, + padded_size, + annotation, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=update_bboxes, + ) + padded_images.append(padded_image) + padded_annotations.append(padded_annotation) + + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [ + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) + for image in images + ] + data["pixel_mask"] = masks + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations + ] + + return encoded_inputs + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.preprocess + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample=None, # PILImageResampling + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + do_convert_annotations: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: Optional[bool] = None, + format: Optional[Union[str, AnnotationFormat]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + do_resize (`bool`, *optional*, defaults to self.do_resize): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to self.size): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to self.resample): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to self.do_rescale): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to self.rescale_factor): + Rescale factor to use when rescaling the image. + do_normalize (`bool`, *optional*, defaults to self.do_normalize): + Whether to normalize the image. + do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations): + Whether to convert the annotations to the format expected by the model. Converts the bounding + boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)` + and in relative coordinates. + image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean): + Mean to use when normalizing the image. + image_std (`float` or `list[float]`, *optional*, defaults to self.image_std): + Standard deviation to use when normalizing the image. + do_pad (`bool`, *optional*, defaults to self.do_pad): + Whether to pad the image. If `True`, padding will be applied to the bottom and right of + the image with zeros. If `pad_size` is provided, the image will be padded to the specified + dimensions. Otherwise, the image will be padded to the maximum height and width of the batch. + format (`str` or `AnnotationFormat`, *optional*, defaults to self.format): + Format of the annotations. + return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): + Type of tensors to return. If `None`, will return the list of images. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + if "pad_and_return_pixel_mask" in kwargs: + logger.warning_once( + "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, " + "use `do_pad` instead." + ) + do_pad = kwargs.pop("pad_and_return_pixel_mask") + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` argument is deprecated and will be removed in a future version, use" + " `size['longest_edge']` instead." + ) + size = kwargs.pop("max_size") + + do_resize = self.do_resize if do_resize is None else do_resize + size = self.size if size is None else size + size = get_size_dict(size=size, default_to_square=False) + resample = self.resample if resample is None else resample + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + do_convert_annotations = ( + self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations + ) + do_pad = self.do_pad if do_pad is None else do_pad + pad_size = self.pad_size if pad_size is None else pad_size + format = self.format if format is None else format + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + if ( + masks_path is not None + and format == AnnotationFormat.COCO_PANOPTIC + and not isinstance(masks_path, (pathlib.Path, str)) + ): + raise ValueError( + "The path to the directory containing the mask PNG files should be provided as a" + f" `pathlib.Path` or string object, but is {type(masks_path)} instead." + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + prepared_images = [] + prepared_annotations = [] + for image, target in zip(images, annotations): + target = self.prepare_annotation( + image, + target, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=input_data_format, + ) + prepared_images.append(image) + prepared_annotations.append(target) + images = prepared_images + annotations = prepared_annotations + del prepared_images, prepared_annotations + + # transformations + if do_resize: + if annotations is not None: + resized_images, resized_annotations = [], [] + for image, target in zip(images, annotations): + orig_size = get_image_size(image, input_data_format) + resized_image = self.resize( + image, size=size, resample=resample, input_data_format=input_data_format + ) + resized_annotation = self.resize_annotation( + target, orig_size, get_image_size(resized_image, input_data_format) + ) + resized_images.append(resized_image) + resized_annotations.append(resized_annotation) + images = resized_images + annotations = resized_annotations + del resized_images, resized_annotations + else: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + + if do_normalize: + images = [ + self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images + ] + + if do_convert_annotations and annotations is not None: + annotations = [ + self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + for annotation, image in zip(annotations, images) + ] + + if do_pad: + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + encoded_inputs = self.pad( + images, + annotations=annotations, + return_pixel_mask=True, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=do_convert_annotations, + return_tensors=return_tensors, + pad_size=pad_size, + ) + else: + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + + return encoded_inputs + + # POSTPROCESSING METHODS - TODO: add support for other frameworks + def post_process(self, outputs, target_sizes): + """ + Converts the output of [`ConditionalDetrForObjectDetection`] into the format expected by the Pascal VOC format (xmin, ymin, xmax, ymax). + Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + logging.warning_once( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + ) + + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1) + scores = topk_values + topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") + labels = topk_indexes % out_logits.shape[2] + boxes = center_to_corners_format(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + # Copied from transformers.models.deformable_detr.image_processing_deformable_detr.DeformableDetrImageProcessor.post_process_object_detection with DeformableDetr->ConditionalDetr + def post_process_object_detection( + self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100 + ): + """ + Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x, + top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + top_k (`int`, *optional*, defaults to 100): + Keep only top k bounding boxes before filtering by thresholding. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + prob = out_logits.sigmoid() + prob = prob.view(out_logits.shape[0], -1) + k_value = min(top_k, prob.size(1)) + topk_values, topk_indexes = torch.topk(prob, k_value, dim=1) + scores = topk_values + topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") + labels = topk_indexes % out_logits.shape[2] + boxes = center_to_corners_format(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, list): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_semantic_segmentation with Detr->ConditionalDetr + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None): + """ + Converts the output of [`ConditionalDetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrForSegmentation`]): + Raw outputs of the model. + target_sizes (`list[tuple[int, int]]`, *optional*): + A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the + batch. If unset, predictions will not be resized. + Returns: + `list[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance_segmentation with Detr->ConditionalDetr + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[list[tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + ) -> list[dict]: + """ + Converts the output of [`ConditionalDetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrForSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If unset, predictions will not be resized. + return_coco_annotation (`bool`, *optional*): + Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) + format. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`. Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=[], + target_size=target_size, + ) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic_segmentation with Detr->ConditionalDetr + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_sizes: Optional[list[tuple[int, int]]] = None, + ) -> list[dict]: + """ + Converts the output of [`ConditionalDetrForSegmentation`] into image panoptic segmentation predictions. Only supports + PyTorch. + + Args: + outputs ([`ConditionalDetrForSegmentation`]): + The outputs from [`ConditionalDetrForSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If unset, predictions will not be resized. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to + the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["ConditionalDetrImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..351d4fa1470f154afcc8f04d14aa3fcbd6fb43b9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py @@ -0,0 +1,1019 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/conditional_detr/modular_conditional_detr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_conditional_detr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import pathlib +from typing import Any, Optional, Union + +import torch +from torch import nn +from torchvision.io import read_image +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, + get_image_size_for_max_height_width, + get_max_height_width, + safe_squeeze, +) +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + validate_annotations, +) +from ...processing_utils import Unpack +from ...utils import TensorType, auto_docstring, logging +from ...utils.import_utils import requires +from .image_processing_conditional_detr import ( + compute_segments, + convert_segmentation_to_rle, + get_size_with_aspect_ratio, + remove_low_and_no_objects, +) + + +logger = logging.get_logger(__name__) + + +class ConditionalDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the CONDITIONAL_DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + return_segmentation_masks (`bool`, *optional*, defaults to `False`): + Whether to return segmentation masks. + """ + + format: Optional[Union[str, AnnotationFormat]] + do_convert_annotations: Optional[bool] + return_segmentation_masks: Optional[bool] + + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + + +# inspired by https://github.com/facebookresearch/conditional_detr/blob/master/datasets/coco.py#L33 +def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor: + """ + Convert a COCO polygon annotation to a mask. + + Args: + segmentations (`list[list[float]]`): + List of polygons, each polygon represented by a list of x-y coordinates. + height (`int`): + Height of the mask. + width (`int`): + Width of the mask. + """ + try: + from pycocotools import mask as coco_mask + except ImportError: + raise ImportError("Pycocotools is not installed in your environment.") + + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8, device=device) + mask = torch.any(mask, axis=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, axis=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8, device=device) + + return masks + + +# inspired by https://github.com/facebookresearch/conditional_detr/blob/master/datasets/coco.py#L50 +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by CONDITIONAL_DETR. + """ + image_height, image_width = image.size()[-2:] + + image_id = target["image_id"] + image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + classes = [] + area = [] + boxes = [] + keypoints = [] + for obj in annotations: + if "iscrowd" not in obj or obj["iscrowd"] == 0: + classes.append(obj["category_id"]) + area.append(obj["area"]) + boxes.append(obj["bbox"]) + if "keypoints" in obj: + keypoints.append(obj["keypoints"]) + + classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) + area = torch.as_tensor(area, dtype=torch.float32, device=image.device) + iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device) + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = { + "image_id": image_id, + "class_labels": classes[keep], + "boxes": boxes[keep], + "area": area[keep], + "iscrowd": iscrowd[keep], + "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device), + } + + if keypoints: + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + if return_segmentation_masks: + segmentation_masks = [obj["segmentation"] for obj in annotations] + masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width, device=image.device) + new_target["masks"] = masks[keep] + + return new_target + + +def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: + """ + Compute the bounding boxes around the provided panoptic segmentation masks. + + Args: + masks: masks in format `[number_masks, height, width]` where N is the number of masks + + Returns: + boxes: bounding boxes in format `[number_masks, 4]` in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + y = torch.arange(0, h, dtype=torch.float32, device=masks.device) + x = torch.arange(0, w, dtype=torch.float32, device=masks.device) + # see https://github.com/pytorch/pytorch/issues/50276 + y, x = torch.meshgrid(y, x, indexing="ij") + + x_mask = masks * torch.unsqueeze(x, 0) + x_max = x_mask.view(x_mask.shape[0], -1).max(-1)[0] + x_min = ( + torch.where(masks, x.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0] + ) + + y_mask = masks * torch.unsqueeze(y, 0) + y_max = y_mask.view(y_mask.shape[0], -1).max(-1)[0] + y_min = ( + torch.where(masks, y.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0] + ) + + return torch.stack([x_min, y_min, x_max, y_max], 1) + + +# 2 functions below adapted from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py +# Copyright (c) 2018, Alexander Kirillov +# All rights reserved. +def rgb_to_id(color): + """ + Converts RGB color to unique ID. + """ + if isinstance(color, torch.Tensor) and len(color.shape) == 3: + if color.dtype == torch.uint8: + color = color.to(torch.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + +def prepare_coco_panoptic_annotation( + image: torch.Tensor, + target: dict, + masks_path: Union[str, pathlib.Path], + return_masks: bool = True, + input_data_format: Union[ChannelDimension, str] = None, +) -> dict: + """ + Prepare a coco panoptic annotation for CONDITIONAL_DETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + annotation_path = pathlib.Path(masks_path) / target["file_name"] + + new_target = {} + new_target["image_id"] = torch.as_tensor( + [target["image_id"] if "image_id" in target else target["id"]], dtype=torch.int64, device=image.device + ) + new_target["size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device) + new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device) + + if "segments_info" in target: + masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device) + masks = rgb_to_id(masks) + + ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device) + masks = masks == ids[:, None, None] + masks = masks.to(torch.bool) + if return_masks: + new_target["masks"] = masks + new_target["boxes"] = masks_to_boxes(masks) + new_target["class_labels"] = torch.as_tensor( + [segment_info["category_id"] for segment_info in target["segments_info"]], + dtype=torch.int64, + device=image.device, + ) + new_target["iscrowd"] = torch.as_tensor( + [segment_info["iscrowd"] for segment_info in target["segments_info"]], + dtype=torch.int64, + device=image.device, + ) + new_target["area"] = torch.as_tensor( + [segment_info["area"] for segment_info in target["segments_info"]], + dtype=torch.float32, + device=image.device, + ) + + return new_target + + +@auto_docstring +@requires(backends=("torchvision", "torch")) +class ConditionalDetrImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + format = AnnotationFormat.COCO_DETECTION + do_resize = True + do_rescale = True + do_normalize = True + do_pad = True + size = {"shortest_edge": 800, "longest_edge": 1333} + default_to_square = False + model_input_names = ["pixel_values", "pixel_mask"] + valid_kwargs = ConditionalDetrFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[ConditionalDetrFastImageProcessorKwargs]) -> None: + if "pad_and_return_pixel_mask" in kwargs: + kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask") + + size = kwargs.pop("size", None) + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None if size is None else 1333 + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + self.size = get_size_dict(size, max_size=max_size, default_to_square=False) + + # Backwards compatibility + do_convert_annotations = kwargs.get("do_convert_annotations") + do_normalize = kwargs.get("do_normalize") + if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None: + self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize + + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `ConditionalDetrImageProcessorFast.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + + def prepare_annotation( + self, + image: torch.Tensor, + target: dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> dict: + """ + Prepare an annotation for feeding into CONDITIONAL_DETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + elif format == AnnotationFormat.COCO_PANOPTIC: + return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_panoptic_annotation( + image, + target, + masks_path=masks_path, + return_masks=return_segmentation_masks, + input_data_format=input_data_format, + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"] = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use if resizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size["shortest_edge"], + size["longest_edge"], + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"]) + elif size.height and size.width: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + + image = F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + return image + + def resize_annotation( + self, + annotation: dict[str, Any], + orig_size: tuple[int, int], + target_size: tuple[int, int], + threshold: float = 0.5, + interpolation: Optional["F.InterpolationMode"] = None, + ): + """ + Resizes an annotation to a target size. + + Args: + annotation (`dict[str, Any]`): + The annotation dictionary. + orig_size (`tuple[int, int]`): + The original size of the input image. + target_size (`tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`): + The resampling filter to use when resizing the masks. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST_EXACT + ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)] + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device + ) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks] + masks = torch.stack(masks).to(torch.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= torch.as_tensor( + [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device + ) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + def _update_annotation_for_padded_image( + self, + annotation: dict, + input_image_size: tuple[int, int], + output_image_size: tuple[int, int], + padding, + update_bboxes, + ) -> dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size)) + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = F.pad( + masks, + padding, + fill=0, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + def pad( + self, + image: torch.Tensor, + padded_size: tuple[int, int], + annotation: Optional[dict[str, Any]] = None, + update_bboxes: bool = True, + fill: int = 0, + ): + original_size = image.size()[-2:] + padding_bottom = padded_size[0] - original_size[0] + padding_right = padded_size[1] - original_size[1] + if padding_bottom < 0 or padding_right < 0: + raise ValueError( + f"Padding dimensions are negative. Please make sure that the padded size is larger than the " + f"original size. Got padded size: {padded_size}, original size: {original_size}." + ) + if original_size != padded_size: + padding = [0, 0, padding_right, padding_bottom] + image = F.pad(image, padding, fill=fill) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, original_size, padded_size, padding, update_bboxes + ) + + # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device) + pixel_mask[: original_size[0], : original_size[1]] = 1 + + return image, pixel_mask, annotation + + @auto_docstring + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + **kwargs: Unpack[ConditionalDetrFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + """ + if "pad_and_return_pixel_mask" in kwargs: + kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask") + logger.warning_once( + "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, " + "use `do_pad` instead." + ) + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` argument is deprecated and will be removed in a future version, use" + " `size['longest_edge']` instead." + ) + kwargs["size"] = kwargs.pop("max_size") + + return super().preprocess(images, annotations, masks_path, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + annotations: Optional[Union[AnnotationType, list[AnnotationType]]], + masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_convert_annotations: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_pad: bool, + pad_size: Optional[SizeDict], + format: Optional[Union[str, AnnotationFormat]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + """ + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + if ( + masks_path is not None + and format == AnnotationFormat.COCO_PANOPTIC + and not isinstance(masks_path, (pathlib.Path, str)) + ): + raise ValueError( + "The path to the directory containing the mask PNG files should be provided as a" + f" `pathlib.Path` or string object, but is {type(masks_path)} instead." + ) + + data = {} + + processed_images = [] + processed_annotations = [] + pixel_masks = [] # Initialize pixel_masks here + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # prepare (COCO annotations as a list of Dict -> CONDITIONAL_DETR target as a single Dict per image) + if annotations is not None: + annotation = self.prepare_annotation( + image, + annotation, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=ChannelDimension.FIRST, + ) + + if do_resize: + resized_image = self.resize(image, size=size, interpolation=interpolation) + if annotations is not None: + annotation = self.resize_annotation( + annotation, + orig_size=image.size()[-2:], + target_size=resized_image.size()[-2:], + ) + image = resized_image + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + if do_convert_annotations and annotations is not None: + annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) + + processed_images.append(image) + processed_annotations.append(annotation) + images = processed_images + annotations = processed_annotations if annotations is not None else None + + if do_pad: + # depends on all resized image shapes so we need another loop + if pad_size is not None: + padded_size = (pad_size.height, pad_size.width) + else: + padded_size = get_max_height_width(images) + + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + if padded_size == image.size()[-2:]: + padded_images.append(image) + pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) + padded_annotations.append(annotation) + continue + image, pixel_mask, annotation = self.pad( + image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations + ) + padded_images.append(image) + padded_annotations.append(annotation) + pixel_masks.append(pixel_mask) + images = padded_images + annotations = padded_annotations if annotations is not None else None + data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) + + data.update({"pixel_values": torch.stack(images, dim=0)}) + encoded_inputs = BatchFeature(data, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + return encoded_inputs + + def post_process(self, outputs, target_sizes): + """ + Converts the output of [`ConditionalDetrForObjectDetection`] into the format expected by the Pascal VOC format (xmin, ymin, xmax, ymax). + Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + logging.warning_once( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + ) + + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1) + scores = topk_values + topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") + labels = topk_indexes % out_logits.shape[2] + boxes = center_to_corners_format(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + def post_process_object_detection( + self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100 + ): + """ + Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x, + top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + top_k (`int`, *optional*, defaults to 100): + Keep only top k bounding boxes before filtering by thresholding. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + prob = out_logits.sigmoid() + prob = prob.view(out_logits.shape[0], -1) + k_value = min(top_k, prob.size(1)) + topk_values, topk_indexes = torch.topk(prob, k_value, dim=1) + scores = topk_values + topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") + labels = topk_indexes % out_logits.shape[2] + boxes = center_to_corners_format(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, list): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None): + """ + Converts the output of [`ConditionalDetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrForSegmentation`]): + Raw outputs of the model. + target_sizes (`list[tuple[int, int]]`, *optional*): + A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the + batch. If unset, predictions will not be resized. + Returns: + `list[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[list[tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + ) -> list[dict]: + """ + Converts the output of [`ConditionalDetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrForSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If unset, predictions will not be resized. + return_coco_annotation (`bool`, *optional*): + Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) + format. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`. Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=[], + target_size=target_size, + ) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_sizes: Optional[list[tuple[int, int]]] = None, + ) -> list[dict]: + """ + Converts the output of [`ConditionalDetrForSegmentation`] into image panoptic segmentation predictions. Only supports + PyTorch. + + Args: + outputs ([`ConditionalDetrForSegmentation`]): + The outputs from [`ConditionalDetrForSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If unset, predictions will not be resized. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to + the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["ConditionalDetrImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/modeling_conditional_detr.py b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/modeling_conditional_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..ada06d87f7cc0c6a89b8cca9fc0e309b383d0158 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -0,0 +1,1995 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Conditional DETR model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends +from ...utils.backbone_utils import load_backbone +from .configuration_conditional_detr import ConditionalDetrConfig + + +if is_timm_available(): + from timm import create_model + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Conditional DETR decoder. This class adds one attribute to + BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output + of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary + decoding losses. + """ +) +class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions): + r""" + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + reference_points: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Conditional DETR encoder-decoder model. This class adds one attribute to + Seq2SeqModelOutput, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder + layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding + losses. + """ +) +class ConditionalDetrModelOutput(Seq2SeqModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + reference_points: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`ConditionalDetrForObjectDetection`]. + """ +) +# Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->ConditionalDetr +class ConditionalDetrObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~ConditionalDetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`ConditionalDetrForSegmentation`]. + """ +) +# Copied from transformers.models.detr.modeling_detr.DetrSegmentationOutput with Detr->ConditionalDetr +class ConditionalDetrSegmentationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~ConditionalDetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`): + Segmentation masks logits for all queries. See also + [`~ConditionalDetrImageProcessor.post_process_semantic_segmentation`] or + [`~ConditionalDetrImageProcessor.post_process_instance_segmentation`] + [`~ConditionalDetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic + segmentation masks respectively. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->ConditionalDetr +class ConditionalDetrFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `ConditionalDetrFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr +class ConditionalDetrConvEncoder(nn.Module): + """ + Convolutional backbone, using either the AutoBackbone API or one from the timm library. + + nn.BatchNorm2d layers are replaced by ConditionalDetrFrozenBatchNorm2d as defined above. + + """ + + def __init__(self, config): + super().__init__() + + self.config = config + + # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API + if config.use_timm_backbone: + # We default to values which were previously hard-coded. This enables configurability from the config + # using backbone arguments, while keeping the default behavior the same. + requires_backends(self, ["timm"]) + kwargs = getattr(config, "backbone_kwargs", {}) + kwargs = {} if kwargs is None else kwargs.copy() + out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) + num_channels = kwargs.pop("in_chans", config.num_channels) + if config.dilation: + kwargs["output_stride"] = kwargs.get("output_stride", 16) + backbone = create_model( + config.backbone, + pretrained=config.use_pretrained_backbone, + features_only=True, + out_indices=out_indices, + in_chans=num_channels, + **kwargs, + ) + else: + backbone = load_backbone(config) + + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = ( + self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels + ) + + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + + if "resnet" in backbone_model_type: + for name, parameter in self.model.named_parameters(): + if config.use_timm_backbone: + if "layer2" not in name and "layer3" not in name and "layer4" not in name: + parameter.requires_grad_(False) + else: + if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name: + parameter.requires_grad_(False) + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->ConditionalDetr +class ConditionalDetrConvModel(nn.Module): + """ + This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. + """ + + def __init__(self, conv_encoder, position_embedding): + super().__init__() + self.conv_encoder = conv_encoder + self.position_embedding = position_embedding + + def forward(self, pixel_values, pixel_mask): + # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples + out = self.conv_encoder(pixel_values, pixel_mask) + pos = [] + for feature_map, mask in out: + # position encoding + pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) + + return out, pos + + +class ConditionalDetrSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.embedding_dim = embedding_dim + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, pixel_values, pixel_mask): + if pixel_mask is None: + raise ValueError("No pixel mask provided") + y_embed = pixel_mask.cumsum(1, dtype=torch.float32) + x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale + + dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding with Detr->ConditionalDetr +class ConditionalDetrLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256): + super().__init__() + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) + + def forward(self, pixel_values, pixel_mask=None): + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + return pos + + +# Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->ConditionalDetr +def build_position_encoding(config): + n_steps = config.d_model // 2 + if config.position_embedding_type == "sine": + # TODO find a better way of exposing other arguments + position_embedding = ConditionalDetrSinePositionEmbedding(n_steps, normalize=True) + elif config.position_embedding_type == "learned": + position_embedding = ConditionalDetrLearnedPositionEmbedding(n_steps) + else: + raise ValueError(f"Not supported {config.position_embedding_type}") + + return position_embedding + + +# function to generate sine positional embedding for 2d coordinates +def gen_sine_position_embeddings(pos_tensor, d_model): + scale = 2 * math.pi + dim = d_model // 2 + dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + pos = torch.cat((pos_y, pos_x), dim=2) + return pos.to(pos_tensor.dtype) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# Copied from transformers.models.detr.modeling_detr.DetrAttention +class DetrAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]): + return tensor if object_queries is None else tensor + object_queries + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + spatial_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size, target_len, embed_dim = hidden_states.size() + + # add position embeddings to the hidden states before projecting to queries and keys + if object_queries is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, object_queries) + + # add key-value position embeddings to the key value states + if spatial_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + if attention_mask.dtype == torch.bool: + attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( + attention_mask, -torch.inf + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class ConditionalDetrAttention(nn.Module): + """ + Cross-Attention used in Conditional DETR 'Conditional DETR for Fast Training Convergence' paper. + + The key q_proj, k_proj, v_proj are defined outside the attention. This attention allows the dim of q, k to be + different to v. + """ + + def __init__( + self, + embed_dim: int, + out_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + # head dimension of values + self.v_head_dim = out_dim // num_heads + if self.v_head_dim * num_heads != self.out_dim: + raise ValueError( + f"out_dim must be divisible by num_heads (got `out_dim`: {self.out_dim} and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.out_proj = nn.Linear(out_dim, out_dim, bias=bias) + + def _qk_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + key_states: Optional[torch.Tensor] = None, + value_states: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, _ = hidden_states.size() + + # get query proj + query_states = hidden_states * self.scaling + # get key, value proj + key_states = self._qk_shape(key_states, -1, batch_size) + value_states = self._v_shape(value_states, -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + v_proj_shape = (batch_size * self.num_heads, -1, self.v_head_dim) + query_states = self._qk_shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*v_proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + if attention_mask.dtype == torch.bool: + attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( + attention_mask, -torch.inf + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, self.out_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->ConditionalDetrEncoderLayer,DetrConfig->ConditionalDetrConfig +class ConditionalDetrEncoderLayer(nn.Module): + def __init__(self, config: ConditionalDetrConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + object_queries: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class ConditionalDetrDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ConditionalDetrConfig): + super().__init__() + self.embed_dim = config.d_model + + d_model = config.d_model + # Decoder Self-Attention projections + self.sa_qcontent_proj = nn.Linear(d_model, d_model) + self.sa_qpos_proj = nn.Linear(d_model, d_model) + self.sa_kcontent_proj = nn.Linear(d_model, d_model) + self.sa_kpos_proj = nn.Linear(d_model, d_model) + self.sa_v_proj = nn.Linear(d_model, d_model) + + self.self_attn = ConditionalDetrAttention( + embed_dim=self.embed_dim, + out_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + # Decoder Cross-Attention projections + self.ca_qcontent_proj = nn.Linear(d_model, d_model) + self.ca_qpos_proj = nn.Linear(d_model, d_model) + self.ca_kcontent_proj = nn.Linear(d_model, d_model) + self.ca_kpos_proj = nn.Linear(d_model, d_model) + self.ca_v_proj = nn.Linear(d_model, d_model) + self.ca_qpos_sine_proj = nn.Linear(d_model, d_model) + + self.encoder_attn = ConditionalDetrAttention( + self.embed_dim * 2, self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.nhead = config.decoder_attention_heads + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + query_sine_embed: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + is_first: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + object_queries that are added to the queries and keys + in the cross-attention layer. + query_position_embeddings (`torch.FloatTensor`, *optional*): + object_queries that are added to the queries and keys + in the self-attention layer. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # ========== Begin of Self-Attention ============= + # Apply projections here + # shape: num_queries x batch_size x 256 + q_content = self.sa_qcontent_proj( + hidden_states + ) # target is the input of the first decoder layer. zero by default. + q_pos = self.sa_qpos_proj(query_position_embeddings) + k_content = self.sa_kcontent_proj(hidden_states) + k_pos = self.sa_kpos_proj(query_position_embeddings) + v = self.sa_v_proj(hidden_states) + + _, num_queries, n_model = q_content.shape + + q = q_content + q_pos + k = k_content + k_pos + hidden_states, self_attn_weights = self.self_attn( + hidden_states=q, + attention_mask=attention_mask, + key_states=k, + value_states=v, + output_attentions=output_attentions, + ) + # ============ End of Self-Attention ============= + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # ========== Begin of Cross-Attention ============= + # Apply projections here + # shape: num_queries x batch_size x 256 + q_content = self.ca_qcontent_proj(hidden_states) + k_content = self.ca_kcontent_proj(encoder_hidden_states) + v = self.ca_v_proj(encoder_hidden_states) + + batch_size, num_queries, n_model = q_content.shape + _, source_len, _ = k_content.shape + + k_pos = self.ca_kpos_proj(object_queries) + + # For the first decoder layer, we concatenate the positional embedding predicted from + # the object query (the positional embedding) into the original query (key) in DETR. + if is_first: + q_pos = self.ca_qpos_proj(query_position_embeddings) + q = q_content + q_pos + k = k_content + k_pos + else: + q = q_content + k = k_content + + q = q.view(batch_size, num_queries, self.nhead, n_model // self.nhead) + query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) + query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead) + q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2) + k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead) + k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead) + k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=q, + attention_mask=encoder_attention_mask, + key_states=k, + value_states=v, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # ============ End of Cross-Attention ============= + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->MLP +class MLP(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@auto_docstring +# Copied from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->ConditionalDetr +class ConditionalDetrPreTrainedModel(PreTrainedModel): + config: ConditionalDetrConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + _no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + xavier_std = self.config.init_xavier_std + + if isinstance(module, ConditionalDetrMHAttentionMap): + nn.init.zeros_(module.k_linear.bias) + nn.init.zeros_(module.q_linear.bias) + nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) + nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + elif isinstance(module, ConditionalDetrLearnedPositionEmbedding): + nn.init.uniform_(module.row_embeddings.weight) + nn.init.uniform_(module.column_embeddings.weight) + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +# Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR +class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`ConditionalDetrEncoderLayer`]. + + The encoder updates the flattened feature map through multiple self-attention layers. + + Small tweak for ConditionalDETR: + + - object_queries are added to the forward pass. + + Args: + config: ConditionalDetrConfig + """ + + def __init__(self, config: ConditionalDetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + self.layers = nn.ModuleList([ConditionalDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + + # in the original ConditionalDETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + object_queries=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Object queries that are added to the queries in each self-attention layer. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + # we add object_queries as extra input to the encoder_layer + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`ConditionalDetrDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some small tweaks for Conditional DETR: + + - object_queries and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: ConditionalDetrConfig + """ + + def __init__(self, config: ConditionalDetrConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + + self.layers = nn.ModuleList([ConditionalDetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + # in Conditional DETR, the decoder uses layernorm after the last decoder layer output + self.layernorm = nn.LayerNorm(config.d_model) + d_model = config.d_model + self.gradient_checkpointing = False + + # query_scale is the FFN applied on f to generate transformation T + self.query_scale = MLP(d_model, d_model, d_model, 2) + self.ref_point_head = MLP(d_model, d_model, 2, 2) + for layer_id in range(config.decoder_layers - 1): + self.layers[layer_id + 1].ca_qpos_proj = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + object_queries=None, + query_position_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The query embeddings that are passed into the decoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each cross-attention layer. + query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # optional intermediate hidden states + intermediate = () if self.config.auxiliary_loss else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + reference_points_before_sigmoid = self.ref_point_head( + query_position_embeddings + ) # [num_queries, batch_size, 2] + reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1) + obj_center = reference_points[..., :2].transpose(0, 1) + # get sine embedding for the query vector + query_sine_embed_before_transformation = gen_sine_position_embeddings(obj_center, self.config.d_model) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + if idx == 0: + pos_transformation = 1 + else: + pos_transformation = self.query_scale(hidden_states) + # apply transformation + query_sine_embed = query_sine_embed_before_transformation * pos_transformation + + layer_outputs = decoder_layer( + hidden_states, + None, # attention_mask + object_queries, + query_position_embeddings, + query_sine_embed, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + is_first=(idx == 0), + ) + + hidden_states = layer_outputs[0] + + if self.config.auxiliary_loss: + hidden_states = self.layernorm(hidden_states) + intermediate += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # finally, apply layernorm + hidden_states = self.layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # stack intermediate decoder activations + if self.config.auxiliary_loss: + intermediate = torch.stack(intermediate) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attns, + all_cross_attentions, + intermediate, + reference_points, + ] + if v is not None + ) + return ConditionalDetrDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=intermediate, + reference_points=reference_points, + ) + + +@auto_docstring( + custom_intro=""" + The bare Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw + hidden-states without any specific head on top. + """ +) +class ConditionalDetrModel(ConditionalDetrPreTrainedModel): + def __init__(self, config: ConditionalDetrConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = ConditionalDetrConvEncoder(config) + object_queries = build_position_encoding(config) + self.backbone = ConditionalDetrConvModel(backbone, object_queries) + + # Create projection layer + self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1) + + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = ConditionalDetrEncoder(config) + self.decoder = ConditionalDetrDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], ConditionalDetrModelOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50") + >>> model = AutoModel.from_pretrained("microsoft/conditional-detr-resnet-50") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # the last hidden states are the final query embeddings of the Transformer decoder + >>> # these are of shape (batch_size, num_queries, hidden_size) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # pixel_values should be of shape (batch_size, num_channels, height, width) + # pixel_mask should be of shape (batch_size, height, width) + features, object_queries_list = self.backbone(pixel_values, pixel_mask) + + # get final feature map and downsampled mask + feature_map, mask = features[-1] + + if mask is None: + raise ValueError("Backbone does not return downsampled pixel mask") + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + projected_feature_map = self.input_projection(feature_map) + + # Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) + + flattened_mask = mask.flatten(1) + + # Fourth, sent flattened_features + flattened_mask + object_queries through encoder + # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, height*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + object_queries=object_queries, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + queries = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=queries, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=flattened_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return ConditionalDetrModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + reference_points=decoder_outputs.reference_points, + ) + + +# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->ConditionalDetr +class ConditionalDetrMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@auto_docstring( + custom_intro=""" + Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on + top, for tasks such as COCO detection. + """ +) +class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel): + def __init__(self, config: ConditionalDetrConfig): + super().__init__(config) + + # CONDITIONAL DETR encoder-decoder model + self.model = ConditionalDetrModel(config) + + # Object detection heads + self.class_labels_classifier = nn.Linear( + config.d_model, config.num_labels + ) # We add one for the "no object" class + self.bbox_predictor = ConditionalDetrMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + + # Initialize weights and apply final processing + self.post_init() + + # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], ConditionalDetrObjectDetectionOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50") + >>> model = AutoModelForObjectDetection.from_pretrained("microsoft/conditional-detr-resnet-50") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[ + ... 0 + ... ] + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected remote with confidence 0.833 at location [38.31, 72.1, 177.63, 118.45] + Detected cat with confidence 0.831 at location [9.2, 51.38, 321.13, 469.0] + Detected cat with confidence 0.804 at location [340.3, 16.85, 642.93, 370.95] + Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01] + Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through CONDITIONAL_DETR base model to obtain encoder + decoder outputs + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # class logits + predicted bounding boxes + logits = self.class_labels_classifier(sequence_output) + + # Index [-2] is valid only if `output_attentions` and `output_hidden_states` + # are not specified, otherwise it will be another index which is hard to determine. + # Leave it as is, because it's not a common case to use + # return_dict=False + output_attentions=True / output_hidden_states=True + reference = outputs.reference_points if return_dict else outputs[-2] + reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1) + + hs = sequence_output + tmp = self.bbox_predictor(hs) + tmp[..., :2] += reference_before_sigmoid + pred_boxes = tmp.sigmoid() + # pred_boxes = self.bbox_predictor(sequence_output).sigmoid() + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + outputs_coords = [] + intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4] + outputs_class = self.class_labels_classifier(intermediate) + for lvl in range(intermediate.shape[0]): + tmp = self.bbox_predictor(intermediate[lvl]) + tmp[..., :2] += reference_before_sigmoid + outputs_coord = tmp.sigmoid() + outputs_coords.append(outputs_coord) + outputs_coord = torch.stack(outputs_coords) + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + auxiliary_outputs + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return ConditionalDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, + for tasks such as COCO panoptic. + """ +) +class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel): + def __init__(self, config: ConditionalDetrConfig): + super().__init__(config) + + # object detection model + self.conditional_detr = ConditionalDetrForObjectDetection(config) + + # segmentation head + hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads + intermediate_channel_sizes = self.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes + + self.mask_head = ConditionalDetrMaskHeadSmallConv( + hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size + ) + + self.bbox_attention = ConditionalDetrMHAttentionMap( + hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], ConditionalDetrSegmentationOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each + dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels, + bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves + should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a + `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a + `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`. + + Examples: + + ```python + >>> import io + >>> import requests + >>> from PIL import Image + >>> import torch + >>> import numpy + + >>> from transformers import ( + ... AutoImageProcessor, + ... ConditionalDetrConfig, + ... ConditionalDetrForSegmentation, + ... ) + >>> from transformers.image_transforms import rgb_to_id + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50") + + >>> # randomly initialize all weights of the model + >>> config = ConditionalDetrConfig() + >>> model = ConditionalDetrForSegmentation(config) + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps + >>> # Segmentation results are returned as a list of dictionaries + >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)]) + >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found + >>> panoptic_seg = result[0]["segmentation"] + >>> # Get prediction score and segment_id to class_id mapping of each segment + >>> panoptic_segments_info = result[0]["segments_info"] + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=device) + + # First, get list of feature maps and object_queries + features, object_queries_list = self.conditional_detr.model.backbone(pixel_values, pixel_mask=pixel_mask) + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + feature_map, mask = features[-1] + batch_size, num_channels, height, width = feature_map.shape + projected_feature_map = self.conditional_detr.model.input_projection(feature_map) + + # Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) + + flattened_mask = mask.flatten(1) + + # Fourth, sent flattened_features + flattened_mask + object_queries through encoder + # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, height*width) + if encoder_outputs is None: + encoder_outputs = self.conditional_detr.model.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + object_queries=object_queries, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(0).repeat( + batch_size, 1, 1 + ) + queries = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + decoder_outputs = self.conditional_detr.model.decoder( + inputs_embeds=queries, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=flattened_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Sixth, compute logits, pred_boxes and pred_masks + logits = self.conditional_detr.class_labels_classifier(sequence_output) + pred_boxes = self.conditional_detr.bbox_predictor(sequence_output).sigmoid() + + memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width) + mask = flattened_mask.view(batch_size, height, width) + + # FIXME h_boxes takes the last one computed, keep this in mind + # important: we need to reverse the mask, since in the original implementation the mask works reversed + # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32) + bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask) + + seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]]) + + pred_masks = seg_masks.view( + batch_size, self.conditional_detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1] + ) + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1] + outputs_class = self.conditional_detr.class_labels_classifier(intermediate) + outputs_coord = self.conditional_detr.bbox_predictor(intermediate).sigmoid() + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, self.device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs + else: + output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return ConditionalDetrSegmentationOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + pred_masks=pred_masks, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +def _expand(tensor, length: int): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + +# Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr +class ConditionalDetrMaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + + if dim % 8 != 0: + raise ValueError( + "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in" + " GroupNorm is set to 8" + ) + + inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] + + self.lay1 = nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = nn.GroupNorm(8, dim) + self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1]) + self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2]) + self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3]) + self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4]) + self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]): + # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with + # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32). + # We expand the projected feature map to match the number of heads. + x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = nn.functional.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = nn.functional.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = nn.functional.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = nn.functional.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = nn.functional.relu(x) + + x = self.out_lay(x) + return x + + +# Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->ConditionalDetr +class ConditionalDetrMHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask: Optional[Tensor] = None): + q = self.q_linear(q) + k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) + + if mask is not None: + weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min) + weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size()) + weights = self.dropout(weights) + return weights + + +__all__ = [ + "ConditionalDetrForObjectDetection", + "ConditionalDetrForSegmentation", + "ConditionalDetrModel", + "ConditionalDetrPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/modular_conditional_detr.py b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/modular_conditional_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0faf2c4b9ed19352bf28755206a82fbcd9a701 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/conditional_detr/modular_conditional_detr.py @@ -0,0 +1,134 @@ +from typing import Union + +import torch + +from transformers.models.detr.image_processing_detr_fast import DetrImageProcessorFast + +from ...image_transforms import ( + center_to_corners_format, +) +from ...utils import ( + TensorType, + logging, +) + + +logger = logging.get_logger(__name__) + + +class ConditionalDetrImageProcessorFast(DetrImageProcessorFast): + def post_process(self, outputs, target_sizes): + """ + Converts the output of [`ConditionalDetrForObjectDetection`] into the format expected by the Pascal VOC format (xmin, ymin, xmax, ymax). + Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + logging.warning_once( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + ) + + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1) + scores = topk_values + topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") + labels = topk_indexes % out_logits.shape[2] + boxes = center_to_corners_format(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + def post_process_object_detection( + self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100 + ): + """ + Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x, + top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`ConditionalDetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + top_k (`int`, *optional*, defaults to 100): + Keep only top k bounding boxes before filtering by thresholding. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + prob = out_logits.sigmoid() + prob = prob.view(out_logits.shape[0], -1) + k_value = min(top_k, prob.size(1)) + topk_values, topk_indexes = torch.topk(prob, k_value, dim=1) + scores = topk_values + topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") + labels = topk_indexes % out_logits.shape[2] + boxes = center_to_corners_format(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, list): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + + def post_process_segmentation(self): + raise NotImplementedError("Segmentation post-processing is not implemented for Conditional DETR yet.") + + def post_process_instance(self): + raise NotImplementedError("Instance post-processing is not implemented for Conditional DETR yet.") + + def post_process_panoptic(self): + raise NotImplementedError("Panoptic post-processing is not implemented for Conditional DETR yet.") + + +__all__ = ["ConditionalDetrImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnext/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/convnext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d826745f5b2e011997179ac0dd3d3cfc14389d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/convnext/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_convnext import * + from .feature_extraction_convnext import * + from .image_processing_convnext import * + from .image_processing_convnext_fast import * + from .modeling_convnext import * + from .modeling_tf_convnext import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnext/configuration_convnext.py b/venv/lib/python3.13/site-packages/transformers/models/convnext/configuration_convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..f54cba58cf296e8c8e3bae70a9f2e2ab21e3c660 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/convnext/configuration_convnext.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ConvNeXT model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class ConvNextConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an + ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ConvNeXT + [facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + patch_size (`int`, *optional*, defaults to 4): + Patch size to use in the patch embedding layer. + num_stages (`int`, *optional*, defaults to 4): + The number of stages in the model. + hidden_sizes (`list[int]`, *optional*, defaults to [96, 192, 384, 768]): + Dimensionality (hidden size) at each stage. + depths (`list[int]`, *optional*, defaults to [3, 3, 9, 3]): + Depth (number of blocks) for each stage. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + layer_scale_init_value (`float`, *optional*, defaults to 1e-6): + The initial value for the layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The drop rate for stochastic depth. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + ```python + >>> from transformers import ConvNextConfig, ConvNextModel + + >>> # Initializing a ConvNext convnext-tiny-224 style configuration + >>> configuration = ConvNextConfig() + + >>> # Initializing a model (with random weights) from the convnext-tiny-224 style configuration + >>> model = ConvNextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "convnext" + + def __init__( + self, + num_channels=3, + patch_size=4, + num_stages=4, + hidden_sizes=None, + depths=None, + hidden_act="gelu", + initializer_range=0.02, + layer_norm_eps=1e-12, + layer_scale_init_value=1e-6, + drop_path_rate=0.0, + image_size=224, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_channels = num_channels + self.patch_size = patch_size + self.num_stages = num_stages + self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes + self.depths = [3, 3, 9, 3] if depths is None else depths + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.image_size = image_size + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +class ConvNextOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 + + +__all__ = ["ConvNextConfig", "ConvNextOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnext/feature_extraction_convnext.py b/venv/lib/python3.13/site-packages/transformers/models/convnext/feature_extraction_convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..1fbb5184cf37cb0aedcafeab3a6b363ab047d9a0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/convnext/feature_extraction_convnext.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for ConvNeXT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_convnext import ConvNextImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ConvNextFeatureExtractor(ConvNextImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use ConvNextImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["ConvNextFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnext/image_processing_convnext.py b/venv/lib/python3.13/site-packages/transformers/models/convnext/image_processing_convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..af89274500ddb33160e92b5591bc9ac83c55f24c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/convnext/image_processing_convnext.py @@ -0,0 +1,325 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for ConvNeXT.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + center_crop, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ConvNextImageProcessor(BaseImageProcessor): + r""" + Constructs a ConvNeXT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden + by `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`): + Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is + resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will + be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can + be overridden by `size` in the `preprocess` method. + crop_pct (`float` *optional*, defaults to 224 / 256): + Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be + overridden by `crop_pct` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + crop_pct: Optional[float] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 384} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + # Default value set here for backwards compatibility where the value in config is None + self.crop_pct = crop_pct if crop_pct is not None else 224 / 256 + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + crop_pct: float, + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If + `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`. + Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`, + after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`. + crop_pct (`float`): + Percentage of the image to crop. Only has an effect if size < 384. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input + image. + """ + size = get_size_dict(size, default_to_square=False) + if "shortest_edge" not in size: + raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}") + shortest_edge = size["shortest_edge"] + + if shortest_edge < 384: + # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct + resize_shortest_edge = int(shortest_edge / crop_pct) + resize_size = get_resize_output_image_size( + image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format + ) + image = resize( + image=image, + size=resize_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + # then crop to (shortest_edge, shortest_edge) + return center_crop( + image=image, + size=(shortest_edge, shortest_edge), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + else: + # warping (no cropping) when evaluated at 384 or larger + return resize( + image, + size=(shortest_edge, shortest_edge), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + crop_pct: Optional[float] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + crop_pct (`float`, *optional*, defaults to `self.crop_pct`): + Percentage of the image to crop if size < 384. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + crop_pct = crop_pct if crop_pct is not None else self.crop_pct + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize( + image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format + ) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["ConvNextImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnext/image_processing_convnext_fast.py b/venv/lib/python3.13/site-packages/transformers/models/convnext/image_processing_convnext_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab00c0fd091369715e636bac663fbbbdc9239a0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/convnext/image_processing_convnext_fast.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for ConvNeXT.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_transforms import get_resize_output_image_size +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) + + +class ConvNextFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + crop_pct (`float`, *optional*): + Percentage of the image to crop. Only has an effect if size < 384. Can be + overridden by `crop_pct` in the`preprocess` method. + """ + + crop_pct: Optional[float] + + +@auto_docstring +class ConvNextImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"shortest_edge": 384} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + crop_pct = 224 / 256 + valid_kwargs = ConvNextFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[ConvNextFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[ConvNextFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def resize( + self, + image: "torch.Tensor", + size: dict[str, int], + crop_pct: float, + interpolation: PILImageResampling = PILImageResampling.BICUBIC, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`dict[str, int]`): + Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If + `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`. + Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`, + after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`. + crop_pct (`float`): + Percentage of the image to crop. Only has an effect if size < 384. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resizing the image. + + Returns: + `torch.Tensor`: Resized image. + """ + if not size.shortest_edge: + raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}") + shortest_edge = size["shortest_edge"] + + if shortest_edge < 384: + # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct + resize_shortest_edge = int(shortest_edge / crop_pct) + resize_size = get_resize_output_image_size( + image, size=resize_shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST + ) + image = F.resize( + image, + resize_size, + interpolation=interpolation, + **kwargs, + ) + # then crop to (shortest_edge, shortest_edge) + return F.center_crop( + image, + (shortest_edge, shortest_edge), + **kwargs, + ) + else: + # warping (no cropping) when evaluated at 384 or larger + return F.resize( + image, + (shortest_edge, shortest_edge), + interpolation=interpolation, + **kwargs, + ) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: dict[str, int], + crop_pct: float, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: int, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["ConvNextImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnext/modeling_convnext.py b/venv/lib/python3.13/site-packages/transformers/models/convnext/modeling_convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..3120c140d2ed41c16e659a780c7d4603dcfda7b7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/convnext/modeling_convnext.py @@ -0,0 +1,424 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ConvNext model.""" + +from typing import Optional + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.backbone_utils import BackboneMixin +from ...utils.generic import can_return_tuple +from .configuration_convnext import ConvNextConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext +class ConvNextDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class ConvNextLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +class ConvNextEmbeddings(nn.Module): + """This class is comparable to (and inspired by) the SwinEmbeddings class + found in src/transformers/models/swin/modeling_swin.py. + """ + + def __init__(self, config): + super().__init__() + self.patch_embeddings = nn.Conv2d( + config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size + ) + self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first") + self.num_channels = config.num_channels + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.patch_embeddings(pixel_values) + embeddings = self.layernorm(embeddings) + return embeddings + + +class ConvNextLayer(nn.Module): + """This corresponds to the `Block` class in the original implementation. + + There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C, + H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back + + The authors used (2) as they find it slightly faster in PyTorch. + + Args: + config ([`ConvNextConfig`]): Model configuration class. + dim (`int`): Number of input channels. + drop_path (`float`): Stochastic depth rate. Default: 0.0. + """ + + def __init__(self, config, dim, drop_path=0): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.layernorm = ConvNextLayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = ACT2FN[config.hidden_act] + self.pwconv2 = nn.Linear(4 * dim, dim) + self.layer_scale_parameter = ( + nn.Parameter(config.layer_scale_init_value * torch.ones(dim), requires_grad=True) + if config.layer_scale_init_value > 0 + else None + ) + self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + features = self.dwconv(features) + features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + features = self.layernorm(features) + features = self.pwconv1(features) + features = self.act(features) + features = self.pwconv2(features) + if self.layer_scale_parameter is not None: + features = self.layer_scale_parameter * features + features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + features = residual + self.drop_path(features) + return features + + +class ConvNextStage(nn.Module): + """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks. + + Args: + config ([`ConvNextConfig`]): Model configuration class. + in_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + depth (`int`): Number of residual blocks. + drop_path_rates(`list[float]`): Stochastic depth rates for each layer. + """ + + def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None): + super().__init__() + + if in_channels != out_channels or stride > 1: + self.downsampling_layer = nn.ModuleList( + [ + ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"), + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), + ] + ) + else: + self.downsampling_layer = nn.ModuleList() + drop_path_rates = drop_path_rates or [0.0] * depth + self.layers = nn.ModuleList( + [ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)] + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + for layer in self.downsampling_layer: + features = layer(features) + for layer in self.layers: + features = layer(features) + return features + + +class ConvNextEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.stages = nn.ModuleList() + drop_path_rates = [ + x.tolist() + for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths) + ] + prev_chs = config.hidden_sizes[0] + for i in range(config.num_stages): + out_chs = config.hidden_sizes[i] + stage = ConvNextStage( + config, + in_channels=prev_chs, + out_channels=out_chs, + stride=2 if i > 0 else 1, + depth=config.depths[i], + drop_path_rates=drop_path_rates[i], + ) + self.stages.append(stage) + prev_chs = out_chs + + def forward( + self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> BaseModelOutputWithNoAttention: + all_hidden_states = [hidden_states] if output_hidden_states else None + + for layer_module in self.stages: + hidden_states = layer_module(hidden_states) + if all_hidden_states is not None: + all_hidden_states.append(hidden_states) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + +@auto_docstring +class ConvNextPreTrainedModel(PreTrainedModel): + config: ConvNextConfig + base_model_prefix = "convnext" + main_input_name = "pixel_values" + _no_split_modules = ["ConvNextLayer"] + _can_record_outputs = {} # hidden states are collected explicitly + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ConvNextLayer): + if module.layer_scale_parameter is not None: + module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value) + + +@auto_docstring +class ConvNextModel(ConvNextPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = ConvNextEmbeddings(config) + self.encoder = ConvNextEncoder(config) + + # final layernorm layer + self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None + ) -> BaseModelOutputWithPoolingAndNoAttention: + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values) + encoder_outputs: BaseModelOutputWithNoAttention = self.encoder( + embedding_output, output_hidden_states=output_hidden_states + ) + last_hidden_state = encoder_outputs.last_hidden_state + + # global average pooling, (N, C, H, W) -> (N, C) + pooled_output = self.layernorm(last_hidden_state.mean([-2, -1])) + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class ConvNextForImageClassification(ConvNextPreTrainedModel): + accepts_loss_kwargs = False + + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.convnext = ConvNextModel(config) + + # Classifier head + if config.num_labels > 0: + self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels) + else: + self.classifier = nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnext(pixel_values, **kwargs) + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config) + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): + has_attentions = False + + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.embeddings = ConvNextEmbeddings(config) + self.encoder = ConvNextEncoder(config) + self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes + + # Add layer norms to hidden states of out_features + hidden_states_norms = {} + for stage, num_channels in zip(self._out_features, self.channels): + hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first") + self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) + + # initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224") + >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224") + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + embedding_output = self.embeddings(pixel_values) + outputs: BaseModelOutputWithPoolingAndNoAttention = self.encoder(embedding_output, output_hidden_states=True) + hidden_states = outputs.hidden_states + + feature_maps = [] + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + hidden_state = self.hidden_states_norms[stage](hidden_state) + feature_maps.append(hidden_state) + + return BackboneOutput( + feature_maps=tuple(feature_maps), + hidden_states=hidden_states if output_hidden_states else None, + ) + + +__all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnext/modeling_tf_convnext.py b/venv/lib/python3.13/site-packages/transformers/models/convnext/modeling_tf_convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..7306877466d9b793682d01d90645f08420930b59 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/convnext/modeling_tf_convnext.py @@ -0,0 +1,667 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 ConvNext model.""" + +from __future__ import annotations + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_convnext import ConvNextConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "ConvNextConfig" +_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224" + + +class TFConvNextDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_path: float, **kwargs): + super().__init__(**kwargs) + self.drop_path = drop_path + + def call(self, x: tf.Tensor, training=None): + if training: + keep_prob = 1 - self.drop_path + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +class TFConvNextEmbeddings(keras.layers.Layer): + """This class is comparable to (and inspired by) the SwinEmbeddings class + found in src/transformers/models/swin/modeling_swin.py. + """ + + def __init__(self, config: ConvNextConfig, **kwargs): + super().__init__(**kwargs) + self.patch_embeddings = keras.layers.Conv2D( + filters=config.hidden_sizes[0], + kernel_size=config.patch_size, + strides=config.patch_size, + name="patch_embeddings", + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer=keras.initializers.Zeros(), + ) + self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm") + self.num_channels = config.num_channels + self.config = config + + def call(self, pixel_values): + if isinstance(pixel_values, dict): + pixel_values = pixel_values["pixel_values"] + + tf.debugging.assert_equal( + shape_list(pixel_values)[1], + self.num_channels, + message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.", + ) + + # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + embeddings = self.patch_embeddings(pixel_values) + embeddings = self.layernorm(embeddings) + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "patch_embeddings", None) is not None: + with tf.name_scope(self.patch_embeddings.name): + self.patch_embeddings.build([None, None, None, self.config.num_channels]) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, None, self.config.hidden_sizes[0]]) + + +class TFConvNextLayer(keras.layers.Layer): + """This corresponds to the `Block` class in the original implementation. + + There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C, + H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back + + The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow + NHWC ordering, we can just apply the operations straight-away without the permutation. + + Args: + config ([`ConvNextConfig`]): Model configuration class. + dim (`int`): Number of input channels. + drop_path (`float`): Stochastic depth rate. Default: 0.0. + """ + + def __init__(self, config, dim, drop_path=0.0, **kwargs): + super().__init__(**kwargs) + self.dim = dim + self.config = config + self.dwconv = keras.layers.Conv2D( + filters=dim, + kernel_size=7, + padding="same", + groups=dim, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="dwconv", + ) # depthwise conv + self.layernorm = keras.layers.LayerNormalization( + epsilon=1e-6, + name="layernorm", + ) + self.pwconv1 = keras.layers.Dense( + units=4 * dim, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="pwconv1", + ) # pointwise/1x1 convs, implemented with linear layers + self.act = get_tf_activation(config.hidden_act) + self.pwconv2 = keras.layers.Dense( + units=dim, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="pwconv2", + ) + # Using `layers.Activation` instead of `tf.identity` to better control `training` + # behaviour. + self.drop_path = ( + TFConvNextDropPath(drop_path, name="drop_path") + if drop_path > 0.0 + else keras.layers.Activation("linear", name="drop_path") + ) + + def build(self, input_shape: tf.TensorShape = None): + # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa) + self.layer_scale_parameter = ( + self.add_weight( + shape=(self.dim,), + initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_parameter", + ) + if self.config.layer_scale_init_value > 0 + else None + ) + + if self.built: + return + self.built = True + if getattr(self, "dwconv", None) is not None: + with tf.name_scope(self.dwconv.name): + self.dwconv.build([None, None, None, self.dim]) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, None, self.dim]) + if getattr(self, "pwconv1", None) is not None: + with tf.name_scope(self.pwconv1.name): + self.pwconv1.build([None, None, self.dim]) + if getattr(self, "pwconv2", None) is not None: + with tf.name_scope(self.pwconv2.name): + self.pwconv2.build([None, None, 4 * self.dim]) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + + def call(self, hidden_states, training=False): + input = hidden_states + x = self.dwconv(hidden_states) + x = self.layernorm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.layer_scale_parameter is not None: + x = self.layer_scale_parameter * x + + x = input + self.drop_path(x, training=training) + return x + + +class TFConvNextStage(keras.layers.Layer): + """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks. + + Args: + config (`ConvNextV2Config`): + Model configuration class. + in_channels (`int`): + Number of input channels. + out_channels (`int`): + Number of output channels. + depth (`int`): + Number of residual blocks. + drop_path_rates(`list[float]`): + Stochastic depth rates for each layer. + """ + + def __init__( + self, + config: ConvNextConfig, + in_channels: int, + out_channels: int, + kernel_size: int = 2, + stride: int = 2, + depth: int = 2, + drop_path_rates: list[float] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + if in_channels != out_channels or stride > 1: + self.downsampling_layer = [ + keras.layers.LayerNormalization( + epsilon=1e-6, + name="downsampling_layer.0", + ), + # Inputs to this layer will follow NHWC format since we + # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings` + # layer. All the outputs throughout the model will be in NHWC + # from this point on until the output where we again change to + # NCHW. + keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer=keras.initializers.Zeros(), + name="downsampling_layer.1", + ), + ] + else: + self.downsampling_layer = [tf.identity] + + drop_path_rates = drop_path_rates or [0.0] * depth + self.layers = [ + TFConvNextLayer( + config, + dim=out_channels, + drop_path=drop_path_rates[j], + name=f"layers.{j}", + ) + for j in range(depth) + ] + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + def call(self, hidden_states): + for layer in self.downsampling_layer: + hidden_states = layer(hidden_states) + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + if self.in_channels != self.out_channels or self.stride > 1: + with tf.name_scope(self.downsampling_layer[0].name): + self.downsampling_layer[0].build([None, None, None, self.in_channels]) + with tf.name_scope(self.downsampling_layer[1].name): + self.downsampling_layer[1].build([None, None, None, self.in_channels]) + + +class TFConvNextEncoder(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.stages = [] + drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths)) + drop_path_rates = tf.split(drop_path_rates, config.depths) + drop_path_rates = [x.numpy().tolist() for x in drop_path_rates] + prev_chs = config.hidden_sizes[0] + for i in range(config.num_stages): + out_chs = config.hidden_sizes[i] + stage = TFConvNextStage( + config, + in_channels=prev_chs, + out_channels=out_chs, + stride=2 if i > 0 else 1, + depth=config.depths[i], + drop_path_rates=drop_path_rates[i], + name=f"stages.{i}", + ) + self.stages.append(stage) + prev_chs = out_chs + + def call(self, hidden_states, output_hidden_states=False, return_dict=True): + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.stages): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = layer_module(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + def build(self, input_shape=None): + for stage in self.stages: + with tf.name_scope(stage.name): + stage.build(None) + + +@keras_serializable +class TFConvNextMainLayer(keras.layers.Layer): + config_class = ConvNextConfig + + def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embeddings = TFConvNextEmbeddings(config, name="embeddings") + self.encoder = TFConvNextEncoder(config, name="encoder") + self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + # We are setting the `data_format` like so because from here on we will revert to the + # NCHW output format + self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None + + @unpack_inputs + def call( + self, + pixel_values: TFModelInputType | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values, training=training) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + last_hidden_state = encoder_outputs[0] + # Change to NCHW output format have uniformity in the modules + last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2)) + pooled_output = self.layernorm(self.pooler(last_hidden_state)) + + # Change the other hidden state outputs to NCHW as well + if output_hidden_states: + hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]) + + if not return_dict: + hidden_states = hidden_states if output_hidden_states else () + return (last_hidden_state, pooled_output) + hidden_states + + return TFBaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, self.config.hidden_sizes[-1]]) + + +class TFConvNextPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ConvNextConfig + base_model_prefix = "convnext" + main_input_name = "pixel_values" + + +CONVNEXT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CONVNEXT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ConvNextImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. +""" + + +@add_start_docstrings( + "The bare ConvNext model outputting raw features without any specific head on top.", + CONVNEXT_START_DOCSTRING, +) +class TFConvNextModel(TFConvNextPreTrainedModel): + def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext") + + @unpack_inputs + @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: TFModelInputType | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFConvNextModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224") + >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224") + + >>> inputs = image_processor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + outputs = self.convnext( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return (outputs[0],) + outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + hidden_states=outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convnext", None) is not None: + with tf.name_scope(self.convnext.name): + self.convnext.build(None) + + +@add_start_docstrings( + """ + ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + CONVNEXT_START_DOCSTRING, +) +class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: ConvNextConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.convnext = TFConvNextMainLayer(config, name="convnext") + + # Classifier head + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: TFModelInputType | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFConvNextForImageClassification + >>> import tensorflow as tf + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224") + >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224") + + >>> inputs = image_processor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] + >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + outputs = self.convnext( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convnext", None) is not None: + with tf.name_scope(self.convnext.name): + self.convnext.build(None) + if getattr(self, "classifier", None) is not None: + if hasattr(self.classifier, "name"): + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_sizes[-1]]) + + +__all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/cvt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/cvt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..756aded9e6ad2964ffa5add89af2c4af56071e88 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/cvt/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_cvt import * + from .modeling_cvt import * + from .modeling_tf_cvt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/cvt/configuration_cvt.py b/venv/lib/python3.13/site-packages/transformers/models/cvt/configuration_cvt.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3f6b339627d868f79e4c53a022fed117166074 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/cvt/configuration_cvt.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CvT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class CvtConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a CvT model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the CvT + [microsoft/cvt-13](https://huggingface.co/microsoft/cvt-13) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + patch_sizes (`list[int]`, *optional*, defaults to `[7, 3, 3]`): + The kernel size of each encoder's patch embedding. + patch_stride (`list[int]`, *optional*, defaults to `[4, 2, 2]`): + The stride size of each encoder's patch embedding. + patch_padding (`list[int]`, *optional*, defaults to `[2, 1, 1]`): + The padding size of each encoder's patch embedding. + embed_dim (`list[int]`, *optional*, defaults to `[64, 192, 384]`): + Dimension of each of the encoder blocks. + num_heads (`list[int]`, *optional*, defaults to `[1, 3, 6]`): + Number of attention heads for each attention layer in each block of the Transformer encoder. + depth (`list[int]`, *optional*, defaults to `[1, 2, 10]`): + The number of layers in each encoder block. + mlp_ratios (`list[float]`, *optional*, defaults to `[4.0, 4.0, 4.0, 4.0]`): + Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the + encoder blocks. + attention_drop_rate (`list[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`): + The dropout ratio for the attention probabilities. + drop_rate (`list[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`): + The dropout ratio for the patch embeddings probabilities. + drop_path_rate (`list[float]`, *optional*, defaults to `[0.0, 0.0, 0.1]`): + The dropout probability for stochastic depth, used in the blocks of the Transformer encoder. + qkv_bias (`list[bool]`, *optional*, defaults to `[True, True, True]`): + The bias bool for query, key and value in attentions + cls_token (`list[bool]`, *optional*, defaults to `[False, False, True]`): + Whether or not to add a classification token to the output of each of the last 3 stages. + qkv_projection_method (`list[string]`, *optional*, defaults to ["dw_bn", "dw_bn", "dw_bn"]`): + The projection method for query, key and value Default is depth-wise convolutions with batch norm. For + Linear projection use "avg". + kernel_qkv (`list[int]`, *optional*, defaults to `[3, 3, 3]`): + The kernel size for query, key and value in attention layer + padding_kv (`list[int]`, *optional*, defaults to `[1, 1, 1]`): + The padding size for key and value in attention layer + stride_kv (`list[int]`, *optional*, defaults to `[2, 2, 2]`): + The stride size for key and value in attention layer + padding_q (`list[int]`, *optional*, defaults to `[1, 1, 1]`): + The padding size for query in attention layer + stride_q (`list[int]`, *optional*, defaults to `[1, 1, 1]`): + The stride size for query in attention layer + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import CvtConfig, CvtModel + + >>> # Initializing a Cvt msft/cvt style configuration + >>> configuration = CvtConfig() + + >>> # Initializing a model (with random weights) from the msft/cvt style configuration + >>> model = CvtModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "cvt" + + def __init__( + self, + num_channels=3, + patch_sizes=[7, 3, 3], + patch_stride=[4, 2, 2], + patch_padding=[2, 1, 1], + embed_dim=[64, 192, 384], + num_heads=[1, 3, 6], + depth=[1, 2, 10], + mlp_ratio=[4.0, 4.0, 4.0], + attention_drop_rate=[0.0, 0.0, 0.0], + drop_rate=[0.0, 0.0, 0.0], + drop_path_rate=[0.0, 0.0, 0.1], + qkv_bias=[True, True, True], + cls_token=[False, False, True], + qkv_projection_method=["dw_bn", "dw_bn", "dw_bn"], + kernel_qkv=[3, 3, 3], + padding_kv=[1, 1, 1], + stride_kv=[2, 2, 2], + padding_q=[1, 1, 1], + stride_q=[1, 1, 1], + initializer_range=0.02, + layer_norm_eps=1e-12, + **kwargs, + ): + super().__init__(**kwargs) + self.num_channels = num_channels + self.patch_sizes = patch_sizes + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.embed_dim = embed_dim + self.num_heads = num_heads + self.depth = depth + self.mlp_ratio = mlp_ratio + self.attention_drop_rate = attention_drop_rate + self.drop_rate = drop_rate + self.drop_path_rate = drop_path_rate + self.qkv_bias = qkv_bias + self.cls_token = cls_token + self.qkv_projection_method = qkv_projection_method + self.kernel_qkv = kernel_qkv + self.padding_kv = padding_kv + self.stride_kv = stride_kv + self.padding_q = padding_q + self.stride_q = stride_q + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + +__all__ = ["CvtConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/cvt/modeling_cvt.py b/venv/lib/python3.13/site-packages/transformers/models/cvt/modeling_cvt.py new file mode 100644 index 0000000000000000000000000000000000000000..9d935ee84893e05e925fde0ab2fd80397c7144e5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/cvt/modeling_cvt.py @@ -0,0 +1,670 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch CvT model.""" + +import collections.abc +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from .configuration_cvt import CvtConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model's outputs, with potential hidden states and attentions. + """ +) +class BaseModelOutputWithCLSToken(ModelOutput): + r""" + cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`): + Classification token at the output of the last layer of the model. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cls_token_value: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class CvtDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class CvtEmbeddings(nn.Module): + """ + Construct the CvT embeddings. + """ + + def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate): + super().__init__() + self.convolution_embeddings = CvtConvEmbeddings( + patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding + ) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, pixel_values): + hidden_state = self.convolution_embeddings(pixel_values) + hidden_state = self.dropout(hidden_state) + return hidden_state + + +class CvtConvEmbeddings(nn.Module): + """ + Image to Conv Embedding. + """ + + def __init__(self, patch_size, num_channels, embed_dim, stride, padding): + super().__init__() + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + self.patch_size = patch_size + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) + self.normalization = nn.LayerNorm(embed_dim) + + def forward(self, pixel_values): + pixel_values = self.projection(pixel_values) + batch_size, num_channels, height, width = pixel_values.shape + hidden_size = height * width + # rearrange "b c h w -> b (h w) c" + pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1) + if self.normalization: + pixel_values = self.normalization(pixel_values) + # rearrange "b (h w) c" -> b c h w" + pixel_values = pixel_values.permute(0, 2, 1).view(batch_size, num_channels, height, width) + return pixel_values + + +class CvtSelfAttentionConvProjection(nn.Module): + def __init__(self, embed_dim, kernel_size, padding, stride): + super().__init__() + self.convolution = nn.Conv2d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=embed_dim, + ) + self.normalization = nn.BatchNorm2d(embed_dim) + + def forward(self, hidden_state): + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class CvtSelfAttentionLinearProjection(nn.Module): + def forward(self, hidden_state): + batch_size, num_channels, height, width = hidden_state.shape + hidden_size = height * width + # rearrange " b c h w -> b (h w) c" + hidden_state = hidden_state.view(batch_size, num_channels, hidden_size).permute(0, 2, 1) + return hidden_state + + +class CvtSelfAttentionProjection(nn.Module): + def __init__(self, embed_dim, kernel_size, padding, stride, projection_method="dw_bn"): + super().__init__() + if projection_method == "dw_bn": + self.convolution_projection = CvtSelfAttentionConvProjection(embed_dim, kernel_size, padding, stride) + self.linear_projection = CvtSelfAttentionLinearProjection() + + def forward(self, hidden_state): + hidden_state = self.convolution_projection(hidden_state) + hidden_state = self.linear_projection(hidden_state) + return hidden_state + + +class CvtSelfAttention(nn.Module): + def __init__( + self, + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + with_cls_token=True, + **kwargs, + ): + super().__init__() + self.scale = embed_dim**-0.5 + self.with_cls_token = with_cls_token + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.convolution_projection_query = CvtSelfAttentionProjection( + embed_dim, + kernel_size, + padding_q, + stride_q, + projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method, + ) + self.convolution_projection_key = CvtSelfAttentionProjection( + embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method + ) + self.convolution_projection_value = CvtSelfAttentionProjection( + embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method + ) + + self.projection_query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.projection_key = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.projection_value = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + + self.dropout = nn.Dropout(attention_drop_rate) + + def rearrange_for_multi_head_attention(self, hidden_state): + batch_size, hidden_size, _ = hidden_state.shape + head_dim = self.embed_dim // self.num_heads + # rearrange 'b t (h d) -> b h t d' + return hidden_state.view(batch_size, hidden_size, self.num_heads, head_dim).permute(0, 2, 1, 3) + + def forward(self, hidden_state, height, width): + if self.with_cls_token: + cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1) + batch_size, hidden_size, num_channels = hidden_state.shape + # rearrange "b (h w) c -> b c h w" + hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width) + + key = self.convolution_projection_key(hidden_state) + query = self.convolution_projection_query(hidden_state) + value = self.convolution_projection_value(hidden_state) + + if self.with_cls_token: + query = torch.cat((cls_token, query), dim=1) + key = torch.cat((cls_token, key), dim=1) + value = torch.cat((cls_token, value), dim=1) + + head_dim = self.embed_dim // self.num_heads + + query = self.rearrange_for_multi_head_attention(self.projection_query(query)) + key = self.rearrange_for_multi_head_attention(self.projection_key(key)) + value = self.rearrange_for_multi_head_attention(self.projection_value(value)) + + attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale + attention_probs = torch.nn.functional.softmax(attention_score, dim=-1) + attention_probs = self.dropout(attention_probs) + + context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value]) + # rearrange"b h t d -> b t (h d)" + _, _, hidden_size, _ = context.shape + context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, hidden_size, self.num_heads * head_dim) + return context + + +class CvtSelfOutput(nn.Module): + """ + The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, embed_dim, drop_rate): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(drop_rate) + + def forward(self, hidden_state, input_tensor): + hidden_state = self.dense(hidden_state) + hidden_state = self.dropout(hidden_state) + return hidden_state + + +class CvtAttention(nn.Module): + def __init__( + self, + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + drop_rate, + with_cls_token=True, + ): + super().__init__() + self.attention = CvtSelfAttention( + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + with_cls_token, + ) + self.output = CvtSelfOutput(embed_dim, drop_rate) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_state, height, width): + self_output = self.attention(hidden_state, height, width) + attention_output = self.output(self_output, hidden_state) + return attention_output + + +class CvtIntermediate(nn.Module): + def __init__(self, embed_dim, mlp_ratio): + super().__init__() + self.dense = nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) + self.activation = nn.GELU() + + def forward(self, hidden_state): + hidden_state = self.dense(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class CvtOutput(nn.Module): + def __init__(self, embed_dim, mlp_ratio, drop_rate): + super().__init__() + self.dense = nn.Linear(int(embed_dim * mlp_ratio), embed_dim) + self.dropout = nn.Dropout(drop_rate) + + def forward(self, hidden_state, input_tensor): + hidden_state = self.dense(hidden_state) + hidden_state = self.dropout(hidden_state) + hidden_state = hidden_state + input_tensor + return hidden_state + + +class CvtLayer(nn.Module): + """ + CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps). + """ + + def __init__( + self, + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + drop_rate, + mlp_ratio, + drop_path_rate, + with_cls_token=True, + ): + super().__init__() + self.attention = CvtAttention( + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + drop_rate, + with_cls_token, + ) + + self.intermediate = CvtIntermediate(embed_dim, mlp_ratio) + self.output = CvtOutput(embed_dim, mlp_ratio, drop_rate) + self.drop_path = CvtDropPath(drop_prob=drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_before = nn.LayerNorm(embed_dim) + self.layernorm_after = nn.LayerNorm(embed_dim) + + def forward(self, hidden_state, height, width): + self_attention_output = self.attention( + self.layernorm_before(hidden_state), # in Cvt, layernorm is applied before self-attention + height, + width, + ) + attention_output = self_attention_output + attention_output = self.drop_path(attention_output) + + # first residual connection + hidden_state = attention_output + hidden_state + + # in Cvt, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_state) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_state) + layer_output = self.drop_path(layer_output) + return layer_output + + +class CvtStage(nn.Module): + def __init__(self, config, stage): + super().__init__() + self.config = config + self.stage = stage + if self.config.cls_token[self.stage]: + self.cls_token = nn.Parameter(torch.randn(1, 1, self.config.embed_dim[-1])) + + self.embedding = CvtEmbeddings( + patch_size=config.patch_sizes[self.stage], + stride=config.patch_stride[self.stage], + num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1], + embed_dim=config.embed_dim[self.stage], + padding=config.patch_padding[self.stage], + dropout_rate=config.drop_rate[self.stage], + ) + + drop_path_rates = [ + x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage], device="cpu") + ] + + self.layers = nn.Sequential( + *[ + CvtLayer( + num_heads=config.num_heads[self.stage], + embed_dim=config.embed_dim[self.stage], + kernel_size=config.kernel_qkv[self.stage], + padding_q=config.padding_q[self.stage], + padding_kv=config.padding_kv[self.stage], + stride_kv=config.stride_kv[self.stage], + stride_q=config.stride_q[self.stage], + qkv_projection_method=config.qkv_projection_method[self.stage], + qkv_bias=config.qkv_bias[self.stage], + attention_drop_rate=config.attention_drop_rate[self.stage], + drop_rate=config.drop_rate[self.stage], + drop_path_rate=drop_path_rates[self.stage], + mlp_ratio=config.mlp_ratio[self.stage], + with_cls_token=config.cls_token[self.stage], + ) + for _ in range(config.depth[self.stage]) + ] + ) + + def forward(self, hidden_state): + cls_token = None + hidden_state = self.embedding(hidden_state) + batch_size, num_channels, height, width = hidden_state.shape + # rearrange b c h w -> b (h w) c" + hidden_state = hidden_state.view(batch_size, num_channels, height * width).permute(0, 2, 1) + if self.config.cls_token[self.stage]: + cls_token = self.cls_token.expand(batch_size, -1, -1) + hidden_state = torch.cat((cls_token, hidden_state), dim=1) + + for layer in self.layers: + layer_outputs = layer(hidden_state, height, width) + hidden_state = layer_outputs + + if self.config.cls_token[self.stage]: + cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1) + hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width) + return hidden_state, cls_token + + +class CvtEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.stages = nn.ModuleList([]) + for stage_idx in range(len(config.depth)): + self.stages.append(CvtStage(config, stage_idx)) + + def forward(self, pixel_values, output_hidden_states=False, return_dict=True): + all_hidden_states = () if output_hidden_states else None + hidden_state = pixel_values + + cls_token = None + for _, (stage_module) in enumerate(self.stages): + hidden_state, cls_token = stage_module(hidden_state) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None) + + return BaseModelOutputWithCLSToken( + last_hidden_state=hidden_state, + cls_token_value=cls_token, + hidden_states=all_hidden_states, + ) + + +@auto_docstring +class CvtPreTrainedModel(PreTrainedModel): + config: CvtConfig + base_model_prefix = "cvt" + main_input_name = "pixel_values" + _no_split_modules = ["CvtLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, CvtStage): + if self.config.cls_token[module.stage]: + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data, mean=0.0, std=self.config.initializer_range + ) + + +@auto_docstring +class CvtModel(CvtPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + self.encoder = CvtEncoder(config) + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithCLSToken]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithCLSToken( + last_hidden_state=sequence_output, + cls_token_value=encoder_outputs.cls_token_value, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """ +) +class CvtForImageClassification(CvtPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.cvt = CvtModel(config, add_pooling_layer=False) + self.layernorm = nn.LayerNorm(config.embed_dim[-1]) + # Classifier head + self.classifier = ( + nn.Linear(config.embed_dim[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.cvt( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + cls_token = outputs[1] + if self.config.cls_token[-1]: + sequence_output = self.layernorm(cls_token) + else: + batch_size, num_channels, height, width = sequence_output.shape + # rearrange "b c h w -> b (h w) c" + sequence_output = sequence_output.view(batch_size, num_channels, height * width).permute(0, 2, 1) + sequence_output = self.layernorm(sequence_output) + + sequence_output_mean = sequence_output.mean(dim=1) + logits = self.classifier(sequence_output_mean) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +__all__ = ["CvtForImageClassification", "CvtModel", "CvtPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/cvt/modeling_tf_cvt.py b/venv/lib/python3.13/site-packages/transformers/models/cvt/modeling_tf_cvt.py new file mode 100644 index 0000000000000000000000000000000000000000..9239e1918eeca7588e0a978af4ef524380cd3f5f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/cvt/modeling_tf_cvt.py @@ -0,0 +1,1095 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Cvt model.""" + +from __future__ import annotations + +import collections.abc +from dataclasses import dataclass + +import tensorflow as tf + +from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_cvt import CvtConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "CvtConfig" + + +@dataclass +class TFBaseModelOutputWithCLSToken(ModelOutput): + """ + Base class for model's outputs. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cls_token_value (`tf.Tensor` of shape `(batch_size, 1, hidden_size)`): + Classification token at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + """ + + last_hidden_state: tf.Tensor | None = None + cls_token_value: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + + +class TFCvtDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_prob: float, **kwargs): + super().__init__(**kwargs) + self.drop_prob = drop_prob + + def call(self, x: tf.Tensor, training=None): + if self.drop_prob == 0.0 or not training: + return x + keep_prob = 1 - self.drop_prob + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + + +class TFCvtEmbeddings(keras.layers.Layer): + """Construct the Convolutional Token Embeddings.""" + + def __init__( + self, + config: CvtConfig, + patch_size: int, + num_channels: int, + embed_dim: int, + stride: int, + padding: int, + dropout_rate: float, + **kwargs, + ): + super().__init__(**kwargs) + self.convolution_embeddings = TFCvtConvEmbeddings( + config, + patch_size=patch_size, + num_channels=num_channels, + embed_dim=embed_dim, + stride=stride, + padding=padding, + name="convolution_embeddings", + ) + self.dropout = keras.layers.Dropout(dropout_rate) + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution_embeddings(pixel_values) + hidden_state = self.dropout(hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution_embeddings", None) is not None: + with tf.name_scope(self.convolution_embeddings.name): + self.convolution_embeddings.build(None) + + +class TFCvtConvEmbeddings(keras.layers.Layer): + """Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts.""" + + def __init__( + self, + config: CvtConfig, + patch_size: int, + num_channels: int, + embed_dim: int, + stride: int, + padding: int, + **kwargs, + ): + super().__init__(**kwargs) + self.padding = keras.layers.ZeroPadding2D(padding=padding) + self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + self.projection = keras.layers.Conv2D( + filters=embed_dim, + kernel_size=patch_size, + strides=stride, + padding="valid", + data_format="channels_last", + kernel_initializer=get_initializer(config.initializer_range), + name="projection", + ) + # Using the same default epsilon as PyTorch + self.normalization = keras.layers.LayerNormalization(epsilon=1e-5, name="normalization") + self.num_channels = num_channels + self.embed_dim = embed_dim + + def call(self, pixel_values: tf.Tensor) -> tf.Tensor: + if isinstance(pixel_values, dict): + pixel_values = pixel_values["pixel_values"] + + pixel_values = self.projection(self.padding(pixel_values)) + + # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" + batch_size, height, width, num_channels = shape_list(pixel_values) + hidden_size = height * width + pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels)) + pixel_values = self.normalization(pixel_values) + + # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" + pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels)) + return pixel_values + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, self.embed_dim]) + + +class TFCvtSelfAttentionConvProjection(keras.layers.Layer): + """Convolutional projection layer.""" + + def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs): + super().__init__(**kwargs) + self.padding = keras.layers.ZeroPadding2D(padding=padding) + self.convolution = keras.layers.Conv2D( + filters=embed_dim, + kernel_size=kernel_size, + kernel_initializer=get_initializer(config.initializer_range), + padding="valid", + strides=stride, + use_bias=False, + name="convolution", + groups=embed_dim, + ) + # Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum) + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") + self.embed_dim = embed_dim + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution(self.padding(hidden_state)) + hidden_state = self.normalization(hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution", None) is not None: + with tf.name_scope(self.convolution.name): + self.convolution.build([None, None, None, self.embed_dim]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.embed_dim]) + + +class TFCvtSelfAttentionLinearProjection(keras.layers.Layer): + """Linear projection layer used to flatten tokens into 1D.""" + + def call(self, hidden_state: tf.Tensor) -> tf.Tensor: + # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" + batch_size, height, width, num_channels = shape_list(hidden_state) + hidden_size = height * width + hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels)) + return hidden_state + + +class TFCvtSelfAttentionProjection(keras.layers.Layer): + """Convolutional Projection for Attention.""" + + def __init__( + self, + config: CvtConfig, + embed_dim: int, + kernel_size: int, + stride: int, + padding: int, + projection_method: str = "dw_bn", + **kwargs, + ): + super().__init__(**kwargs) + if projection_method == "dw_bn": + self.convolution_projection = TFCvtSelfAttentionConvProjection( + config, embed_dim, kernel_size, stride, padding, name="convolution_projection" + ) + self.linear_projection = TFCvtSelfAttentionLinearProjection() + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution_projection(hidden_state, training=training) + hidden_state = self.linear_projection(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution_projection", None) is not None: + with tf.name_scope(self.convolution_projection.name): + self.convolution_projection.build(None) + + +class TFCvtSelfAttention(keras.layers.Layer): + """ + Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), is applied for + query, key, and value embeddings. + """ + + def __init__( + self, + config: CvtConfig, + num_heads: int, + embed_dim: int, + kernel_size: int, + stride_q: int, + stride_kv: int, + padding_q: int, + padding_kv: int, + qkv_projection_method: str, + qkv_bias: bool, + attention_drop_rate: float, + with_cls_token: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.scale = embed_dim**-0.5 + self.with_cls_token = with_cls_token + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.convolution_projection_query = TFCvtSelfAttentionProjection( + config, + embed_dim, + kernel_size, + stride_q, + padding_q, + projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method, + name="convolution_projection_query", + ) + self.convolution_projection_key = TFCvtSelfAttentionProjection( + config, + embed_dim, + kernel_size, + stride_kv, + padding_kv, + projection_method=qkv_projection_method, + name="convolution_projection_key", + ) + self.convolution_projection_value = TFCvtSelfAttentionProjection( + config, + embed_dim, + kernel_size, + stride_kv, + padding_kv, + projection_method=qkv_projection_method, + name="convolution_projection_value", + ) + + self.projection_query = keras.layers.Dense( + units=embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=qkv_bias, + bias_initializer="zeros", + name="projection_query", + ) + self.projection_key = keras.layers.Dense( + units=embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=qkv_bias, + bias_initializer="zeros", + name="projection_key", + ) + self.projection_value = keras.layers.Dense( + units=embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=qkv_bias, + bias_initializer="zeros", + name="projection_value", + ) + self.dropout = keras.layers.Dropout(attention_drop_rate) + + def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tensor: + batch_size, hidden_size, _ = shape_list(hidden_state) + head_dim = self.embed_dim // self.num_heads + hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, self.num_heads, head_dim)) + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3)) + return hidden_state + + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: + if self.with_cls_token: + cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) + + # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" + batch_size, hidden_size, num_channels = shape_list(hidden_state) + hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels)) + + key = self.convolution_projection_key(hidden_state, training=training) + query = self.convolution_projection_query(hidden_state, training=training) + value = self.convolution_projection_value(hidden_state, training=training) + + if self.with_cls_token: + query = tf.concat((cls_token, query), axis=1) + key = tf.concat((cls_token, key), axis=1) + value = tf.concat((cls_token, value), axis=1) + + head_dim = self.embed_dim // self.num_heads + + query = self.rearrange_for_multi_head_attention(self.projection_query(query)) + key = self.rearrange_for_multi_head_attention(self.projection_key(key)) + value = self.rearrange_for_multi_head_attention(self.projection_value(value)) + + attention_score = tf.matmul(query, key, transpose_b=True) * self.scale + attention_probs = stable_softmax(logits=attention_score, axis=-1) + attention_probs = self.dropout(attention_probs, training=training) + + context = tf.matmul(attention_probs, value) + # "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)" + _, _, hidden_size, _ = shape_list(context) + context = tf.transpose(context, perm=(0, 2, 1, 3)) + context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim)) + return context + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution_projection_query", None) is not None: + with tf.name_scope(self.convolution_projection_query.name): + self.convolution_projection_query.build(None) + if getattr(self, "convolution_projection_key", None) is not None: + with tf.name_scope(self.convolution_projection_key.name): + self.convolution_projection_key.build(None) + if getattr(self, "convolution_projection_value", None) is not None: + with tf.name_scope(self.convolution_projection_value.name): + self.convolution_projection_value.build(None) + if getattr(self, "projection_query", None) is not None: + with tf.name_scope(self.projection_query.name): + self.projection_query.build([None, None, self.embed_dim]) + if getattr(self, "projection_key", None) is not None: + with tf.name_scope(self.projection_key.name): + self.projection_key.build([None, None, self.embed_dim]) + if getattr(self, "projection_value", None) is not None: + with tf.name_scope(self.projection_value.name): + self.projection_value.build([None, None, self.embed_dim]) + + +class TFCvtSelfOutput(keras.layers.Layer): + """Output of the Attention layer .""" + + def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(drop_rate) + self.embed_dim = embed_dim + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.dense(inputs=hidden_state) + hidden_state = self.dropout(inputs=hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.embed_dim]) + + +class TFCvtAttention(keras.layers.Layer): + """Attention layer. First chunk of the convolutional transformer block.""" + + def __init__( + self, + config: CvtConfig, + num_heads: int, + embed_dim: int, + kernel_size: int, + stride_q: int, + stride_kv: int, + padding_q: int, + padding_kv: int, + qkv_projection_method: str, + qkv_bias: bool, + attention_drop_rate: float, + drop_rate: float, + with_cls_token: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.attention = TFCvtSelfAttention( + config, + num_heads, + embed_dim, + kernel_size, + stride_q, + stride_kv, + padding_q, + padding_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + with_cls_token, + name="attention", + ) + self.dense_output = TFCvtSelfOutput(config, embed_dim, drop_rate, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False): + self_output = self.attention(hidden_state, height, width, training=training) + attention_output = self.dense_output(self_output, training=training) + return attention_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +class TFCvtIntermediate(keras.layers.Layer): + """Intermediate dense layer. Second chunk of the convolutional transformer block.""" + + def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + units=int(embed_dim * mlp_ratio), + kernel_initializer=get_initializer(config.initializer_range), + activation="gelu", + name="dense", + ) + self.embed_dim = embed_dim + + def call(self, hidden_state: tf.Tensor) -> tf.Tensor: + hidden_state = self.dense(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.embed_dim]) + + +class TFCvtOutput(keras.layers.Layer): + """ + Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection. + """ + + def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, drop_rate: int, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(drop_rate) + self.embed_dim = embed_dim + self.mlp_ratio = mlp_ratio + + def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.dense(inputs=hidden_state) + hidden_state = self.dropout(inputs=hidden_state, training=training) + hidden_state = hidden_state + input_tensor + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, int(self.embed_dim * self.mlp_ratio)]) + + +class TFCvtLayer(keras.layers.Layer): + """ + Convolutional Transformer Block composed by attention layers, normalization and multi-layer perceptrons (mlps). It + consists of 3 chunks : an attention layer, an intermediate dense layer and an output layer. This corresponds to the + `Block` class in the original implementation. + """ + + def __init__( + self, + config: CvtConfig, + num_heads: int, + embed_dim: int, + kernel_size: int, + stride_q: int, + stride_kv: int, + padding_q: int, + padding_kv: int, + qkv_projection_method: str, + qkv_bias: bool, + attention_drop_rate: float, + drop_rate: float, + mlp_ratio: float, + drop_path_rate: float, + with_cls_token: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.attention = TFCvtAttention( + config, + num_heads, + embed_dim, + kernel_size, + stride_q, + stride_kv, + padding_q, + padding_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + drop_rate, + with_cls_token, + name="attention", + ) + self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate") + self.dense_output = TFCvtOutput(config, embed_dim, mlp_ratio, drop_rate, name="output") + # Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour. + self.drop_path = ( + TFCvtDropPath(drop_path_rate, name="drop_path") + if drop_path_rate > 0.0 + else keras.layers.Activation("linear", name="drop_path") + ) + # Using the same default epsilon as PyTorch + self.layernorm_before = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before") + self.layernorm_after = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after") + self.embed_dim = embed_dim + + def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: + # in Cvt, layernorm is applied before self-attention + attention_output = self.attention(self.layernorm_before(hidden_state), height, width, training=training) + attention_output = self.drop_path(attention_output, training=training) + + # first residual connection + hidden_state = attention_output + hidden_state + + # in Cvt, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_state) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.dense_output(layer_output, hidden_state) + layer_output = self.drop_path(layer_output, training=training) + return layer_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + if getattr(self, "layernorm_before", None) is not None: + with tf.name_scope(self.layernorm_before.name): + self.layernorm_before.build([None, None, self.embed_dim]) + if getattr(self, "layernorm_after", None) is not None: + with tf.name_scope(self.layernorm_after.name): + self.layernorm_after.build([None, None, self.embed_dim]) + + +class TFCvtStage(keras.layers.Layer): + """ + Cvt stage (encoder block). Each stage has 2 parts : + - (1) A Convolutional Token Embedding layer + - (2) A Convolutional Transformer Block (layer). + The classification token is added only in the last stage. + + Args: + config ([`CvtConfig`]): Model configuration class. + stage (`int`): Stage number. + """ + + def __init__(self, config: CvtConfig, stage: int, **kwargs): + super().__init__(**kwargs) + self.config = config + self.stage = stage + if self.config.cls_token[self.stage]: + self.cls_token = self.add_weight( + shape=(1, 1, self.config.embed_dim[-1]), + initializer=get_initializer(self.config.initializer_range), + trainable=True, + name="cvt.encoder.stages.2.cls_token", + ) + + self.embedding = TFCvtEmbeddings( + self.config, + patch_size=config.patch_sizes[self.stage], + num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1], + stride=config.patch_stride[self.stage], + embed_dim=config.embed_dim[self.stage], + padding=config.patch_padding[self.stage], + dropout_rate=config.drop_rate[self.stage], + name="embedding", + ) + + drop_path_rates = tf.linspace(0.0, config.drop_path_rate[self.stage], config.depth[stage]) + drop_path_rates = [x.numpy().item() for x in drop_path_rates] + self.layers = [ + TFCvtLayer( + config, + num_heads=config.num_heads[self.stage], + embed_dim=config.embed_dim[self.stage], + kernel_size=config.kernel_qkv[self.stage], + stride_q=config.stride_q[self.stage], + stride_kv=config.stride_kv[self.stage], + padding_q=config.padding_q[self.stage], + padding_kv=config.padding_kv[self.stage], + qkv_projection_method=config.qkv_projection_method[self.stage], + qkv_bias=config.qkv_bias[self.stage], + attention_drop_rate=config.attention_drop_rate[self.stage], + drop_rate=config.drop_rate[self.stage], + mlp_ratio=config.mlp_ratio[self.stage], + drop_path_rate=drop_path_rates[self.stage], + with_cls_token=config.cls_token[self.stage], + name=f"layers.{j}", + ) + for j in range(config.depth[self.stage]) + ] + + def call(self, hidden_state: tf.Tensor, training: bool = False): + cls_token = None + hidden_state = self.embedding(hidden_state, training) + + # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" + batch_size, height, width, num_channels = shape_list(hidden_state) + hidden_size = height * width + hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels)) + + if self.config.cls_token[self.stage]: + cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0) + hidden_state = tf.concat((cls_token, hidden_state), axis=1) + + for layer in self.layers: + layer_outputs = layer(hidden_state, height, width, training=training) + hidden_state = layer_outputs + + if self.config.cls_token[self.stage]: + cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) + + # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" + hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels)) + return hidden_state, cls_token + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedding", None) is not None: + with tf.name_scope(self.embedding.name): + self.embedding.build(None) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFCvtEncoder(keras.layers.Layer): + """ + Convolutional Vision Transformer encoder. CVT has 3 stages of encoder blocks with their respective number of layers + (depth) being 1, 2 and 10. + + Args: + config ([`CvtConfig`]): Model configuration class. + """ + + config_class = CvtConfig + + def __init__(self, config: CvtConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.stages = [ + TFCvtStage(config, stage_idx, name=f"stages.{stage_idx}") for stage_idx in range(len(config.depth)) + ] + + def call( + self, + pixel_values: TFModelInputType, + output_hidden_states: bool | None = False, + return_dict: bool | None = True, + training: bool | None = False, + ) -> TFBaseModelOutputWithCLSToken | tuple[tf.Tensor]: + all_hidden_states = () if output_hidden_states else None + hidden_state = pixel_values + # When running on CPU, `keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width) + # as input format. So change the input format to (batch_size, height, width, num_channels). + hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1)) + + cls_token = None + for _, (stage_module) in enumerate(self.stages): + hidden_state, cls_token = stage_module(hidden_state, training=training) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + # Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules + hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2)) + if output_hidden_states: + all_hidden_states = tuple(tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None) + + return TFBaseModelOutputWithCLSToken( + last_hidden_state=hidden_state, + cls_token_value=cls_token, + hidden_states=all_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "stages", None) is not None: + for layer in self.stages: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFCvtMainLayer(keras.layers.Layer): + """Construct the Cvt model.""" + + config_class = CvtConfig + + def __init__(self, config: CvtConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.encoder = TFCvtEncoder(config, name="encoder") + + @unpack_inputs + def call( + self, + pixel_values: TFModelInputType | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutputWithCLSToken | tuple[tf.Tensor]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return TFBaseModelOutputWithCLSToken( + last_hidden_state=sequence_output, + cls_token_value=encoder_outputs.cls_token_value, + hidden_states=encoder_outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +class TFCvtPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CvtConfig + base_model_prefix = "cvt" + main_input_name = "pixel_values" + + +TFCVT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using [`keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + + + + Args: + config ([`CvtConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TFCVT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`] + for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.", + TFCVT_START_DOCSTRING, +) +class TFCvtModel(TFCvtPreTrainedModel): + def __init__(self, config: CvtConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.cvt = TFCvtMainLayer(config, name="cvt") + + @unpack_inputs + @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutputWithCLSToken, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: tf.Tensor | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutputWithCLSToken | tuple[tf.Tensor]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFCvtModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13") + >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13") + + >>> inputs = image_processor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + outputs = self.cvt( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return (outputs[0],) + outputs[1:] + + return TFBaseModelOutputWithCLSToken( + last_hidden_state=outputs.last_hidden_state, + cls_token_value=outputs.cls_token_value, + hidden_states=outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "cvt", None) is not None: + with tf.name_scope(self.cvt.name): + self.cvt.build(None) + + +@add_start_docstrings( + """ + Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + TFCVT_START_DOCSTRING, +) +class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: CvtConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.cvt = TFCvtMainLayer(config, name="cvt") + # Using same default epsilon as in the original implementation. + self.layernorm = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") + + # Classifier head + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=True, + bias_initializer="zeros", + name="classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: tf.Tensor | None = None, + labels: tf.Tensor | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFImageClassifierOutputWithNoAttention | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFCvtForImageClassification + >>> import tensorflow as tf + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13") + >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13") + + >>> inputs = image_processor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] + >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) + ```""" + + outputs = self.cvt( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + cls_token = outputs[1] + if self.config.cls_token[-1]: + sequence_output = self.layernorm(cls_token) + else: + # rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels" + batch_size, num_channels, height, width = shape_list(sequence_output) + sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width)) + sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1)) + sequence_output = self.layernorm(sequence_output) + + sequence_output_mean = tf.reduce_mean(sequence_output, axis=1) + logits = self.classifier(sequence_output_mean) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "cvt", None) is not None: + with tf.name_scope(self.cvt.name): + self.cvt.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, self.config.embed_dim[-1]]) + if getattr(self, "classifier", None) is not None: + if hasattr(self.classifier, "name"): + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.embed_dim[-1]]) + + +__all__ = ["TFCvtForImageClassification", "TFCvtModel", "TFCvtPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dab_detr/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dab_detr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa364bd2152049fc92f575e102bf7b934ba4a6c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dab_detr/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dab_detr import * + from .modeling_dab_detr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/dab_detr/configuration_dab_detr.py b/venv/lib/python3.13/site-packages/transformers/models/dab_detr/configuration_dab_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..e53d7783a6f431feafb8a3946b85118c5c436858 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dab_detr/configuration_dab_detr.py @@ -0,0 +1,268 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DAB-DETR model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class DabDetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DabDetrModel`]. It is used to instantiate + a DAB-DETR model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the DAB-DETR + [IDEA-Research/dab_detr-base](https://huggingface.co/IDEA-Research/dab_detr-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_timm_backbone (`bool`, *optional*, defaults to `True`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which + case it will default to `ResNetConfig()`. + backbone (`str`, *optional*, defaults to `"resnet50"`): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `True`): + Whether to use pretrained weights for the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`DabDetrModel`] can detect in a single image. For COCO, we recommend 100 queries. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Indicates whether the transformer model architecture is an encoder-decoder or not. + activation_function (`str` or `function`, *optional*, defaults to `"prelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_size (`int`, *optional*, defaults to 256): + This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float`, *optional*, defaults to 1.0): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + dilation (`bool`, *optional*, defaults to `False`): + Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when `use_timm_backbone` = `True`. + class_cost (`float`, *optional*, defaults to 2): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + cls_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the classification loss in the object detection loss function. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + focal_alpha (`float`, *optional*, defaults to 0.25): + Alpha parameter in the focal loss. + temperature_height (`int`, *optional*, defaults to 20): + Temperature parameter to tune the flatness of positional attention (HEIGHT) + temperature_width (`int`, *optional*, defaults to 20): + Temperature parameter to tune the flatness of positional attention (WIDTH) + query_dim (`int`, *optional*, defaults to 4): + Query dimension parameter represents the size of the output vector. + random_refpoints_xy (`bool`, *optional*, defaults to `False`): + Whether to fix the x and y coordinates of the anchor boxes with random initialization. + keep_query_pos (`bool`, *optional*, defaults to `False`): + Whether to concatenate the projected positional embedding from the object query into the original query (key) in every decoder layer. + num_patterns (`int`, *optional*, defaults to 0): + Number of pattern embeddings. + normalize_before (`bool`, *optional*, defaults to `False`): + Whether we use a normalization layer in the Encoder or not. + sine_position_embedding_scale (`float`, *optional*, defaults to 'None'): + Scaling factor applied to the normalized positional encodings. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + + + Examples: + + ```python + >>> from transformers import DabDetrConfig, DabDetrModel + + >>> # Initializing a DAB-DETR IDEA-Research/dab_detr-base style configuration + >>> configuration = DabDetrConfig() + + >>> # Initializing a model (with random weights) from the IDEA-Research/dab_detr-base style configuration + >>> model = DabDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dab-detr" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + use_timm_backbone=True, + backbone_config=None, + backbone="resnet50", + use_pretrained_backbone=True, + backbone_kwargs=None, + num_queries=300, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + is_encoder_decoder=True, + activation_function="prelu", + hidden_size=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + init_xavier_std=1.0, + auxiliary_loss=False, + dilation=False, + class_cost=2, + bbox_cost=5, + giou_cost=2, + cls_loss_coefficient=2, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + focal_alpha=0.25, + temperature_height=20, + temperature_width=20, + query_dim=4, + random_refpoints_xy=False, + keep_query_pos=False, + num_patterns=0, + normalize_before=False, + sine_position_embedding_scale=None, + initializer_bias_prior_prob=None, + **kwargs, + ): + if query_dim != 4: + raise ValueError("The query dimensions has to be 4.") + + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = 3 # num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): + if backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + backbone = None + # set timm attributes to None + dilation = None + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.use_timm_backbone = use_timm_backbone + self.backbone_config = backbone_config + self.num_queries = num_queries + self.hidden_size = hidden_size + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.num_hidden_layers = encoder_layers + self.auxiliary_loss = auxiliary_loss + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.backbone_kwargs = backbone_kwargs + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.cls_loss_coefficient = cls_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.focal_alpha = focal_alpha + self.query_dim = query_dim + self.random_refpoints_xy = random_refpoints_xy + self.keep_query_pos = keep_query_pos + self.num_patterns = num_patterns + self.normalize_before = normalize_before + self.temperature_width = temperature_width + self.temperature_height = temperature_height + self.sine_position_embedding_scale = sine_position_embedding_scale + self.initializer_bias_prior_prob = initializer_bias_prior_prob + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + +__all__ = ["DabDetrConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dab_detr/modeling_dab_detr.py b/venv/lib/python3.13/site-packages/transformers/models/dab_detr/modeling_dab_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7a27e7663b62fb1fd663595e9c6b4b16e9db9e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dab_detr/modeling_dab_detr.py @@ -0,0 +1,1599 @@ +# coding=utf-8 +# Copyright 2024 IDEA Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DAB-DETR model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + auto_docstring, + logging, +) +from ...utils.backbone_utils import load_backbone +from .configuration_dab_detr import DabDetrConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Conditional DETR decoder. This class adds one attribute to + BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output + of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary + decoding losses. + """ +) +# Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderOutput with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points) +class DabDetrDecoderOutput(BaseModelOutputWithCrossAttentions): + r""" + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + reference_points: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Conditional DETR encoder-decoder model. This class adds one attribute to + Seq2SeqModelOutput, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder + layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding + losses. + """ +) +# Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrModelOutput with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points) +class DabDetrModelOutput(Seq2SeqModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + reference_points: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`DabDetrForObjectDetection`]. + """ +) +# Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->DabDetr +class DabDetrObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DabDetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DabDetr +class DabDetrFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DabDetr +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `DabDetrFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = DabDetrFrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +# Modified from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->DabDetr +class DabDetrConvEncoder(nn.Module): + """ + Convolutional backbone, using either the AutoBackbone API or one from the timm library. + + nn.BatchNorm2d layers are replaced by DabDetrFrozenBatchNorm2d as defined above. + + """ + + def __init__(self, config: DabDetrConfig): + super().__init__() + + self.config = config + backbone = load_backbone(config) + + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DabDetr +class DabDetrConvModel(nn.Module): + """ + This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. + """ + + def __init__(self, conv_encoder, position_embedding): + super().__init__() + self.conv_encoder = conv_encoder + self.position_embedding = position_embedding + + def forward(self, pixel_values, pixel_mask): + # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples + out = self.conv_encoder(pixel_values, pixel_mask) + pos = [] + for feature_map, mask in out: + # position encoding + pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) + + return out, pos + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrSinePositionEmbedding with ConditionalDetr->DabDetr +class DabDetrSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__(self, config: DabDetrConfig): + super().__init__() + self.config = config + self.embedding_dim = config.hidden_size / 2 + self.temperature_height = config.temperature_height + self.temperature_width = config.temperature_width + scale = config.sine_position_embedding_scale + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, pixel_values, pixel_mask): + if pixel_mask is None: + raise ValueError("No pixel mask provided") + y_embed = pixel_mask.cumsum(1, dtype=torch.float32) + x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale + + # We use float32 to ensure reproducibility of the original implementation + dim_tx = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + # Modifying dim_tx in place to avoid extra memory allocation -> dim_tx = self.temperature_width ** (2 * (dim_tx // 2) / self.embedding_dim) + dim_tx //= 2 + dim_tx.mul_(2 / self.embedding_dim) + dim_tx.copy_(self.temperature_width**dim_tx) + pos_x = x_embed[:, :, :, None] / dim_tx + + # We use float32 to ensure reproducibility of the original implementation + dim_ty = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + # Modifying dim_ty in place to avoid extra memory allocation -> dim_ty = self.temperature_height ** (2 * (dim_ty // 2) / self.embedding_dim) + dim_ty //= 2 + dim_ty.mul_(2 / self.embedding_dim) + dim_ty.copy_(self.temperature_height**dim_ty) + pos_y = y_embed[:, :, :, None] / dim_ty + + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# function to generate sine positional embedding for 4d coordinates +def gen_sine_position_embeddings(pos_tensor, hidden_size=256): + """ + This function computes position embeddings using sine and cosine functions from the input positional tensor, + which has a shape of (batch_size, num_queries, 4). + The last dimension of `pos_tensor` represents the following coordinates: + - 0: x-coord + - 1: y-coord + - 2: width + - 3: height + + The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension + achieved by concatenating the sine and cosine values for each coordinate. + """ + scale = 2 * math.pi + dim = hidden_size // 2 + dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + return pos.to(pos_tensor.dtype) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# Modified from transformers.models.detr.modeling_detr.DetrAttention +class DetrAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + """ + + def __init__( + self, + config: DabDetrConfig, + bias: bool = True, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.encoder_attention_heads + self.attention_dropout = config.attention_dropout + self.head_dim = self.hidden_size // self.num_heads + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + batch_size, q_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if object_queries is not None: + hidden_states_original = hidden_states + hidden_states = hidden_states + object_queries + + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states_original) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(batch_size, q_len, embed_dim) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrAttention with ConditionalDetr->DABDETR,Conditional DETR->DabDetr +class DabDetrAttention(nn.Module): + """ + Cross-Attention used in DAB-DETR 'DAB-DETR for Fast Training Convergence' paper. + + The key q_proj, k_proj, v_proj are defined outside the attention. This attention allows the dim of q, k to be + different to v. + """ + + def __init__(self, config: DabDetrConfig, bias: bool = True, is_cross: bool = False): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size * 2 if is_cross else config.hidden_size + self.output_dim = config.hidden_size + self.attention_heads = config.decoder_attention_heads + self.attention_dropout = config.attention_dropout + self.attention_head_dim = self.embed_dim // self.attention_heads + if self.attention_head_dim * self.attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `attention_heads`:" + f" {self.attention_heads})." + ) + # head dimension of values + self.values_head_dim = self.output_dim // self.attention_heads + if self.values_head_dim * self.attention_heads != self.output_dim: + raise ValueError( + f"output_dim must be divisible by attention_heads (got `output_dim`: {self.output_dim} and `attention_heads`: {self.attention_heads})." + ) + self.scaling = self.attention_head_dim**-0.5 + self.output_proj = nn.Linear(self.output_dim, self.output_dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + key_states: Optional[torch.Tensor] = None, + value_states: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + # scaling query and refactor key-, value states + query_states = hidden_states * self.scaling + query_states = query_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.attention_heads, self.values_head_dim).transpose(1, 2) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_probs = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) + + if attn_output.size() != (batch_size, self.attention_heads, q_len, self.values_head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.attention_heads, q_len, self.values_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(batch_size, q_len, self.output_dim) + attn_output = self.output_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DabDetrDecoderLayerSelfAttention(nn.Module): + def __init__(self, config: DabDetrConfig): + super().__init__() + self.dropout = config.dropout + self.self_attn_query_content_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_query_pos_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_key_content_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_key_pos_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_value_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn = DabDetrAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + query_position_embeddings: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + residual = hidden_states + query_content = self.self_attn_query_content_proj(hidden_states) + query_pos = self.self_attn_query_pos_proj(query_position_embeddings) + key_content = self.self_attn_key_content_proj(hidden_states) + key_pos = self.self_attn_key_pos_proj(query_position_embeddings) + value = self.self_attn_value_proj(hidden_states) + + query = query_content + query_pos + key = key_content + key_pos + + hidden_states, attn_weights = self.self_attn( + hidden_states=query, + attention_mask=attention_mask, + key_states=key, + value_states=value, + output_attentions=True, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + return hidden_states, attn_weights + + +class DabDetrDecoderLayerCrossAttention(nn.Module): + def __init__(self, config: DabDetrConfig, is_first: bool = False): + super().__init__() + hidden_size = config.hidden_size + self.cross_attn_query_content_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_query_pos_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_key_content_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_key_pos_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_value_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_query_pos_sine_proj = nn.Linear(hidden_size, hidden_size) + self.decoder_attention_heads = config.decoder_attention_heads + self.cross_attn_layer_norm = nn.LayerNorm(hidden_size) + self.cross_attn = DabDetrAttention(config, is_cross=True) + + self.keep_query_pos = config.keep_query_pos + + if not self.keep_query_pos and not is_first: + self.cross_attn_query_pos_proj = None + + self.is_first = is_first + self.dropout = config.dropout + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + query_sine_embed: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + query_content = self.cross_attn_query_content_proj(hidden_states) + key_content = self.cross_attn_key_content_proj(encoder_hidden_states) + value = self.cross_attn_value_proj(encoder_hidden_states) + + batch_size, num_queries, n_model = query_content.shape + _, height_width, _ = key_content.shape + + key_pos = self.cross_attn_key_pos_proj(object_queries) + + # For the first decoder layer, we add the positional embedding predicted from + # the object query (the positional embedding) into the original query (key) in DETR. + if self.is_first or self.keep_query_pos: + query_pos = self.cross_attn_query_pos_proj(query_position_embeddings) + query = query_content + query_pos + key = key_content + key_pos + else: + query = query_content + key = key_content + + query = query.view( + batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads + ) + query_sine_embed = self.cross_attn_query_pos_sine_proj(query_sine_embed) + query_sine_embed = query_sine_embed.view( + batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads + ) + query = torch.cat([query, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2) + key = key.view(batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads) + key_pos = key_pos.view( + batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads + ) + key = torch.cat([key, key_pos], dim=3).view(batch_size, height_width, n_model * 2) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.cross_attn( + hidden_states=query, + attention_mask=encoder_attention_mask, + key_states=key, + value_states=value, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.cross_attn_layer_norm(hidden_states) + + return hidden_states, cross_attn_weights + + +class DabDetrDecoderLayerFFN(nn.Module): + def __init__(self, config: DabDetrConfig): + super().__init__() + hidden_size = config.hidden_size + self.final_layer_norm = nn.LayerNorm(hidden_size) + self.fc1 = nn.Linear(hidden_size, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, hidden_size) + self.activation_fn = ACT2FN[config.activation_function] + self.dropout = config.dropout + self.activation_dropout = config.activation_dropout + self.keep_query_pos = config.keep_query_pos + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +# Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig +class DabDetrEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DabDetrConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = DetrAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.fc1 = nn.Linear(self.hidden_size, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.hidden_size) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + object_queries: torch.Tensor, + output_attentions: Optional[bool] = None, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr +class DabDetrDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DabDetrConfig, is_first: bool = False): + super().__init__() + self.self_attn = DabDetrDecoderLayerSelfAttention(config) + self.cross_attn = DabDetrDecoderLayerCrossAttention(config, is_first) + self.mlp = DabDetrDecoderLayerFFN(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + query_sine_embed: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + object_queries that are added to the queries and keys + in the cross-attention layer. + query_position_embeddings (`torch.FloatTensor`, *optional*): + object_queries that are added to the queries and keys + in the self-attention layer. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + + """ + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + query_position_embeddings=query_position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states, cross_attn_weights = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_position_embeddings=query_position_embeddings, + object_queries=object_queries, + encoder_attention_mask=encoder_attention_mask, + query_sine_embed=query_sine_embed, + output_attentions=output_attentions, + ) + + hidden_states = self.mlp(hidden_states=hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Modified from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->DabDetrMLP +class DabDetrMLP(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, input_tensor): + for i, layer in enumerate(self.layers): + input_tensor = nn.functional.relu(layer(input_tensor)) if i < self.num_layers - 1 else layer(input_tensor) + return input_tensor + + +# Modified from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->DabDetr +@auto_docstring +class DabDetrPreTrainedModel(PreTrainedModel): + config: DabDetrConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + _no_split_modules = [r"DabDetrConvEncoder", r"DabDetrEncoderLayer", r"DabDetrDecoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + xavier_std = self.config.init_xavier_std + + if isinstance(module, DabDetrMHAttentionMap): + nn.init.zeros_(module.k_linear.bias) + nn.init.zeros_(module.q_linear.bias) + nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) + nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DabDetrForObjectDetection): + nn.init.constant_(module.bbox_predictor.layers[-1].weight.data, 0) + nn.init.constant_(module.bbox_predictor.layers[-1].bias.data, 0) + + # init prior_prob setting for focal loss + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias_value = -math.log((1 - prior_prob) / prior_prob) + module.class_embed.bias.data.fill_(bias_value) + elif isinstance(module, nn.PReLU): + module.reset_parameters() + + +# Modified from transformers.models.detr.modeling_detr.DetrEncoder with Detr->DabDetr,DETR->ConditionalDETR +class DabDetrEncoder(DabDetrPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`DabDetrEncoderLayer`]. + + The encoder updates the flattened feature map through multiple self-attention layers. + + Small tweak for DAB-DETR: + + - object_queries are added to the forward pass. + + Args: + config: DabDetrConfig + """ + + def __init__(self, config: DabDetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.query_scale = DabDetrMLP(config.hidden_size, config.hidden_size, config.hidden_size, 2) + self.layers = nn.ModuleList([DabDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.norm = nn.LayerNorm(config.hidden_size) if config.normalize_before else None + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds, + attention_mask, + object_queries, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`): + Object queries that are added to the queries in each self-attention layer. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # pos scaler + pos_scales = self.query_scale(hidden_states) + # we add object_queries * pos_scaler as extra input to the encoder_layer + scaled_object_queries = object_queries * pos_scales + + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + object_queries=scaled_object_queries, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if self.norm: + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoder with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR +class DabDetrDecoder(DabDetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DabDetrDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some small tweaks for DAB-DETR: + + - object_queries and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: DabDetrConfig + """ + + def __init__(self, config: DabDetrConfig): + super().__init__(config) + self.config = config + self.dropout = config.dropout + self.num_layers = config.decoder_layers + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [DabDetrDecoderLayer(config, is_first=(layer_id == 0)) for layer_id in range(config.decoder_layers)] + ) + # in DAB-DETR, the decoder uses layernorm after the last decoder layer output + self.hidden_size = config.hidden_size + self.layernorm = nn.LayerNorm(self.hidden_size) + + # Default cond-elewise + self.query_scale = DabDetrMLP(self.hidden_size, self.hidden_size, self.hidden_size, 2) + + self.ref_point_head = DabDetrMLP( + config.query_dim // 2 * self.hidden_size, self.hidden_size, self.hidden_size, 2 + ) + + self.bbox_embed = None + + # Default decoder_modulate_hw_attn is True + self.ref_anchor_head = DabDetrMLP(self.hidden_size, self.hidden_size, 2, 2) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds, + encoder_hidden_states, + memory_key_padding_mask, + object_queries, + query_position_embeddings, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(encoder_sequence_length, batch_size, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + memory_key_padding_mask (`torch.Tensor.bool` of shape `(batch_size, sequence_length)`): + The memory_key_padding_mask indicates which positions in the memory (encoder outputs) should be ignored during the attention computation, + ensuring padding tokens do not influence the attention mechanism. + object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each cross-attention layer. + query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, number_of_anchor_points)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + intermediate = [] + reference_points = query_position_embeddings.sigmoid() + ref_points = [reference_points] + + # expand encoder attention mask + if encoder_hidden_states is not None and memory_key_padding_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + memory_key_padding_mask = _prepare_4d_attention_mask( + memory_key_padding_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + for layer_id, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + obj_center = reference_points[..., : self.config.query_dim] + query_sine_embed = gen_sine_position_embeddings(obj_center, self.hidden_size) + query_pos = self.ref_point_head(query_sine_embed) + + # For the first decoder layer, we do not apply transformation over p_s + pos_transformation = 1 if layer_id == 0 else self.query_scale(hidden_states) + + # apply transformation + query_sine_embed = query_sine_embed[..., : self.hidden_size] * pos_transformation + + # modulated Height Width attentions + reference_anchor_size = self.ref_anchor_head(hidden_states).sigmoid() # nq, bs, 2 + query_sine_embed[..., self.hidden_size // 2 :] *= ( + reference_anchor_size[..., 0] / obj_center[..., 2] + ).unsqueeze(-1) + query_sine_embed[..., : self.hidden_size // 2] *= ( + reference_anchor_size[..., 1] / obj_center[..., 3] + ).unsqueeze(-1) + + layer_outputs = decoder_layer( + hidden_states, + None, # attention_mask + object_queries, + query_pos, + query_sine_embed, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=memory_key_padding_mask, + output_attentions=output_attentions, + ) + + # iter update + hidden_states = layer_outputs[0] + + if self.bbox_embed is not None: + new_reference_points = self.bbox_embed(hidden_states) + + new_reference_points[..., : self.config.query_dim] += inverse_sigmoid(reference_points) + new_reference_points = new_reference_points[..., : self.config.query_dim].sigmoid() + if layer_id != self.num_layers - 1: + ref_points.append(new_reference_points) + reference_points = new_reference_points.detach() + + intermediate.append(self.layernorm(hidden_states)) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Layer normalization on hidden states + hidden_states = self.layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output_intermediate_hidden_states = torch.stack(intermediate) + output_reference_points = torch.stack(ref_points) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attns, + all_cross_attentions, + output_intermediate_hidden_states, + output_reference_points, + ] + if v is not None + ) + return DabDetrDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=output_intermediate_hidden_states, + reference_points=output_reference_points, + ) + + +@auto_docstring( + custom_intro=""" + The bare DAB-DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw + hidden-states, intermediate hidden states, reference points, output coordinates without any specific head on top. + """ +) +class DabDetrModel(DabDetrPreTrainedModel): + def __init__(self, config: DabDetrConfig): + super().__init__(config) + + self.auxiliary_loss = config.auxiliary_loss + + # Create backbone + positional encoding + self.backbone = DabDetrConvEncoder(config) + object_queries = DabDetrSinePositionEmbedding(config) + + self.query_refpoint_embeddings = nn.Embedding(config.num_queries, config.query_dim) + self.random_refpoints_xy = config.random_refpoints_xy + if self.random_refpoints_xy: + self.query_refpoint_embeddings.weight.data[:, :2].uniform_(0, 1) + self.query_refpoint_embeddings.weight.data[:, :2] = inverse_sigmoid( + self.query_refpoint_embeddings.weight.data[:, :2] + ) + self.query_refpoint_embeddings.weight.data[:, :2].requires_grad = False + + # Create projection layer + self.input_projection = nn.Conv2d( + self.backbone.intermediate_channel_sizes[-1], config.hidden_size, kernel_size=1 + ) + self.backbone = DabDetrConvModel(self.backbone, object_queries) + + self.encoder = DabDetrEncoder(config) + self.decoder = DabDetrDecoder(config) + + # decoder related variables + self.hidden_size = config.hidden_size + self.num_queries = config.num_queries + + self.num_patterns = config.num_patterns + if not isinstance(self.num_patterns, int): + logger.warning(f"num_patterns should be int but {type(self.num_patterns)}") + self.num_patterns = 0 + if self.num_patterns > 0: + self.patterns = nn.Embedding(self.num_patterns, self.hidden_size) + + self.aux_loss = config.auxiliary_loss + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], DabDetrModelOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab_detr-base") + >>> model = AutoModel.from_pretrained("IDEA-Research/dab_detr-base") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # the last hidden states are the final query embeddings of the Transformer decoder + >>> # these are of shape (batch_size, num_queries, hidden_size) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # pixel_values should be of shape (batch_size, num_channels, height, width) + # pixel_mask should be of shape (batch_size, height, width) + features, object_queries_list = self.backbone(pixel_values, pixel_mask) + + # get final feature map and downsampled mask + feature_map, mask = features[-1] + + if mask is None: + raise ValueError("Backbone does not return downsampled pixel mask") + + flattened_mask = mask.flatten(1) + + # Second, apply 1x1 convolution to reduce the channel dimension to hidden_size (256 by default) + projected_feature_map = self.input_projection(feature_map) + + # Third, flatten the feature map + object_queries of shape NxCxHxW to HWxNxC, and permute it to NxHWxC + # In other words, turn their shape into ( sequence_length, batch_size, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) + reference_position_embeddings = self.query_refpoint_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + + # Fourth, sent flattened_features + flattened_mask + object_queries through encoder + # flattened_features is a Tensor of shape (height*width, batch_size, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, height*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + object_queries=object_queries, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output) + num_queries = reference_position_embeddings.shape[1] + if self.num_patterns == 0: + queries = torch.zeros(batch_size, num_queries, self.hidden_size, device=device) + else: + queries = ( + self.patterns.weight[:, None, None, :] + .repeat(1, self.num_queries, batch_size, 1) + .flatten(0, 1) + .permute(1, 0, 2) + ) # bs, n_q*n_pat, hidden_size + reference_position_embeddings = reference_position_embeddings.repeat( + 1, self.num_patterns, 1 + ) # bs, n_q*n_pat, hidden_size + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=queries, + query_position_embeddings=reference_position_embeddings, + object_queries=object_queries, + encoder_hidden_states=encoder_outputs[0], + memory_key_padding_mask=flattened_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + # last_hidden_state + output = (decoder_outputs[0],) + reference_points = decoder_outputs[-1] + intermediate_hidden_states = decoder_outputs[-2] + + # it has to follow the order of DABDETRModelOutput that is based on ModelOutput + # If we only use one of the variables then the indexing will change. + # E.g: if we return everything then 'decoder_attentions' is decoder_outputs[2], if we only use output_attentions then its decoder_outputs[1] + if output_hidden_states and output_attentions: + output += ( + decoder_outputs[1], + decoder_outputs[2], + decoder_outputs[3], + encoder_outputs[0], + encoder_outputs[1], + encoder_outputs[2], + ) + elif output_hidden_states: + # decoder_hidden_states, encoder_last_hidden_state, encoder_hidden_states + output += ( + decoder_outputs[1], + encoder_outputs[0], + encoder_outputs[1], + ) + elif output_attentions: + # decoder_self_attention, decoder_cross_attention, encoder_attentions + output += ( + decoder_outputs[1], + decoder_outputs[2], + encoder_outputs[1], + ) + + output += (intermediate_hidden_states, reference_points) + + return output + + reference_points = decoder_outputs.reference_points + intermediate_hidden_states = decoder_outputs.intermediate_hidden_states + + return DabDetrModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states if output_hidden_states else None, + decoder_attentions=decoder_outputs.attentions if output_attentions else None, + cross_attentions=decoder_outputs.cross_attentions if output_attentions else None, + encoder_last_hidden_state=encoder_outputs.last_hidden_state if output_hidden_states else None, + encoder_hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + encoder_attentions=encoder_outputs.attentions if output_attentions else None, + intermediate_hidden_states=intermediate_hidden_states, + reference_points=reference_points, + ) + + +# Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->DabDetr +class DabDetrMHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask: Optional[Tensor] = None): + q = self.q_linear(q) + k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) + + if mask is not None: + weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min) + weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size()) + weights = self.dropout(weights) + return weights + + +@auto_docstring( + custom_intro=""" + DAB_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on + top, for tasks such as COCO detection. + """ +) +class DabDetrForObjectDetection(DabDetrPreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = [ + r"bbox_predictor\.layers\.\d+\.(weight|bias)", + r"model\.decoder\.bbox_embed\.layers\.\d+\.(weight|bias)", + ] + + def __init__(self, config: DabDetrConfig): + super().__init__(config) + + self.config = config + self.auxiliary_loss = config.auxiliary_loss + self.query_dim = config.query_dim + # DAB-DETR encoder-decoder model + self.model = DabDetrModel(config) + + _bbox_embed = DabDetrMLP(config.hidden_size, config.hidden_size, 4, 3) + # Object detection heads + self.class_embed = nn.Linear(config.hidden_size, config.num_labels) + + # Default bbox_embed_diff_each_layer is False + self.bbox_predictor = _bbox_embed + + # Default iter_update is True + self.model.decoder.bbox_embed = self.bbox_predictor + + # Initialize weights and apply final processing + self.post_init() + + # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/dab_detr.py + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], DabDetrObjectDetectionOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50") + >>> model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([(image.height, image.width)]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0] + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected remote with confidence 0.833 at location [38.31, 72.1, 177.63, 118.45] + Detected cat with confidence 0.831 at location [9.2, 51.38, 321.13, 469.0] + Detected cat with confidence 0.804 at location [340.3, 16.85, 642.93, 370.95] + Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01] + Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through DAB_DETR base model to obtain encoder + decoder outputs + model_outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + reference_points = model_outputs.reference_points if return_dict else model_outputs[-1] + intermediate_hidden_states = model_outputs.intermediate_hidden_states if return_dict else model_outputs[-2] + + # class logits + predicted bounding boxes + logits = self.class_embed(intermediate_hidden_states[-1]) + + reference_before_sigmoid = inverse_sigmoid(reference_points) + bbox_with_refinement = self.bbox_predictor(intermediate_hidden_states) + bbox_with_refinement[..., : self.query_dim] += reference_before_sigmoid + outputs_coord = bbox_with_refinement.sigmoid() + + pred_boxes = outputs_coord[-1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class = None + if self.config.auxiliary_loss: + outputs_class = self.class_embed(intermediate_hidden_states) + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + auxiliary_outputs + model_outputs + else: + output = (logits, pred_boxes) + model_outputs + # Since DabDetrObjectDetectionOutput doesn't have reference points + intermedieate_hidden_states we cut down. + return ((loss, loss_dict) + output) if loss is not None else output[:-2] + + return DabDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=model_outputs.last_hidden_state, + decoder_hidden_states=model_outputs.decoder_hidden_states if output_hidden_states else None, + decoder_attentions=model_outputs.decoder_attentions if output_attentions else None, + cross_attentions=model_outputs.cross_attentions if output_attentions else None, + encoder_last_hidden_state=model_outputs.encoder_last_hidden_state if output_hidden_states else None, + encoder_hidden_states=model_outputs.encoder_hidden_states if output_hidden_states else None, + encoder_attentions=model_outputs.encoder_attentions if output_attentions else None, + ) + + +__all__ = [ + "DabDetrForObjectDetection", + "DabDetrModel", + "DabDetrPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7000ac3d353bf4eba157b07e350b9ac5f7552a98 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_data2vec_audio import * + from .configuration_data2vec_text import * + from .configuration_data2vec_vision import * + from .modeling_data2vec_audio import * + from .modeling_data2vec_text import * + from .modeling_data2vec_vision import * + from .modeling_tf_data2vec_vision import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..3d88a9de6543afaa60c6d5353f2c32d871d5ee21 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py @@ -0,0 +1,288 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data2VecText configuration""" + +import math + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Data2VecAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Data2VecAudioModel`]. It is used to instantiate + an Data2VecAudio model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Data2VecAudio + [facebook/data2vec-audio-base-960h](https://huggingface.co/facebook/data2vec-audio-base-960h) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32): + Vocabulary size of the Data2VecAudio model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`Data2VecAudioModel`] or [`TFData2VecAudioModel`]. Vocabulary size + of the model. Defines the different tokens that can be represented by the *inputs_ids* passed to the + forward method of [`Data2VecAudioModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`Data2VecAudioForCTC`]. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) for more + details. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the feature encoder. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the probability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`Data2VecAudioForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`Data2VecAudioForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`Data2VecAudioForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + tdnn_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`): + A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN* + module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers. + tdnn_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the + *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*. + tdnn_dilation (`tuple[int]` or `list[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`): + A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the + *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*. + xvector_output_dim (`int`, *optional*, defaults to 512): + Dimensionality of the *XVector* embedding vectors. + add_adapter (`bool`, *optional*, defaults to `False`): + Whether a convolutional network should be stacked on top of the Data2VecAudio Encoder. Can be very useful + for warm-starting Data2VecAudio for SpeechEncoderDecoder models. + adapter_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adapter_stride (`int`, *optional*, defaults to 2): + Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + num_adapter_layers (`int`, *optional*, defaults to 3): + Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is + True`. + output_hidden_size (`int`, *optional*): + Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant + if `add_adapter is True`. + + Example: + + ```python + >>> from transformers import Data2VecAudioConfig, Data2VecAudioModel + + >>> # Initializing a Data2VecAudio facebook/data2vec-audio-base-960h style configuration + >>> configuration = Data2VecAudioConfig() + + >>> # Initializing a model (with random weights) from the facebook/data2vec-audio-base-960h style configuration + >>> model = Data2VecAudioModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "data2vec-audio" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embedding_groups=16, + conv_pos_kernel_size=19, + num_conv_pos_embeddings=5, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + ctc_loss_reduction="sum", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + tdnn_dim=(512, 512, 512, 512, 1500), + tdnn_kernel=(5, 3, 3, 1, 1), + tdnn_dilation=(1, 2, 3, 1, 1), + xvector_output_dim=512, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + add_adapter=False, + adapter_kernel_size=3, + adapter_stride=2, + num_adapter_layers=3, + output_hidden_size=None, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.conv_pos_kernel_size = conv_pos_kernel_size + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.use_weighted_layer_sum = use_weighted_layer_sum + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://huggingface.co/papers/1904.08779 + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # adapter + self.add_adapter = add_adapter + self.adapter_kernel_size = adapter_kernel_size + self.adapter_stride = adapter_stride + self.num_adapter_layers = num_adapter_layers + self.output_hidden_size = output_hidden_size or hidden_size + + # SequenceClassification-specific parameter. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + + # XVector-specific parameters. Feel free to ignore for other classes. + self.tdnn_dim = list(tdnn_dim) + self.tdnn_kernel = list(tdnn_kernel) + self.tdnn_dilation = list(tdnn_dilation) + self.xvector_output_dim = xvector_output_dim + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) + + +__all__ = ["Data2VecAudioConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_text.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_text.py new file mode 100644 index 0000000000000000000000000000000000000000..f9518d67bf665f01a2cfb46cfdb6f529f5f22bce --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_text.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data2VecText configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Data2VecTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Data2VecTextModel`] and [`Data2VecTextModel`]. It + is used to instantiate a Data2VecText model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Data2VecText + [facebook/data2vec-text-base](https://huggingface.co/facebook/data2vec-text-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the DATA2VEC model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Data2VecModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`Data2VecModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import Data2VecTextConfig, Data2VecTextModel + + >>> # Initializing a Data2VecText facebook/data2vec-text-base style configuration + >>> configuration = Data2VecTextConfig() + + >>> # Initializing a model (with random weights) from the facebook/data2vec-text-base style configuration + >>> model = Data2VecTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "data2vec-text" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class Data2VecTextOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["Data2VecTextConfig", "Data2VecTextOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..2de256f9d7d7abda3b065472f116a371c03ee4da --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright Meta Platforms and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data2VecVision model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Data2VecVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Data2VecVisionModel`]. It is used to instantiate + an Data2VecVision model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Data2VecVision + [facebook/data2vec-vision-base](https://huggingface.co/facebook/data2vec-vision-base) architecture. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style absolute position embeddings. + use_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use T5-style relative position embeddings in the self-attention layers. + use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use the same relative position embeddings across all self-attention layers of the Transformer. + layer_scale_init_value (`float`, *optional*, defaults to 0.1): + Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_mean_pooling (`bool`, *optional*, defaults to `True`): + Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the + CLS token, before applying the classification head. + out_indices (`list[int]`, *optional*, defaults to `[3, 5, 7, 11]`): + Indices of the feature maps to use for semantic segmentation. + pool_scales (`tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (`int`, *optional*, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (`int`, *optional*, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + + Example: + + ```python + >>> from transformers import Data2VecVisionConfig, Data2VecVisionModel + + >>> # Initializing a Data2VecVision data2vec_vision-base-patch16-224-in22k style configuration + >>> configuration = Data2VecVisionConfig() + + >>> # Initializing a model (with random weights) from the data2vec_vision-base-patch16-224-in22k style configuration + >>> model = Data2VecVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "data2vec-vision" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + use_mask_token=False, + use_absolute_position_embeddings=False, + use_relative_position_bias=False, + use_shared_relative_position_bias=False, + layer_scale_init_value=0.1, + drop_path_rate=0.1, + use_mean_pooling=True, + out_indices=[3, 5, 7, 11], + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, + semantic_loss_ignore_index=255, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.use_mask_token = use_mask_token + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.use_relative_position_bias = use_relative_position_bias + self.use_shared_relative_position_bias = use_shared_relative_position_bias + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.use_mean_pooling = use_mean_pooling + # decode head attributes (semantic segmentation) + self.out_indices = out_indices + self.pool_scales = pool_scales + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + +# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig +class Data2VecVisionOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["Data2VecVisionConfig", "Data2VecVisionOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b3f01f42d4ee528c11e12865d0e020ed5f0cf7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py @@ -0,0 +1,1397 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_data2vec_audio.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available +from .configuration_data2vec_audio import Data2VecAudioConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +class Data2VecAudioConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Data2VecAudioPadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class Data2VecAudioPositionalConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.conv_pos_kernel_size, + padding=config.conv_pos_kernel_size // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size) + self.activation = ACT2FN[config.feat_extract_activation] + # no learnable parameters + self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Data2VecAudioPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList( + [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)] + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + for layer in self.layers: + hidden_states = layer(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Data2VecAudioFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + self.conv_layers = nn.ModuleList( + [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + ) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class Data2VecAudioFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Data2VecAudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[Data2VecAudioConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, None + + +class Data2VecAudioFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class Data2VecAudioEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.attention = Data2VecAudioAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + config=config, + ) + + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = Data2VecAudioFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Data2VecAudioEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = Data2VecAudioPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = self.training and dropout_probability < self.config.layerdrop + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + +class Data2VecAudioAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + +class Data2VecAudioAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) + else: + self.proj = self.proj_layer_norm = None + + self.layers = nn.ModuleList(Data2VecAudioAdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + def forward(self, hidden_states): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +@auto_docstring +class Data2VecAudioPreTrainedModel(PreTrainedModel): + config: Data2VecAudioConfig + base_model_prefix = "data2vec_audio" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Data2VecAudioFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, Data2VecAudioPositionalConvLayer): + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + if module.bias is not None: + module.bias.data.zero_() + if module.weight is not None: + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +def _compute_mask_indices( + shape: tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput + + +@auto_docstring +class Data2VecAudioModel(Data2VecAudioPreTrainedModel): + def __init__(self, config: Data2VecAudioConfig): + super().__init__(config) + self.config = config + self.feature_extractor = Data2VecAudioFeatureEncoder(config) + self.feature_projection = Data2VecAudioFeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = Data2VecAudioEncoder(config) + + self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.feature_extractor._freeze_parameters() + + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://huggingface.co/papers/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Data2VecAudioBaseModelOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Data2VecAudioBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +_HIDDEN_STATES_START_POSITION = 2 + + +@auto_docstring( + custom_intro=""" + Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). + """ +) +class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel): + def __init__(self, config): + r""" + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`Data2VecAudioForCTC`] with adapters. Uses 'eng' by + default. + """ + super().__init__(config) + + self.data2vec_audio = Data2VecAudioModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.data2vec_audio.feature_extractor._freeze_parameters() + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.data2vec_audio( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@auto_docstring( + custom_intro=""" + Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """ +) +class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)" + ) + self.data2vec_audio = Data2VecAudioModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.data2vec_audio.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.data2vec_audio.parameters(): + param.requires_grad = False + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion + into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.data2vec_audio( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)" + ) + self.data2vec_audio = Data2VecAudioModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.num_labels = config.num_labels + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.data2vec_audio.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.data2vec_audio.parameters(): + param.requires_grad = False + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion + into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.data2vec_audio( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super().__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if is_peft_available(): + from peft.tuners.lora import LoraLayer + + if is_peft_available(): + if isinstance(self.kernel, LoraLayer): + warnings.warn( + "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " + "You should exclude TDNNLayer from LoRA's target modules.", + ) + + # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up + hidden_states = hidden_states.transpose(1, 2) + weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2) + hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +@auto_docstring( + custom_intro=""" + Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """ +) +class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.data2vec_audio = Data2VecAudioModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.data2vec_audio.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.data2vec_audio.parameters(): + param.requires_grad = False + + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[tuple, XVectorOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion + into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.data2vec_audio( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Data2VecAudioForAudioFrameClassification", + "Data2VecAudioForCTC", + "Data2VecAudioForSequenceClassification", + "Data2VecAudioForXVector", + "Data2VecAudioModel", + "Data2VecAudioPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_text.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_text.py new file mode 100644 index 0000000000000000000000000000000000000000..f866dd9144a627b938de67ed3d3f79816e849722 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_text.py @@ -0,0 +1,1378 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Data2VecText model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_data2vec_text import Data2VecTextConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText +class Data2VecTextForTextEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText +class Data2VecTextSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in Data2VecTextModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class Data2VecTextSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = { + "eager": Data2VecTextSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT +class Data2VecTextAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, + ) + self.output = Data2VecTextSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class Data2VecTextIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class Data2VecTextOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText +class Data2VecTextLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Data2VecTextAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = Data2VecTextAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) + self.intermediate = Data2VecTextIntermediate(config) + self.output = Data2VecTextOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Data2VecText +class Data2VecTextEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer = nn.ModuleList([Data2VecTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class Data2VecTextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class Data2VecTextPreTrainedModel(PreTrainedModel): + config: Data2VecTextConfig + base_model_prefix = "data2vec_text" + supports_gradient_checkpointing = True + _no_split_modules = ["Data2VecTextForTextEmbeddings", "Data2VecTextLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(1.0) + + +@auto_docstring +class Data2VecTextModel(Data2VecTextPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762 + + """ + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = Data2VecTextForTextEmbeddings(config) + self.encoder = Data2VecTextEncoder(config) + + self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Data2VecText Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `Data2VecTextLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False) + self.lm_head = Data2VecTextLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> from transformers import AutoTokenizer, Data2VecTextForCausalLM, Data2VecTextConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/data2vec-text-base") + >>> config = Data2VecTextConfig.from_pretrained("facebook/data2vec-text-base") + >>> config.is_decoder = True + >>> model = Data2VecTextForCausalLM.from_pretrained("facebook/data2vec-text-base", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.data2vec_text( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring +class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `Data2VecTextForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False) + self.lm_head = Data2VecTextLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.data2vec_text( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(prediction_scores.device) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Data2VecText +class Data2VecTextLMHead(nn.Module): + """Data2VecText Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@auto_docstring( + custom_intro=""" + Data2VecText Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False) + self.classifier = Data2VecTextClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.data2vec_text( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.data2vec_text = Data2VecTextModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.data2vec_text( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(reshaped_logits.device) + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.data2vec_text( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Data2VecText +class Data2VecTextClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@auto_docstring +class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.data2vec_text( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "Data2VecTextForCausalLM", + "Data2VecTextForMaskedLM", + "Data2VecTextForMultipleChoice", + "Data2VecTextForQuestionAnswering", + "Data2VecTextForSequenceClassification", + "Data2VecTextForTokenClassification", + "Data2VecTextModel", + "Data2VecTextPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..f214f8eb6a0bcefe5eb6a7072864dfbbe4b0b22a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py @@ -0,0 +1,1348 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Data2VecVision model.""" + +import collections.abc +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + SemanticSegmenterOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging, torch_int +from .configuration_data2vec_vision import Data2VecVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`Data2VecVisionModel`]. + """ +) +# Copied from transformers.models.beit.modeling_beit.BeitModelOutputWithPooling with Beit->Data2VecVision +class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + """ + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision +class Data2VecVisionDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision +class Data2VecVisionEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + else: + self.mask_token = None + self.patch_embeddings = Data2VecVisionPatchEmbeddings(config) + self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) + num_patches = self.patch_embeddings.num_patches + if config.use_absolute_position_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + else: + self.position_embeddings = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + ) -> torch.Tensor: + if self.position_embeddings is not None and interpolate_pos_encoding is not None: + warnings.warn( + "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always " + "interpolated to the input image size. The argument will be removed in transformers v4.51.0." + ) + + _, _, height, width = pixel_values.shape + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + if self.position_embeddings is not None: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = self.dropout(embeddings) + + return embeddings, (patch_height, patch_width) + + +# Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision +class Data2VecVisionPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + embeddings = self.projection(pixel_values) + patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, (patch_height, patch_width) + + +# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision +class Data2VecVisionSelfAttention(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.has_relative_position_bias = bool(window_size) + if self.has_relative_position_bias: + self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Add relative position bias if present. + if self.has_relative_position_bias: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + attention_scores = attention_scores + self.relative_position_bias( + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_scores = attention_scores + relative_position_bias + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision +class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + attn_bias = None + if self.has_relative_position_bias: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + attn_bias = self.relative_position_bias( + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + if attn_bias is None: + attn_bias = relative_position_bias + else: + attn_bias += relative_position_bias + + scaling = 1 / math.sqrt(self.attention_head_size) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attn_bias, + dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=scaling, + ) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer, None + + +# Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision +class Data2VecVisionSelfOutput(nn.Module): + """ + The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +DATA2VEC_VISION_SELF_ATTENTION_CLASSES = { + "eager": Data2VecVisionSelfAttention, + "sdpa": Data2VecVisionSdpaSelfAttention, +} + + +# Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION +class Data2VecVisionAttention(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, window_size=window_size + ) + self.output = Data2VecVisionSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + self_outputs = self.attention( + hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution + ) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.beit.modeling_beit.BeitIntermediate with Beit->Data2VecVision +class Data2VecVisionIntermediate(nn.Module): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.BeitOutput with Beit->Data2VecVision +class Data2VecVisionOutput(nn.Module): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision +class Data2VecVisionLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__( + self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0 + ) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Data2VecVisionAttention(config, window_size=window_size) + self.intermediate = Data2VecVisionIntermediate(config) + self.output = Data2VecVisionOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + init_values = config.layer_scale_init_value + if init_values > 0: + self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True) + self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True) + else: + self.lambda_1, self.lambda_2 = None, None + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int, int]] = None, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Data2VecVision, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.beit.modeling_beit.BeitRelativePositionBias with Beit->Data2VecVision +class Data2VecVisionRelativePositionBias(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None: + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, config.num_attention_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + @compile_compatible_method_lru_cache(maxsize=10) + def generate_relative_position_index(self, window_size: tuple[int, int]) -> torch.Tensor: + """ + This method creates the relative position index, modified to support arbitrary window sizes, + as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460). + """ + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # cls to token & token 2 cls & cls to cls + # get pair-wise relative position index for each token inside the window + window_area = window_size[0] * window_size[1] + grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij") + coords = torch.stack(grid) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor: + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = nn.functional.interpolate( + old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear" + ) + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]] + ) + + relative_position_index = self.generate_relative_position_index(window_size) + relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)] + + # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads + relative_position_bias = relative_position_bias.view( + window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1 + ) + # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + + if interpolate_pos_encoding: + relative_position_bias = nn.functional.interpolate( + relative_position_bias.unsqueeze(1), + size=(dim_size, dim_size), + mode="bilinear", + align_corners=False, + ).squeeze(1) + + return relative_position_bias.unsqueeze(0) + + +# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision +class Data2VecVisionEncoder(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.config = config + self.has_relative_position_bias = config.use_shared_relative_position_bias + if self.has_relative_position_bias: + self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")] + self.layer = nn.ModuleList( + [ + Data2VecVisionLayer( + config, + window_size=window_size if config.use_relative_position_bias else None, + drop_path_rate=dpr[i], + ) + for i in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + interpolate_pos_encoding: bool = False, + resolution: Optional[tuple[int, int]] = None, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.has_relative_position_bias: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + relative_position_bias = self.relative_position_bias( + window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + else: + relative_position_bias = None + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + head_mask=layer_head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +# Copied from transformers.models.beit.modeling_beit.BeitPreTrainedModel with Beit->Data2VecVision,beit->data2vec_vision +class Data2VecVisionPreTrainedModel(PreTrainedModel): + config: Data2VecVisionConfig + base_model_prefix = "data2vec_vision" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Data2VecVisionLayer"] + _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"] + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Data2VecVisionEmbeddings): + module.cls_token.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, Data2VecVisionRelativePositionBias): + module.relative_position_bias_table.data.zero_() + elif isinstance(module, Data2VecVisionLayer): + if module.lambda_1 is not None: + module.lambda_1.data.fill_(self.config.layer_scale_init_value) + module.lambda_2.data.fill_(self.config.layer_scale_init_value) + + +@auto_docstring +# Copied from transformers.models.beit.modeling_beit.BeitModel with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,True->False +class Data2VecVisionModel(Data2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False) -> None: + r""" + add_pooling_layer (bool, *optional*, defaults to `False`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = Data2VecVisionEmbeddings(config) + self.encoder = Data2VecVisionEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape) + + self.layernorm = ( + nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + self.pooler = Data2VecVisionPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + resolution = pixel_values.shape[2:] + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + resolution=resolution, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return Data2VecVisionModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.beit.modeling_beit.BeitPooler with Beit->Data2VecVision +class Data2VecVisionPooler(nn.Module): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.layernorm = ( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.layernorm is not None: + # Mean pool the final hidden states of the patch tokens + patch_tokens = hidden_states[:, 1:, :] + pooled_output = self.layernorm(patch_tokens.mean(1)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, 0] + + return pooled_output + + +@auto_docstring( + custom_intro=""" + Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of + the final hidden states of the patch tokens) e.g. for ImageNet. + """ +) +# Copied from transformers.models.beit.modeling_beit.BeitForImageClassification with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,beit->data2vec_vision +class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=True) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.data2vec_vision( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.beit.modeling_beit.BeitConvModule with Beit->Data2VecVision +class Data2VecVisionConvModule(nn.Module): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, tuple[int, int]], + padding: Union[int, tuple[int, int], str] = 0, + bias: bool = False, + dilation: Union[int, tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + dilation=dilation, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.conv(input) + output = self.bn(output) + output = self.activation(output) + + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision +class Data2VecVisionPyramidPoolingBlock(nn.Module): + def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None: + super().__init__() + self.layers = [ + nn.AdaptiveAvgPool2d(pool_scale), + Data2VecVisionConvModule(in_channels, channels, kernel_size=1), + ] + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision +class Data2VecVisionPyramidPoolingModule(nn.Module): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + align_corners (bool): align_corners argument of F.interpolate. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None: + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.blocks = [] + for i, pool_scale in enumerate(pool_scales): + block = Data2VecVisionPyramidPoolingBlock( + pool_scale=pool_scale, in_channels=in_channels, channels=channels + ) + self.blocks.append(block) + self.add_module(str(i), block) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + ppm_outs = [] + for ppm in self.blocks: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +# Copied from transformers.models.beit.modeling_beit.BeitUperHead with Beit->Data2VecVision +class Data2VecVisionUperHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://huggingface.co/papers/1807.10221). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] + self.channels = config.hidden_size + self.align_corners = False + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + # PSP Module + self.psp_modules = Data2VecVisionPyramidPoolingModule( + self.pool_scales, + self.in_channels[-1], + self.channels, + align_corners=self.align_corners, + ) + self.bottleneck = Data2VecVisionConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = Data2VecVisionConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = Data2VecVisionConvModule(self.channels, self.channels, kernel_size=3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = Data2VecVisionConvModule( + len(self.in_channels) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitFCNHead with Beit->Data2VecVision +class Data2VecVisionFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented of + [FCNNet](https://huggingface.co/papers/1411.4038>). + + Args: + config (Data2VecVisionConfig): Configuration. + in_channels + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + config: Data2VecVisionConfig, + in_index: int = 2, + kernel_size: int = 3, + dilation: Union[int, tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.in_channels = config.hidden_size + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + Data2VecVisionConvModule( + self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + for i in range(self.num_convs - 1): + convs.append( + Data2VecVisionConvModule( + self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = Data2VecVisionConvModule( + self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) + if self.concat_input: + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) + output = self.classifier(output) + return output + + +@auto_docstring +# Copied from transformers.models.beit.modeling_beit.BeitForSemanticSegmentation with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,microsoft/beit-base-finetuned-ade-640-640->facebook/data2vec-vision-base,beit->data2vec_vision +class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False) + + # FPNs + if len(self.config.out_indices) != 4: + raise ValueError( + "Data2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, " + "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of " + "a base-sized architecture." + ) + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + nn.BatchNorm2d(config.hidden_size), + nn.GELU(), + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + # Semantic segmentation head(s) + self.decode_head = Data2VecVisionUperHead(config) + self.auxiliary_head = Data2VecVisionFCNHead(config) if config.use_auxiliary_head else None + + # Initialize weights and apply final processing + self.post_init() + + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if auxiliary_logits is not None: + upsampled_auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + main_loss = loss_fct(upsampled_logits, labels) + loss = main_loss + if auxiliary_logits is not None: + auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) + loss += self.config.auxiliary_loss_weight * auxiliary_loss + + return loss + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Data2VecVisionForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base") + >>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.data2vec_vision( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + # only keep certain features, and reshape + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] + batch_size = pixel_values.shape[0] + patch_resolution = self.config.image_size // self.config.patch_size + features = [ + x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features + ] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + logits = self.decode_head(features) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + + loss = None + if labels is not None: + loss = self.compute_loss(logits, auxiliary_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Data2VecVisionForImageClassification", + "Data2VecVisionForSemanticSegmentation", + "Data2VecVisionModel", + "Data2VecVisionPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa0fe1f811ee1748956c2481d28b0ad1ef516e0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -0,0 +1,1723 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Data2Vec Vision model.""" + +from __future__ import annotations + +import collections.abc +import math +from dataclasses import dataclass + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFSemanticSegmenterOutput, + TFSequenceClassifierOutput, +) +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_data2vec_vision import Data2VecVisionConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Data2VecVisionConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k" +_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote" + + +@dataclass +class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling): + """ + Class for outputs of [`TFData2VecVisionModel`]. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor | None = None + pooler_output: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +class TFData2VecVisionDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_path, **kwargs): + super().__init__(**kwargs) + self.drop_path = drop_path + + def call(self, x, training=None): + if training: + keep_prob = 1 - self.drop_path + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +class TFData2VecVisionEmbeddings(keras.layers.Layer): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: Data2VecVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings") + self.num_patches = self.patch_embeddings.num_patches + self.config = config + + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + + def build(self, input_shape=None): + self.cls_token = self.add_weight( + shape=(1, 1, self.config.hidden_size), + initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), + trainable=True, + name="cls_token", + ) + if self.config.use_mask_token: + self.mask_token = self.add_weight( + shape=(1, 1, self.config.hidden_size), + initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), + trainable=True, + name="mask_token", + ) + else: + self.mask_token = None + + if self.config.use_absolute_position_embeddings: + self.position_embeddings = self.add_weight( + shape=(1, self.num_patches + 1, self.config.hidden_size), + initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), + trainable=True, + name="position_embeddings", + ) + else: + self.position_embeddings = None + + if self.built: + return + self.built = True + if getattr(self, "patch_embeddings", None) is not None: + with tf.name_scope(self.patch_embeddings.name): + self.patch_embeddings.build(None) + + def call(self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None) -> tf.Tensor: + embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_len, projection_dim = shape_list(embeddings) + + cls_tokens = tf.tile(self.cls_token, (batch_size, 1, 1)) + + if bool_masked_pos is not None: + mask_tokens = tf.broadcast_to(self.mask_token, (batch_size, seq_len, projection_dim)) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos[..., None] + w = tf.cast(w, mask_tokens.dtype) + # since TF doesn't support eager tensor assignment + embeddings = embeddings * (1 - w) + mask_tokens * w + + embeddings = tf.concat([cls_tokens, embeddings], axis=1) + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +class TFData2VecVisionPatchEmbeddings(keras.layers.Layer): + """ + Image to Patch Embedding. + """ + + def __init__(self, config: Data2VecVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + self.patch_shape = patch_shape + self.num_channels = num_channels + + self.projection = keras.layers.Conv2D( + filters=hidden_size, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + data_format="channels_last", + kernel_initializer="glorot_uniform", # following torch.nn.Linear + bias_initializer="zeros", + name="projection", + ) + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly(): + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the" + " configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + + # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + projection = self.projection(pixel_values) + + # Change the 2D spatial dimensions to a single temporal dimension. + # shape = (batch_size, num_patches, out_channels=embed_dim) + num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) + + return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + + +class TFData2VecVisionSelfAttention(keras.layers.Layer): + def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + name="key", + use_bias=False, + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + if window_size: + self.relative_position_bias = TFData2VecVisionRelativePositionBias( + config, window_size=window_size, name="relative_position_bias" + ) + else: + self.relative_position_bias = None + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + relative_position_bias: TFData2VecVisionRelativePositionBias | None = None, + training: bool = False, + ) -> tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + mixed_key_layer = self.key(inputs=hidden_states) + mixed_value_layer = self.value(inputs=hidden_states) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + attention_scores = attention_scores / self.sqrt_att_head_size + + # Add relative position bias if present. + if self.relative_position_bias is not None: + # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras + # might complain about `Layer.call()` not being invoked properly. In this case this input + # i.e., 0.0 is not going to be used in any calculations so we're safe. + attention_scores = attention_scores + self.relative_position_bias(0.0)[None, ...] + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_scores = attention_scores + relative_position_bias + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + if getattr(self, "relative_position_bias", None) is not None: + with tf.name_scope(self.relative_position_bias.name): + self.relative_position_bias.build(None) + + +class TFData2VecVisionSelfOutput(keras.layers.Layer): + """ + The residual connection is defined in TFData2VecVisionLayer instead of here (as is the case with other models), due + to the layernorm applied before each block. + """ + + def __init__(self, config: Data2VecVisionConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFData2VecVisionAttention(keras.layers.Layer): + def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs): + super().__init__(**kwargs) + + self.attention = TFData2VecVisionSelfAttention(config, window_size=window_size, name="attention") + self.dense_output = TFData2VecVisionSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + relative_position_bias: TFData2VecVisionRelativePositionBias | None = None, + training: bool = False, + ) -> tuple[tf.Tensor]: + self_outputs = self.attention( + hidden_states=input_tensor, + head_mask=head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision +class TFData2VecVisionIntermediate(keras.layers.Layer): + def __init__(self, config: Data2VecVisionConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFData2VecVisionOutput(keras.layers.Layer): + def __init__(self, config: Data2VecVisionConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + + +class TFData2VecVisionLayer(keras.layers.Layer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__( + self, config: Data2VecVisionConfig, window_size: tuple | None = None, drop_path_rate: float = 0.0, **kwargs + ): + super().__init__(**kwargs) + self.config = config + + self.attention = TFData2VecVisionAttention(config, window_size=window_size, name="attention") + self.intermediate = TFData2VecVisionIntermediate(config, name="intermediate") + self.data2vec_output = TFData2VecVisionOutput(config, name="output") + + self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") + self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + # Using `layers.Activation` instead of `tf.identity` to better control `training` + # behaviour. + self.drop_path = ( + TFData2VecVisionDropPath(drop_path_rate, name="drop_path") + if drop_path_rate > 0.0 + else keras.layers.Activation("linear", name="drop_path") + ) + self.init_values = config.layer_scale_init_value + + def build(self, input_shape: tf.TensorShape = None): + if self.init_values > 0: + self.lambda_1 = self.add_weight( + shape=(self.config.hidden_size), + initializer="ones", + trainable=True, + name="lambda_1", + ) + self.lambda_2 = self.add_weight( + shape=(self.config.hidden_size), + initializer="ones", + trainable=True, + name="lambda_2", + ) + self.lambda_1.assign(self.init_values * tf.ones(self.config.hidden_size)) + self.lambda_2.assign(self.init_values * tf.ones(self.config.hidden_size)) + else: + self.lambda_1, self.lambda_2 = None, None + + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "data2vec_output", None) is not None: + with tf.name_scope(self.data2vec_output.name): + self.data2vec_output.build(None) + if getattr(self, "layernorm_before", None) is not None: + with tf.name_scope(self.layernorm_before.name): + self.layernorm_before.build([None, None, self.config.hidden_size]) + if getattr(self, "layernorm_after", None) is not None: + with tf.name_scope(self.layernorm_after.name): + self.layernorm_after.build([None, None, self.config.hidden_size]) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + relative_position_bias: TFData2VecVisionRelativePositionBias | None = None, + training: bool = False, + ) -> tuple[tf.Tensor]: + self_attention_outputs = self.attention( + # in Data2VecVision, layernorm is applied before self-attention + input_tensor=self.layernorm_before(inputs=hidden_states), + head_mask=head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + training=training, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Data2VecVision, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.data2vec_output(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Taken and modified from here: +# https://github.com/leondgarse/keras_cv_attention_models/blob/main/keras_cv_attention_models/beit/beit.py#L28 +class TFData2VecVisionRelativePositionBias(keras.layers.Layer): + def __init__(self, config: Data2VecVisionConfig, window_size: tuple, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + + self.window_size = window_size + # +3 for cls_token_pos_len + # window_size can be something like (14, 14) + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + + self.relative_position_index = self.get_position_index() + + def build(self, input_shape): + self.relative_position_bias_table = self.add_weight( + shape=(self.num_relative_distance, self.config.num_attention_heads), + initializer="zeros", + trainable=True, + name="relative_position_bias_table", + ) # [2*Wh-1 * 2*Ww-1, nH] + # cls to token & token 2 cls & cls to cls + + super().build(input_shape) + + def get_position_index(self): + # get pair-wise relative position index for each token inside the window + xx, yy = tf.meshgrid(range(self.window_size[0]), range(self.window_size[1])) + coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww] + coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww] + + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Wh*Ww, Wh*Ww] + relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0]) # [Wh*Ww, Wh*Ww, 2] + + xx = (relative_coords[:, :, 0] + self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + yy = relative_coords[:, :, 1] + self.window_size[1] - 1 + relative_coords = tf.stack([xx, yy], axis=-1) + + relative_position_index = tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww] + + top = tf.ones((1, relative_position_index.shape[1]), dtype=relative_position_index.dtype) * ( + self.num_relative_distance - 3 + ) + left = tf.ones((relative_position_index.shape[0], 1), dtype=relative_position_index.dtype) * ( + self.num_relative_distance - 2 + ) + corner = tf.ones((1, 1), dtype=relative_position_index.dtype) * (self.num_relative_distance - 1) + + left_corner = tf.concat([corner, left], axis=0) + relative_position_index = tf.concat([top, relative_position_index], axis=0) + relative_position_index = tf.concat([left_corner, relative_position_index], axis=1) # [Wh*Ww + 1, Wh*Ww + 1] + return relative_position_index + + def call(self, inputs=None) -> tf.Tensor: + relative_position_bias = tf.gather(self.relative_position_bias_table, self.relative_position_index, axis=0) + return tf.transpose(relative_position_bias, [2, 0, 1]) + + +class TFData2VecVisionEncoder(keras.layers.Layer): + def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + if config.use_shared_relative_position_bias: + self.relative_position_bias = TFData2VecVisionRelativePositionBias( + config, window_size=window_size, name="relative_position_bias" + ) + else: + self.relative_position_bias = None + + # stochastic depth decay rule + dpr = list(tf.linspace(0.0, config.drop_path_rate, config.num_hidden_layers)) + self.layer = [ + TFData2VecVisionLayer( + config, + window_size=window_size if config.use_relative_position_bias else None, + drop_path_rate=dpr[i], + name=f"layer_._{i}", + ) + for i in range(config.num_hidden_layers) + ] + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> tuple | TFBaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras + # might complain about `Layer.call()` not being invoked properly. In this case this input + # i.e., 0.0 is not going to be used in any calculations so we're safe. + relative_position_bias = ( + self.relative_position_bias(0.0) if self.relative_position_bias is not None else None + ) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "relative_position_bias", None) is not None: + with tf.name_scope(self.relative_position_bias.name): + self.relative_position_bias.build(None) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFData2VecVisionMainLayer(keras.layers.Layer): + config_class = Data2VecVisionConfig + + def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.add_pooling_layer = add_pooling_layer + + self.embeddings = TFData2VecVisionEmbeddings(config, name="embeddings") + self.encoder = TFData2VecVisionEncoder( + config, window_size=self.embeddings.patch_embeddings.patch_shape, name="encoder" + ) + self.layernorm = ( + tf.identity + if config.use_mean_pooling + else keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + ) + + # We are setting the `data_format` like so because from here on we will revert to the + # NCHW output format + self.pooler = TFData2VecVisionPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> tuple | TFData2VecVisionModelOutputWithPooling: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings(pixel_values, bool_masked_pos, training=training) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return TFData2VecVisionModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "layernorm", None) is not None: + if hasattr(self.layernorm, "name"): + with tf.name_scope(self.layernorm.name): + self.layernorm.build((None, self.config.hidden_size)) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFData2VecVisionPooler(keras.layers.Layer): + def __init__(self, config: Data2VecVisionConfig, **kwargs): + super().__init__(**kwargs) + self.layernorm = ( + keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + if config.use_mean_pooling + else None + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + if self.layernorm is not None: + # Mean pool the final hidden states of the patch tokens + patch_tokens = hidden_states[:, 1:, :] + pooled_output = self.layernorm(tf.reduce_mean(patch_tokens, axis=1)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, 0] + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layernorm", None) is not None: + if hasattr(self.layernorm, "name"): + with tf.name_scope(self.layernorm.name): + self.layernorm.build((None, self.config.hidden_size)) + + +class TFData2VecVisionPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Data2VecVisionConfig + base_model_prefix = "data2vec_vision" + main_input_name = "pixel_values" + _keys_to_ignore_on_load_unexpected = [r"relative_position_index"] + + +DATA2VEC_VISION_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.). + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DATA2VEC_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BeitImageProcessor.__call__`] for details. + + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.", + DATA2VEC_VISION_START_DOCSTRING, +) +class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.config = config + + self.data2vec_vision = TFData2VecVisionMainLayer( + config, add_pooling_layer=add_pooling_layer, name="data2vec_vision" + ) + + def get_input_embeddings(self): + return self.data2vec_vision.get_input_embeddings() + + @unpack_inputs + @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFData2VecVisionModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: TFModelInputType | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> tuple | TFData2VecVisionModelOutputWithPooling: + r""" + bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + outputs = self.data2vec_vision( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "data2vec_vision", None) is not None: + with tf.name_scope(self.data2vec_vision.name): + self.data2vec_vision.build(None) + + +@add_start_docstrings( + """ + Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of + the final hidden states of the patch tokens) e.g. for ImageNet. + """, + DATA2VEC_VISION_START_DOCSTRING, +) +class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=True, name="data2vec_vision") + + # Classifier head + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: TFModelInputType | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.data2vec_vision( + pixel_values=pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + logits = self.classifier(pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "data2vec_vision", None) is not None: + with tf.name_scope(self.data2vec_vision.name): + self.data2vec_vision.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +class TFData2VecVisionConvModule(keras.layers.Layer): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + padding: str = "valid", + bias: bool = False, + dilation: int | tuple[int, int] = 1, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.conv = keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + padding=padding, + use_bias=bias, + dilation_rate=dilation, + name="conv", + ) + self.bn = keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5) + self.activation = tf.nn.relu + self.in_channels = in_channels + self.out_channels = out_channels + + def call(self, input: tf.Tensor) -> tf.Tensor: + output = self.conv(input) + output = self.bn(output) + output = self.activation(output) + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build([None, None, None, self.in_channels]) + if getattr(self, "bn", None) is not None: + with tf.name_scope(self.bn.name): + self.bn.build((None, None, None, self.out_channels)) + + +class TFAdaptiveAvgPool2D(keras.layers.Layer): + def __init__(self, output_dims: tuple[int, int], input_ordering: str = "NHWC", **kwargs): + super().__init__(**kwargs) + self.output_dims = output_dims + self.input_ordering = input_ordering + if input_ordering not in ("NCHW", "NHWC"): + raise ValueError("Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!") + self.h_axis = input_ordering.index("H") + self.w_axis = input_ordering.index("W") + + def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool): + # Figure out which axis we're pooling on + if h_pooling: + axis = self.h_axis + output_dim = self.output_dims[0] + else: + axis = self.w_axis + output_dim = self.output_dims[1] + input_dim = inputs.shape[axis] + + # Figure out the potential pooling windows + # This is the key idea - the torch op always uses only two + # consecutive pooling window sizes, like 3 and 4. Therefore, + # if we pool with both possible sizes, we simply need to gather + # the 'correct' pool at each position to reimplement the torch op. + small_window = math.ceil(input_dim / output_dim) + big_window = small_window + 1 + if h_pooling: + output_dim = self.output_dims[0] + small_window_shape = (small_window, 1) + big_window_shape = (big_window, 1) + else: + output_dim = self.output_dims[1] + small_window_shape = (1, small_window) + big_window_shape = (1, big_window) + + # For resizes to 1, or integer resizes, we can take quick shortcuts + if output_dim == input_dim: + return inputs + elif output_dim == 1: + return tf.reduce_mean(inputs, axis=axis, keepdims=True) + elif input_dim % output_dim == 0: + return tf.nn.avg_pool2d( + inputs, + ksize=small_window_shape, + strides=small_window_shape, + padding="VALID", + data_format=self.input_ordering, + ) + # When upscaling by an integer factor we can also take a quick shortcut + elif output_dim > input_dim and output_dim % input_dim == 0: + return tf.repeat(inputs, repeats=output_dim // input_dim, axis=axis) + + # For non-integer resizes, we pool with both possible window sizes and concatenate them + if output_dim < input_dim: + small_pool = tf.nn.avg_pool2d( + inputs, ksize=small_window_shape, strides=1, padding="VALID", data_format=self.input_ordering + ) + big_pool = tf.nn.avg_pool2d( + inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering + ) + both_pool = tf.concat([small_pool, big_pool], axis=axis) + else: + # When we're actually upscaling instead, then we build the pools a bit differently + small_pool = inputs + big_pool = tf.nn.avg_pool2d( + inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering + ) + both_pool = tf.concat([small_pool, big_pool], axis=axis) + + # We compute vectors of the start and end positions for each pooling window + # Each (start, end) pair here corresponds to a single output position + window_starts = tf.math.floor((tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim) + window_starts = tf.cast(window_starts, tf.int64) + window_ends = tf.math.ceil((tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) / output_dim) + window_ends = tf.cast(window_ends, tf.int64) + + # pool_selector is a boolean array of shape (output_dim,) where 1 indicates that output position + # has a big receptive field and 0 indicates that that output position has a small receptive field + pool_selector = tf.cast(window_ends - window_starts - small_window, tf.bool) + + # Since we concatenated the small and big pools, we need to do a bit of + # pointer arithmetic to get the indices of the big pools + small_indices = window_starts + big_indices = window_starts + small_pool.shape[axis] + + # Finally, we use the pool_selector to generate a list of indices, one per output position + gather_indices = tf.where(pool_selector, big_indices, small_indices) + + # Gathering from those indices yields the final, correct pooling + return tf.gather(both_pool, gather_indices, axis=axis) + + def call(self, inputs: tf.Tensor): + if self.input_ordering == "NHWC": + input_shape = inputs.shape[1:3] + else: + input_shape = inputs.shape[2:] + + # We break the task down into each possible case + # Firstly, if we're resizing down to 1, it's just tf.reduce_mean + if self.output_dims[0] == self.output_dims[1] == 1: + if self.input_ordering == "NHWC": + reduce_dims = [1, 2] + else: + reduce_dims = [2, 3] + return tf.reduce_mean(inputs, axis=reduce_dims, keepdims=True) + # Secondly, if we're resizing by an integer factor on both dimensions, we can take a quick shortcut + elif input_shape[0] % self.output_dims[0] == 0 and input_shape[1] % self.output_dims[1] == 0: + h_resize = int(input_shape[0] // self.output_dims[0]) + w_resize = int(input_shape[1] // self.output_dims[1]) + return tf.nn.avg_pool2d( + inputs, + ksize=(h_resize, w_resize), + strides=(h_resize, w_resize), + padding="VALID", + data_format=self.input_ordering, + ) + else: + # Finally, if we can't take the shortcut, we do a 1D pool on each axis. pseudo_1d_pool will take a shortcut + # for dimensions where an integer resize is possible. It can also handle upscaling. + h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True) + return self.pseudo_1d_pool(h_pooled, h_pooling=False) + + +class TFData2VecVisionPyramidPoolingModule(keras.layers.Layer): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + channels (int): Channels after modules, before conv_seg. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, pool_scales: tuple[int, ...], in_channels: int, out_channels: int, **kwargs) -> None: + super().__init__(**kwargs) + self.pool_scales = pool_scales + self.in_channels = in_channels + self.out_channels = out_channels + + self.layer_list = [] + for idx, pool_scale in enumerate(pool_scales): + pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale) + self.layer_list.append( + [ + TFAdaptiveAvgPool2D(output_dims=pool_scale), + TFData2VecVisionConvModule( + in_channels=in_channels, out_channels=self.out_channels, kernel_size=1, name=f"{idx}.1" + ), + ] + ) + + def call(self, x: tf.Tensor) -> list[tf.Tensor]: + ppm_outs = [] + inputs = x + + for ppm in self.layer_list: + for layer_module in ppm: + ppm_out = layer_module(x) + x = ppm_out + + upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear") + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + def build(self, input_shape=None): + for layer in self.layer_list: + for layer_module in layer: + with tf.name_scope(layer_module.name): + layer_module.build(None) + + +class TFData2VecVisionUperHead(keras.layers.Layer): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://huggingface.co/papers/1807.10221). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None: + super().__init__(**kwargs) + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] + self.channels = config.hidden_size + self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier") + + # PSP Module + self.psp_modules = TFData2VecVisionPyramidPoolingModule( + self.pool_scales, self.in_channels[-1], self.channels, name="psp_modules" + ) + self.bottleneck = TFData2VecVisionConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding="same", + name="bottleneck", + ) + # FPN Module + self.lateral_convs = [] + self.fpn_convs = [] + for idx, in_channels in enumerate(self.in_channels[:-1]): # skip the top layer + l_conv = TFData2VecVisionConvModule( + in_channels, out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}" + ) + fpn_conv = TFData2VecVisionConvModule( + in_channels=self.channels, + out_channels=self.channels, + kernel_size=3, + padding="same", + name=f"fpn_convs.{idx}", + ) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = TFData2VecVisionConvModule( + in_channels=len(self.in_channels) * self.channels, + out_channels=self.channels, + kernel_size=3, + padding="same", + name="fpn_bottleneck", + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = tf.concat(psp_outs, axis=-1) + output = self.bottleneck(psp_outs) + + return output + + def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor: + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = shape_list(laterals[i - 1])[1:-1] + laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear") + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear") + fpn_outs = tf.concat(fpn_outs, axis=-1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, None, self.channels]) + if getattr(self, "psp_modules", None) is not None: + with tf.name_scope(self.psp_modules.name): + self.psp_modules.build(None) + if getattr(self, "bottleneck", None) is not None: + with tf.name_scope(self.bottleneck.name): + self.bottleneck.build(None) + if getattr(self, "fpn_bottleneck", None) is not None: + with tf.name_scope(self.fpn_bottleneck.name): + self.fpn_bottleneck.build(None) + for layer in self.lateral_convs: + with tf.name_scope(layer.name): + layer.build(None) + for layer in self.fpn_convs: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFData2VecVisionFCNHead(keras.layers.Layer): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented from + [FCNNet](https://huggingface.co/papers/1411.4038). + + Args: + config (Data2VecVisionConfig): Configuration. + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + config: Data2VecVisionConfig, + in_index: int = 2, + kernel_size: int = 3, + dilation: int | tuple[int, int] = 1, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.in_channels = config.hidden_size + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + convs = [] + convs.append( + TFData2VecVisionConvModule( + in_channels=self.in_channels, + out_channels=self.channels, + kernel_size=kernel_size, + padding="same", + dilation=dilation, + name="convs.0", + ) + ) + for i in range(self.num_convs - 1): + convs.append( + TFData2VecVisionConvModule( + in_channels=self.channels, + out_channels=self.channels, + kernel_size=kernel_size, + padding="same", + dilation=dilation, + name=f"conv_module_{i + 2}", + ) + ) + if self.num_convs == 0: + self.convs = [tf.identity] + else: + self.convs = convs + if self.concat_input: + self.conv_cat = TFData2VecVisionConvModule( + self.in_channels + self.channels, + out_channels=self.channels, + kernel_size=kernel_size, + padding="same", + name="conv_cat", + ) + + self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier") + + def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = hidden_states + for layer_module in self.convs: + output = layer_module(output) + if self.concat_input: + output = self.conv_cat(tf.concat([hidden_states, output], axis=-1)) + output = self.classifier(output) + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, None, self.channels]) + if getattr(self, "conv_cat", None) is not None: + with tf.name_scope(self.conv_cat.name): + self.conv_cat.build(None) + + +@add_start_docstrings( + """ + Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DATA2VEC_VISION_START_DOCSTRING, +) +class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision") + + # FPNs + self.fpn1 = [ + keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"), + keras.layers.BatchNormalization(name="fpn1.1", momentum=0.9, epsilon=1e-5), + keras.layers.Activation("gelu"), + keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"), + ] + self.fpn2 = [keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")] + + self.fpn3 = tf.identity + self.fpn4 = keras.layers.MaxPool2D(pool_size=2, strides=2) + + # Semantic segmentation head(s) + self.decode_head = TFData2VecVisionUperHead(config, name="decode_head") + self.auxiliary_head = ( + TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None + ) + + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + if len(shape_list(labels)) > 3: + label_interp_shape = shape_list(labels)[1:-1] + else: + label_interp_shape = shape_list(labels)[-2:] + + upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") + if auxiliary_logits is not None: + upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear") + # compute weighted loss + loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") + + # Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics. + # Utility to mask the index to ignore during computing the loss. + def masked_loss(real, pred): + mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index)) + loss_ = loss_fct(real, pred) + mask = tf.cast(mask, dtype=loss_.dtype) + loss_ *= mask + reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask) + return tf.reshape(reduced_masked_loss, (1,)) + + main_loss = masked_loss(labels, upsampled_logits) + auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits) + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + return loss + + @unpack_inputs + @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + labels: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | TFSemanticSegmenterOutput: + r""" + labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFData2VecVisionForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base") + >>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.data2vec_vision( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + # only keep certain features, and reshape + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] + patch_resolution = self.config.image_size // self.config.patch_size + + def reshape_features(x): + # We do it this way so TF can always infer the non-batch dims at compile time + x = tf.reshape(x, (-1, patch_resolution, patch_resolution, self.config.hidden_size)) + return x + + features = [reshape_features(x[:, 1:, :]) for x in features] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for module in ops[0]: + features[0] = module(features[0]) + features[1] = ops[1][0](features[1]) + for i in range(len(features[2:])): + features[i + 2] = ops[i + 2](features[i + 2]) + + logits = self.decode_head(features) + # Transpose the logits to maintain consistency in the output formats. + transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2]) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + + loss = None + if labels is not None: + loss = self.compute_loss(logits, auxiliary_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSemanticSegmenterOutput( + loss=loss, + logits=transposed_logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "data2vec_vision", None) is not None: + with tf.name_scope(self.data2vec_vision.name): + self.data2vec_vision.build(None) + if getattr(self, "decode_head", None) is not None: + with tf.name_scope(self.decode_head.name): + self.decode_head.build(None) + if getattr(self, "auxiliary_head", None) is not None: + with tf.name_scope(self.auxiliary_head.name): + self.auxiliary_head.build(None) + if getattr(self, "fpn1", None) is not None: + with tf.name_scope(self.fpn1[0].name): + self.fpn1[0].build([None, None, None, self.config.hidden_size]) + with tf.name_scope(self.fpn1[1].name): + self.fpn1[1].build((None, None, None, self.config.hidden_size)) + with tf.name_scope(self.fpn1[3].name): + self.fpn1[3].build([None, None, None, self.config.hidden_size]) + if getattr(self, "fpn2", None) is not None: + with tf.name_scope(self.fpn2[0].name): + self.fpn2[0].build([None, None, None, self.config.hidden_size]) + + +__all__ = [ + "TFData2VecVisionForImageClassification", + "TFData2VecVisionForSemanticSegmentation", + "TFData2VecVisionModel", + "TFData2VecVisionPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/data2vec/modular_data2vec_audio.py b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modular_data2vec_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..91cb04730e4aec01edff0f35701eb12c3ee5af3d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/data2vec/modular_data2vec_audio.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Data2VecText model.""" + +import math + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import Wav2Vec2BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Adapter, + Wav2Vec2Encoder, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeatureProjection, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, + Wav2Vec2SamePadLayer, +) +from .configuration_data2vec_audio import Data2VecAudioConfig + + +class Data2VecAudioConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer): + pass + + +class Data2VecAudioPositionalConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.conv_pos_kernel_size, + padding=config.conv_pos_kernel_size // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size) + self.activation = ACT2FN[config.feat_extract_activation] + # no learnable parameters + self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Data2VecAudioPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList( + [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)] + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + for layer in self.layers: + hidden_states = layer(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder): + def __init__(self, config): + nn.Module.__init__(self) + self.conv_layers = nn.ModuleList( + [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + ) + self.gradient_checkpointing = False + self._requires_grad = True + + +class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection): + pass + + +class Data2VecAudioEncoder(Wav2Vec2Encoder): + pass + + +class Data2VecAudioAdapter(Wav2Vec2Adapter): + pass + + +class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): + config: Data2VecAudioConfig + base_model_prefix = "data2vec_audio" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Data2VecAudioFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, Data2VecAudioPositionalConvLayer): + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + if module.bias is not None: + module.bias.data.zero_() + if module.weight is not None: + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_adapters(self): + raise AttributeError("Not needed for Data2VecAudio") + + def init_adapter_layers(self): + raise AttributeError("Not needed for Data2VecAudio") + + def load_adapter(self): + raise AttributeError("Not needed for Data2VecAudio") + + +Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput + + +class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model): + def __init__(self, config: Data2VecAudioConfig): + Data2VecAudioPreTrainedModel.__init__(self, config) + self.config = config + self.feature_extractor = Data2VecAudioFeatureEncoder(config) + self.feature_projection = Data2VecAudioFeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = Data2VecAudioEncoder(config) + + self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Data2VecAudio") + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.feature_extractor._freeze_parameters() + + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC): + def __init__(self, config): + Data2VecAudioPreTrainedModel.__init__(self, config) + + self.data2vec_audio = Data2VecAudioModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_base_model(self): + raise AttributeError("Not needed for Data2VecAudio") + + def tie_weights(self): + raise AttributeError("Not needed for Data2VecAudio") + + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification): + pass + + +class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): + pass + + +class Data2VecAudioForXVector(Wav2Vec2ForXVector): + pass + + +__all__ = [ + "Data2VecAudioForAudioFrameClassification", + "Data2VecAudioForCTC", + "Data2VecAudioForSequenceClassification", + "Data2VecAudioForXVector", + "Data2VecAudioModel", + "Data2VecAudioPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dbrx/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dbrx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cce0f34c778d21a81d03940e7a7951707d898c86 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dbrx/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dbrx import * + from .modeling_dbrx import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/dbrx/configuration_dbrx.py b/venv/lib/python3.13/site-packages/transformers/models/dbrx/configuration_dbrx.py new file mode 100644 index 0000000000000000000000000000000000000000..17b6b2a368cc5f13783ae7b333dc7af78aee9d05 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dbrx/configuration_dbrx.py @@ -0,0 +1,232 @@ +# coding=utf-8 +# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DBRX model configuration""" + +from typing import Any, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DbrxAttentionConfig(PretrainedConfig): + """Configuration class for Dbrx Attention. + + [`DbrxAttention`] class. It is used to instantiate attention layers + according to the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention layers. + clip_qkv (`float`, *optional*): + If set, clip the queries, keys, and values in the attention layer to this value. + kv_n_heads (`int`, *optional*, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads. + rope_theta (`float`, *optional*, defaults to 10000.0): The base frequency for rope. + """ + + base_config_key = "attn_config" + + def __init__( + self, + attn_pdrop: float = 0.0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + +class DbrxFFNConfig(PretrainedConfig): + """Configuration class for Dbrx FFN. + + [`DbrxFFN`] class. It is used to instantiate feedforward layers according to + the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN. + The dict should have a key 'name' with the value being the name of the activation function along with + any additional keyword arguments. If `None`, then set to `{"name": "silu"}`. + ffn_hidden_size (`int`, *optional*, defaults to 3584): The hidden size of the feedforward network. + moe_num_experts (`int`, *optional*, defaults to 4): The number of experts in the mixture of experts layer. + moe_top_k (`int`, *optional*, defaults to 1): The number of experts to use in the mixture of experts layer. + moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer. + moe_loss_weight (`float`, *optional*, defaults to 0.01): The loss weight for the mixture of experts layer. + moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights. + """ + + base_config_key = "ffn_config" + + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1.0, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + +class DbrxConfig(PretrainedConfig): + r""" + + This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the + specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 2048): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + max_seq_len (`int`, *optional*, defaults to 2048): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DbrxModel`]. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + ffn_config (`dict`, *optional*): + A dictionary used to configure the model's FFN module. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details. + + + Example: + ```python + >>> from transformers import DbrxConfig, DbrxModel + + >>> # Initializing a Dbrx configuration + >>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128) + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DbrxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dbrx" + sub_configs = {"attn_config": DbrxAttentionConfig, "ffn_config": DbrxFFNConfig} + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + "max_position_embeddings": "max_seq_len", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.num_key_value_heads = self.attn_config.kv_n_heads + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError("tie_word_embeddings is not supported for DBRX models.") + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["DbrxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dbrx/modeling_dbrx.py b/venv/lib/python3.13/site-packages/transformers/models/dbrx/modeling_dbrx.py new file mode 100644 index 0000000000000000000000000000000000000000..5f3a423213cba3ef11b95a36a3efa5aa01da69c2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dbrx/modeling_dbrx.py @@ -0,0 +1,1248 @@ +# coding=utf-8 +# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DBRX model.""" + +import math +from typing import Any, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_dbrx import DbrxConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class DbrxRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def load_balancing_loss_func( + gate_probabilities: torch.Tensor, + num_experts: int, + top_k: int, + attention_mask: Optional[torch.Tensor], +) -> torch.Tensor: + r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts (`int`): + Number of experts. + top_k (`int`): + The number of experts each token is routed to. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_probabilities is None or not isinstance(gate_probabilities, tuple): + return torch.tensor(0.0) + + if isinstance(gate_probabilities, tuple): + compute_device = gate_probabilities[0].device + routing_weights = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_probabilities], dim=0) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class DbrxAttention(nn.Module): + """Multi-head self attention.""" + + def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.d_model + self.num_heads = config.n_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_seq_len + self.block_idx = block_idx + if block_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will " + + "lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` " + + "when creating this class." + ) + + attn_config = config.attn_config + self.attn_pdrop = attn_config.attn_pdrop + self.clip_qkv = attn_config.clip_qkv + self.num_key_value_heads = attn_config.kv_n_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.rope_theta = attn_config.rope_theta + self.is_causal = True + + self.Wqkv = nn.Linear( + self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False + ) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.rotary_emb = DbrxRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.Wqkv(hidden_states) + min_val = -self.clip_qkv if self.clip_qkv is not None else None + max_val = self.clip_qkv + qkv_states = qkv_states.clamp(min=min_val, max=max_val) + + query_states, key_states, value_states = qkv_states.split( + [ + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + dim=2, + ) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attn_pdrop, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DbrxFlashAttention2(DbrxAttention): + """Dbrx flash attention module. + + This module inherits from `DbrxAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it + calls the public API of flash attention. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if isinstance(past_key_values, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.") + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query_states, key_states, value_states = qkv_states.split( + [ + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + dim=2, + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attn_pdrop if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = query_states.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be " + + "related to the fact you have upcasted embedding or layer norm layers in " + + f"float32. We will cast back the input in {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DbrxSdpaAttention(DbrxAttention): + """ + Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DbrxAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DbrxModel is using DbrxSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query_states, key_states, value_states = qkv_states.split( + [ + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + dim=2, + ) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = causal_mask is None and q_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attn_pdrop if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +DBRX_ATTENTION_CLASSES = { + "eager": DbrxAttention, + "flash_attention_2": DbrxFlashAttention2, + "sdpa": DbrxSdpaAttention, +} + + +class DbrxNormAttentionNorm(nn.Module): + def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None): + super().__init__() + self.block_idx = block_idx + self.resid_pdrop = config.resid_pdrop + self.norm_1 = nn.LayerNorm(config.d_model, bias=False) + self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation]( + config=config, + block_idx=block_idx, + ) + self.norm_2 = nn.LayerNorm(config.d_model, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + residual_states = hidden_states + hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training) + hidden_states = hidden_states + residual_states + + residual_states = hidden_states + hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype) + + return residual_states, hidden_states, attn_weights + + +class DbrxRouter(nn.Module): + def __init__( + self, + hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + moe_jitter_eps: Optional[float], + moe_normalize_expert_weights: Optional[float], + ): + super().__init__() + self.hidden_size = hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_normalize_expert_weights = moe_normalize_expert_weights + + self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + if self.training and self.moe_jitter_eps is not None: + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps + ) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) + top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + + top_weights_scale = ( + torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True) + if self.moe_normalize_expert_weights is not None + else 1.0 + ) + top_weights = top_weights / top_weights_scale + + weights = weights.to(hidden_states.dtype) + top_weights = top_weights.to(hidden_states.dtype) + return weights, top_weights, top_experts + + +class DbrxExpertGLU(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict): + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + + self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + + act_fn_name = ffn_act_fn.get("name", "silu") + self.activation_fn = ACT2FN[act_fn_name] + + def forward( + self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor + ) -> torch.Tensor: + gate_proj = x.matmul(expert_w1.t()) + up_proj = x.matmul(expert_v1.t()) + gate_proj = self.activation_fn(gate_proj) + intermediate_states = gate_proj * up_proj + down_proj = intermediate_states.matmul(expert_w2) + return down_proj + + +class DbrxExperts(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict): + super().__init__() + self.moe_num_experts = moe_num_experts + self.mlp = DbrxExpertGLU( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + ffn_act_fn=ffn_act_fn, + ) + + def forward( + self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor + ) -> torch.Tensor: + bsz, q_len, hidden_size = x.shape + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) + + expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + # Chunk experts at once to avoid storing full parameter multiple times in autograd + w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( + self.moe_num_experts, dim=0 + ) + w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] + v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] + w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] + for expert_idx in range(0, self.moe_num_experts): + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`) + # (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested) + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + if token_idx.shape[0] == 0: + continue + + token_list = token_idx + topk_list = topk_idx + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + expert_out = ( + self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) + * top_weights[token_list, topk_list, None] + ) + + out.index_add_(0, token_idx, expert_out) + + out = out.reshape(bsz, q_len, hidden_size) + return out + + +class DbrxFFN(nn.Module): + def __init__(self, config: DbrxConfig): + super().__init__() + + ffn_config = config.ffn_config + self.router = DbrxRouter( + hidden_size=config.d_model, + moe_num_experts=ffn_config.moe_num_experts, + moe_top_k=ffn_config.moe_top_k, + moe_jitter_eps=ffn_config.moe_jitter_eps, + moe_normalize_expert_weights=ffn_config.moe_normalize_expert_weights, + ) + + self.experts = DbrxExperts( + hidden_size=config.d_model, + ffn_hidden_size=ffn_config.ffn_hidden_size, + moe_num_experts=ffn_config.moe_num_experts, + ffn_act_fn=ffn_config.ffn_act_fn, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + weights, top_weights, top_experts = self.router(x) + out = self.experts(x, weights, top_weights, top_experts) + return out, weights + + +class DbrxBlock(GradientCheckpointingLayer): + def __init__(self, config: DbrxConfig, block_idx: int): + super().__init__() + self.hidden_size = config.d_model + self.resid_pdrop = config.resid_pdrop + self.block_idx = block_idx + self.norm_attn_norm = DbrxNormAttentionNorm( + config=config, + block_idx=block_idx, + ) + self.ffn = DbrxFFN(config=config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> Union[ + tuple[torch.Tensor], + tuple[torch.Tensor, Optional[torch.Tensor]], + tuple[torch.Tensor, Optional[Cache]], + tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]], + tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], + tuple[torch.Tensor, Optional[Cache], Optional[torch.Tensor]], + tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor]], + ]: + """Forward function for DbrxBlock. + + Args: + hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)` + attention_mask (`torch.Tensor`, *optional*): attention mask of size (batch_size, sequence_length) + if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length) + if default attention is used. + past_key_values (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all + attention layers. See `attentions` under returned tensors for more detail. + output_router_logits (`bool`, *optional*): Whether or not to return the router logits. + use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are + returned and can be used to speed up decoding (see `past_key_values`). + cache_position (`torch.LongTensor`, *optional*): position ids of the cache + """ + + # Norm + Attention + Norm + resid_states, hidden_states, self_attn_weights = self.norm_attn_norm( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + # Fully Connected + hidden_states, router_logits = self.ffn(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training) + hidden_states = resid_states + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@auto_docstring +class DbrxPreTrainedModel(PreTrainedModel): + config: DbrxConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["DbrxBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + + def _init_weights(self, module: nn.Module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, DbrxExpertGLU): + module.w1.data.normal_(mean=0.0, std=std) + module.v1.data.normal_(mean=0.0, std=std) + module.w2.data.normal_(mean=0.0, std=std) + + +@auto_docstring +class DbrxModel(DbrxPreTrainedModel): + """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer. + + Args: + config ([`DbrxConfig`]): Model configuration class with all parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + + def __init__(self, config: DbrxConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.emb_pdrop = config.emb_pdrop + + self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)]) + self.norm_f = nn.LayerNorm(config.d_model, bias=False) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.wte + + def set_input_embeddings(self, value: nn.Embedding): + self.wte = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + block_outputs = block( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = block_outputs[0] + + if output_attentions: + all_self_attns += (block_outputs[1],) + + if output_router_logits: + all_router_logits += (block_outputs[-1],) + + hidden_states = self.norm_f(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring( + custom_intro=""" + The DBRX Model transformer for causal language modeling. + """ +) +class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): + def __init__(self, config: DbrxConfig): + super().__init__(config) + self.transformer = DbrxModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.moe_loss_weight = config.ffn_config.moe_loss_weight + self.num_experts = config.ffn_config.moe_num_experts + self.num_experts_per_tok = config.ffn_config.moe_top_k + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.transformer.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Embedding): + self.transformer.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Linear: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.lm_head = new_embeddings + + def set_decoder(self, decoder: DbrxModel): + self.transformer = decoder + + def get_decoder(self) -> DbrxModel: + return self.transformer + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[tuple, MoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >> from transformers import AutoTokenizer, DbrxForCausalLM + + >> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct") + >> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct") + + >> prompt = "Hey, are you conscious? Can you talk to me?" + >> inputs = tokenizer(prompt, return_tensors="pt") + + >> # Generate + >> generate_ids = model.generate(inputs.input_ids, max_length=30) + >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # No upscaling to float was ever done for Dbrx + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None and loss is not None: + loss += self.moe_loss_weight * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f70972237964208b461eee8141a4f8b1377154f6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deberta import * + from .modeling_deberta import * + from .modeling_tf_deberta import * + from .tokenization_deberta import * + from .tokenization_deberta_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta/configuration_deberta.py b/venv/lib/python3.13/site-packages/transformers/models/deberta/configuration_deberta.py new file mode 100644 index 0000000000000000000000000000000000000000..3e23a73a8c38a1b2c08a6392f8cd2a8c228799ac --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta/configuration_deberta.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright 2020, Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DeBERTa model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +if TYPE_CHECKING: + from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType + + +logger = logging.get_logger(__name__) + + +class DebertaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DebertaModel`] or a [`TFDebertaModel`]. It is + used to instantiate a DeBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the DeBERTa + [microsoft/deberta-base](https://huggingface.co/microsoft/deberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the DeBERTa model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"` + are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 0): + The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + relative_attention (`bool`, *optional*, defaults to `False`): + Whether use relative position encoding. + max_relative_positions (`int`, *optional*, defaults to 1): + The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value + as `max_position_embeddings`. + pad_token_id (`int`, *optional*, defaults to 0): + The value used to pad input_ids. + position_biased_input (`bool`, *optional*, defaults to `True`): + Whether add absolute position embedding to content embedding. + pos_att_type (`list[str]`, *optional*): + The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`, + `["p2c", "c2p"]`. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + legacy (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the legacy `LegacyDebertaOnlyMLMHead`, which does not work properly + for mask infilling tasks. + + Example: + + ```python + >>> from transformers import DebertaConfig, DebertaModel + + >>> # Initializing a DeBERTa microsoft/deberta-base style configuration + >>> configuration = DebertaConfig() + + >>> # Initializing a model (with random weights) from the microsoft/deberta-base style configuration + >>> model = DebertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deberta" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=0, + initializer_range=0.02, + layer_norm_eps=1e-7, + relative_attention=False, + max_relative_positions=-1, + pad_token_id=0, + position_biased_input=True, + pos_att_type=None, + pooler_dropout=0, + pooler_hidden_act="gelu", + legacy=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.relative_attention = relative_attention + self.max_relative_positions = max_relative_positions + self.pad_token_id = pad_token_id + self.position_biased_input = position_biased_input + + # Backwards compatibility + if isinstance(pos_att_type, str): + pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")] + + self.pos_att_type = pos_att_type + self.vocab_size = vocab_size + self.layer_norm_eps = layer_norm_eps + + self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size) + self.pooler_dropout = pooler_dropout + self.pooler_hidden_act = pooler_hidden_act + self.legacy = legacy + + +# Copied from transformers.models.deberta_v2.configuration_deberta_v2.DebertaV2OnnxConfig +class DebertaOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + if self._config.type_vocab_size > 0: + return OrderedDict( + [("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)] + ) + else: + return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)]) + + @property + def default_onnx_opset(self) -> int: + return 12 + + def generate_dummy_inputs( + self, + preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], + batch_size: int = -1, + seq_length: int = -1, + num_choices: int = -1, + is_pair: bool = False, + framework: Optional["TensorType"] = None, + num_channels: int = 3, + image_width: int = 40, + image_height: int = 40, + tokenizer: "PreTrainedTokenizerBase" = None, + ) -> Mapping[str, Any]: + dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework) + if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs: + del dummy_inputs["token_type_ids"] + return dummy_inputs + + +__all__ = ["DebertaConfig", "DebertaOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta/modeling_deberta.py b/venv/lib/python3.13/site-packages/transformers/models/deberta/modeling_deberta.py new file mode 100644 index 0000000000000000000000000000000000000000..461572b47677fbc097cac1a7e66d8d758db90108 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta/modeling_deberta.py @@ -0,0 +1,1204 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DeBERTa model.""" + +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_deberta import DebertaConfig + + +logger = logging.get_logger(__name__) + + +class DebertaLayerNorm(nn.Module): + """LayerNorm module in the TF style (epsilon inside the square root).""" + + def __init__(self, size, eps=1e-12): + super().__init__() + self.weight = nn.Parameter(torch.ones(size)) + self.bias = nn.Parameter(torch.zeros(size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_type = hidden_states.dtype + hidden_states = hidden_states.float() + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.to(input_type) + y = self.weight * hidden_states + self.bias + return y + + +class DebertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +@torch.jit.script +def build_relative_position(query_layer, key_layer): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + + query_size = query_layer.size(-2) + key_size = key_layer.size(-2) + + q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device) + k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device) + rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +###### To support a general trace, we have to define these operation as they use python objects (sizes) ################## +# which are not supported by torch.jit.trace. +# Full credits to @Szustarol +@torch.jit.script +def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int): + return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + + +@torch.jit.script +def build_rpos(query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos): + if query_layer.size(-2) != key_layer.size(-2): + return build_relative_position(query_layer, key_layer) + else: + return relative_pos + + +@torch.jit.script +def compute_attention_span(query_layer: torch.Tensor, key_layer: torch.Tensor, max_relative_positions: int): + return torch.tensor(min(max(query_layer.size(-2), key_layer.size(-2)), max_relative_positions)) + + +@torch.jit.script +def uneven_size_corrected(p2c_att, query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos): + if query_layer.size(-2) != key_layer.size(-2): + pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) + return torch.gather(p2c_att, dim=2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer)) + else: + return p2c_att + + +######################################################################################################################## + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`str`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaConfig`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) + self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + + self.relative_attention = getattr(config, "relative_attention", False) + self.talking_head = getattr(config, "talking_head", False) + + if self.talking_head: + self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) + self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) + else: + self.head_logits_proj = None + self.head_weights_proj = None + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_dropout = nn.Dropout(config.hidden_dropout_prob) + + if "c2p" in self.pos_att_type: + self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + if "p2c" in self.pos_att_type: + self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + query_states: Optional[torch.Tensor] = None, + relative_pos: Optional[torch.Tensor] = None, + rel_embeddings: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.BoolTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, *optional*): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, *optional*): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) + query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) + else: + ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) + qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] + q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype)) + k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype)) + v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype)) + query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + + query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) + value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + + rel_att: int = 0 + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + len(self.pos_att_type) + scale = scaled_size_sqrt(query_layer, scale_factor) + query_layer = query_layer / scale.to(dtype=query_layer.dtype) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.relative_attention and rel_embeddings is not None and relative_pos is not None: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + + # bxhxlxd + if self.head_logits_proj is not None: + attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) + # bsz x height x length x dimension + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + attention_probs = self.dropout(attention_probs) + if self.head_weights_proj is not None: + attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs) + + def disentangled_att_bias( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + relative_pos: torch.Tensor, + rel_embeddings: torch.Tensor, + scale_factor: int, + ): + if relative_pos is None: + relative_pos = build_relative_position(query_layer, key_layer, query_layer.device) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bxhxqxk + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = compute_attention_span(query_layer, key_layer, self.max_relative_positions) + relative_pos = relative_pos.long() + rel_embeddings = rel_embeddings[ + self.max_relative_positions - att_span : self.max_relative_positions + att_span, : + ].unsqueeze(0) + + score = 0 + + # content->position + if "c2p" in self.pos_att_type: + pos_key_layer = self.pos_proj(rel_embeddings) + pos_key_layer = self.transpose_for_scores(pos_key_layer) + c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)) + score += c2p_att + + # position->content + if "p2c" in self.pos_att_type: + pos_query_layer = self.pos_q_proj(rel_embeddings) + pos_query_layer = self.transpose_for_scores(pos_query_layer) + pos_query_layer /= scaled_size_sqrt(pos_query_layer, scale_factor) + r_pos = build_rpos( + query_layer, + key_layer, + relative_pos, + ) + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)) + p2c_att = torch.gather( + p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) + ).transpose(-1, -2) + + p2c_att = uneven_size_corrected(p2c_att, query_layer, key_layer, relative_pos) + score += p2c_att + + return score + + +class DebertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + else: + self.token_type_embeddings = None + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + else: + self.embed_proj = None + + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings = embeddings + position_embeddings + if self.token_type_embeddings is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + + if self.embed_proj is not None: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +class DebertaAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaSelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions: bool = False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + self_output, att_matrix = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return (attention_output, None) + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta +class DebertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class DebertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class DebertaLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.attention = DebertaAttention(config) + self.intermediate = DebertaIntermediate(config) + self.output = DebertaOutput(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + attention_output, att_matrix = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + + if output_attentions: + return (layer_output, att_matrix) + else: + return (layer_output, None) + + +class DebertaEncoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) + self.gradient_checkpointing = False + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + if query_states is not None: + relative_pos = build_relative_position(query_states, hidden_states) + else: + relative_pos = build_relative_position(hidden_states, hidden_states) + return relative_pos + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_hidden_states: bool = True, + output_attentions: bool = False, + query_states=None, + relative_pos=None, + return_dict: bool = True, + ): + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states: Optional[tuple[torch.Tensor]] = (hidden_states,) if output_hidden_states else None + all_attentions = () if output_attentions else None + + next_kv = hidden_states + + rel_embeddings = self.get_rel_embedding() + for i, layer_module in enumerate(self.layer): + hidden_states, att_m = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if query_states is not None: + query_states = hidden_states + else: + next_kv = hidden_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +@auto_docstring +class DebertaPreTrainedModel(PreTrainedModel): + config: DebertaConfig + base_model_prefix = "deberta" + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, DebertaLayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, DisentangledSelfAttention): + module.q_bias.data.zero_() + module.v_bias.data.zero_() + elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)): + module.bias.data.zero_() + + +@auto_docstring +class DebertaModel(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaEmbeddings(config) + self.encoder = DebertaEncoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +class LegacyDebertaPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + + self.dense = nn.Linear(config.hidden_size, self.embedding_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class LegacyDebertaLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = LegacyDebertaPredictionHeadTransform(config) + + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LegacyDeberta +class LegacyDebertaOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = LegacyDebertaLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class DebertaLMPredictionHead(nn.Module): + """https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # note that the input embeddings must be passed as an argument + def forward(self, hidden_states, word_embeddings): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm( + hidden_states + ) # original used MaskedLayerNorm, but passed no mask. This is equivalent. + hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias + return hidden_states + + +class DebertaOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.lm_head = DebertaLMPredictionHead(config) + + # note that the input embeddings must be passed as an argument + def forward(self, sequence_output, word_embeddings): + prediction_scores = self.lm_head(sequence_output, word_embeddings) + return prediction_scores + + +@auto_docstring +class DebertaForMaskedLM(DebertaPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + self.legacy = config.legacy + self.deberta = DebertaModel(config) + if self.legacy: + self.cls = LegacyDebertaOnlyMLMHead(config) + else: + self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"] + self.lm_predictions = DebertaOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + if self.legacy: + return self.cls.predictions.decoder + else: + return self.lm_predictions.lm_head.dense + + def set_output_embeddings(self, new_embeddings): + if self.legacy: + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + else: + self.lm_predictions.lm_head.dense = new_embeddings + self.lm_predictions.lm_head.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + if self.legacy: + prediction_scores = self.cls(sequence_output) + else: + prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = nn.Dropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +@auto_docstring( + custom_intro=""" + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class DebertaForSequenceClassification(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaModel(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = nn.Dropout(drop_out) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + elif self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@auto_docstring +class DebertaForTokenClassification(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@auto_docstring +class DebertaForQuestionAnswering(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "DebertaForMaskedLM", + "DebertaForQuestionAnswering", + "DebertaForSequenceClassification", + "DebertaForTokenClassification", + "DebertaModel", + "DebertaPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta/modeling_tf_deberta.py b/venv/lib/python3.13/site-packages/transformers/models/deberta/modeling_tf_deberta.py new file mode 100644 index 0000000000000000000000000000000000000000..40d23fc28b9475623dc8a94d5539c5a530b3f376 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta/modeling_tf_deberta.py @@ -0,0 +1,1652 @@ +# coding=utf-8 +# Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 DeBERTa model.""" + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFMaskedLMOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_deberta import DebertaConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "DebertaConfig" +_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-base" + + +class TFDebertaContextPooler(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense") + self.dropout = TFDebertaStableDropout(config.pooler_dropout, name="dropout") + self.config = config + + def call(self, hidden_states, training: bool = False): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token, training=training) + pooled_output = self.dense(context_token) + pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output) + return pooled_output + + @property + def output_dim(self) -> int: + return self.config.hidden_size + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.pooler_hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +class TFDebertaXSoftmax(keras.layers.Layer): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`tf.Tensor`): The input tensor that will apply softmax. + mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + """ + + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call(self, inputs: tf.Tensor, mask: tf.Tensor): + rmask = tf.logical_not(tf.cast(mask, tf.bool)) + output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs) + output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis) + output = tf.where(rmask, 0.0, output) + return output + + +class TFDebertaStableDropout(keras.layers.Layer): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob, **kwargs): + super().__init__(**kwargs) + self.drop_prob = drop_prob + + @tf.custom_gradient + def xdropout(self, inputs): + """ + Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob. + """ + mask = tf.cast( + 1 + - tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)), + tf.bool, + ) + scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype) + if self.drop_prob > 0: + inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale + + def grad(upstream): + if self.drop_prob > 0: + return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale + else: + return upstream + + return inputs, grad + + def call(self, inputs: tf.Tensor, training: tf.Tensor = False): + if training: + return self.xdropout(inputs) + return inputs + + +class TFDebertaLayerNorm(keras.layers.Layer): + """LayerNorm module in the TF style (epsilon inside the square root).""" + + def __init__(self, size, eps=1e-12, **kwargs): + super().__init__(**kwargs) + self.size = size + self.eps = eps + + def build(self, input_shape): + self.gamma = self.add_weight(shape=[self.size], initializer=tf.ones_initializer(), name="weight") + self.beta = self.add_weight(shape=[self.size], initializer=tf.zeros_initializer(), name="bias") + return super().build(input_shape) + + def call(self, x: tf.Tensor) -> tf.Tensor: + mean = tf.reduce_mean(x, axis=[-1], keepdims=True) + variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True) + std = tf.math.sqrt(variance + self.eps) + return self.gamma * (x - mean) / std + self.beta + + +class TFDebertaSelfOutput(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(config.hidden_size, name="dense") + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout") + self.config = config + + def call(self, hidden_states, input_tensor, training: bool = False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +class TFDebertaAttention(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + self.self = TFDebertaDisentangledSelfAttention(config, name="self") + self.dense_output = TFDebertaSelfOutput(config, name="output") + self.config = config + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + rel_embeddings: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor]: + self_outputs = self.self( + hidden_states=input_tensor, + attention_mask=attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + training=training, + ) + if query_states is None: + query_states = input_tensor + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=query_states, training=training + ) + + output = (attention_output,) + self_outputs[1:] + + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +class TFDebertaIntermediate(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFDebertaOutput(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout") + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +class TFDebertaLayer(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFDebertaAttention(config, name="attention") + self.intermediate = TFDebertaIntermediate(config, name="intermediate") + self.bert_output = TFDebertaOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + rel_embeddings: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor]: + attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + training=training, + ) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + + +class TFDebertaEncoder(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + + self.layer = [TFDebertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + self.relative_attention = getattr(config, "relative_attention", False) + self.config = config + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.relative_attention: + self.rel_embeddings = self.add_weight( + name="rel_embeddings.weight", + shape=[self.max_relative_positions * 2, self.config.hidden_size], + initializer=get_initializer(self.config.initializer_range), + ) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings if self.relative_attention else None + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if len(shape_list(attention_mask)) <= 2: + extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2) + attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1) + attention_mask = tf.cast(attention_mask, tf.uint8) + elif len(shape_list(attention_mask)) == 3: + attention_mask = tf.expand_dims(attention_mask, 1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2] + relative_pos = build_relative_position(q, shape_list(hidden_states)[-2]) + return relative_pos + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + + rel_embeddings = self.get_rel_embedding() + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states=next_kv, + attention_mask=attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if query_states is not None: + query_states = hidden_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = hidden_states + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def build_relative_position(query_size, key_size): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + + Return: + `tf.Tensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = tf.range(query_size, dtype=tf.int32) + k_ids = tf.range(key_size, dtype=tf.int32) + rel_pos_ids = q_ids[:, None] - tf.tile(tf.reshape(k_ids, [1, -1]), [query_size, 1]) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0) + return tf.cast(rel_pos_ids, tf.int64) + + +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + shapes = [ + shape_list(query_layer)[0], + shape_list(query_layer)[1], + shape_list(query_layer)[2], + shape_list(relative_pos)[-1], + ] + return tf.broadcast_to(c2p_pos, shapes) + + +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + shapes = [ + shape_list(query_layer)[0], + shape_list(query_layer)[1], + shape_list(key_layer)[-2], + shape_list(key_layer)[-2], + ] + return tf.broadcast_to(c2p_pos, shapes) + + +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]] + return tf.broadcast_to(pos_index, shapes) + + +def torch_gather(x, indices, gather_axis): + if gather_axis < 0: + gather_axis = tf.rank(x) + gather_axis + + if gather_axis != tf.rank(x) - 1: + pre_roll = tf.rank(x) - 1 - gather_axis + permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0) + x = tf.transpose(x, perm=permutation) + indices = tf.transpose(indices, perm=permutation) + else: + pre_roll = 0 + + flat_x = tf.reshape(x, (-1, tf.shape(x)[-1])) + flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1])) + gathered = tf.gather(flat_x, flat_indices, batch_dims=1) + gathered = tf.reshape(gathered, tf.shape(indices)) + + if pre_roll != 0: + permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0) + gathered = tf.transpose(gathered, perm=permutation) + + return gathered + + +class TFDebertaDisentangledSelfAttention(keras.layers.Layer): + """ + Disentangled self-attention module + + Parameters: + config (`str`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaConfig`] + + """ + + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.in_proj = keras.layers.Dense( + self.all_head_size * 3, + kernel_initializer=get_initializer(config.initializer_range), + name="in_proj", + use_bias=False, + ) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + + self.relative_attention = getattr(config, "relative_attention", False) + self.talking_head = getattr(config, "talking_head", False) + + if self.talking_head: + self.head_logits_proj = keras.layers.Dense( + self.num_attention_heads, + kernel_initializer=get_initializer(config.initializer_range), + name="head_logits_proj", + use_bias=False, + ) + self.head_weights_proj = keras.layers.Dense( + self.num_attention_heads, + kernel_initializer=get_initializer(config.initializer_range), + name="head_weights_proj", + use_bias=False, + ) + + self.softmax = TFDebertaXSoftmax(axis=-1) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="pos_dropout") + if "c2p" in self.pos_att_type: + self.pos_proj = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + name="pos_proj", + use_bias=False, + ) + if "p2c" in self.pos_att_type: + self.pos_q_proj = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="pos_q_proj" + ) + + self.dropout = TFDebertaStableDropout(config.attention_probs_dropout_prob, name="dropout") + self.config = config + + def build(self, input_shape=None): + if self.built: + return + self.built = True + self.q_bias = self.add_weight( + name="q_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros() + ) + self.v_bias = self.add_weight( + name="v_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros() + ) + if getattr(self, "in_proj", None) is not None: + with tf.name_scope(self.in_proj.name): + self.in_proj.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "head_logits_proj", None) is not None: + with tf.name_scope(self.head_logits_proj.name): + self.head_logits_proj.build(None) + if getattr(self, "head_weights_proj", None) is not None: + with tf.name_scope(self.head_weights_proj.name): + self.head_weights_proj.build(None) + if getattr(self, "pos_dropout", None) is not None: + with tf.name_scope(self.pos_dropout.name): + self.pos_dropout.build(None) + if getattr(self, "pos_proj", None) is not None: + with tf.name_scope(self.pos_proj.name): + self.pos_proj.build([self.config.hidden_size]) + if getattr(self, "pos_q_proj", None) is not None: + with tf.name_scope(self.pos_q_proj.name): + self.pos_q_proj.build([self.config.hidden_size]) + + def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor: + shape = shape_list(tensor)[:-1] + [self.num_attention_heads, -1] + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=shape) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + rel_embeddings: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor]: + """ + Call the module + + Args: + hidden_states (`tf.Tensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`tf.Tensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + return_att (`bool`, *optional*): + Whether return the attention matrix. + + query_states (`tf.Tensor`, *optional*): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`tf.Tensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`tf.Tensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) + query_layer, key_layer, value_layer = tf.split( + self.transpose_for_scores(qp), num_or_size_splits=3, axis=-1 + ) + else: + + def linear(w, b, x): + out = tf.matmul(x, w, transpose_b=True) + if b is not None: + out += tf.transpose(b) + return out + + ws = tf.split( + tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0 + ) + qkvw = tf.TensorArray(dtype=self.dtype, size=3) + for k in tf.range(3): + qkvw_inside = tf.TensorArray(dtype=self.dtype, size=self.num_attention_heads) + for i in tf.range(self.num_attention_heads): + qkvw_inside = qkvw_inside.write(i, ws[i * 3 + k]) + qkvw = qkvw.write(k, qkvw_inside.concat()) + qkvb = [None] * 3 + + q = linear(qkvw[0], qkvb[0], query_states) + k = linear(qkvw[1], qkvb[1], hidden_states) + v = linear(qkvw[2], qkvb[2], hidden_states) + query_layer = self.transpose_for_scores(q) + key_layer = self.transpose_for_scores(k) + value_layer = self.transpose_for_scores(v) + + query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) + value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + len(self.pos_att_type) + scale = math.sqrt(shape_list(query_layer)[-1] * scale_factor) + query_layer = query_layer / scale + + attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 1, 3, 2])) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings, training=training) + rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + + if self.talking_head: + attention_scores = tf.transpose( + self.head_logits_proj(tf.transpose(attention_scores, [0, 2, 3, 1])), [0, 3, 1, 2] + ) + + attention_probs = self.softmax(attention_scores, attention_mask) + attention_probs = self.dropout(attention_probs, training=training) + if self.talking_head: + attention_probs = tf.transpose( + self.head_weights_proj(tf.transpose(attention_probs, [0, 2, 3, 1])), [0, 3, 1, 2] + ) + + context_layer = tf.matmul(attention_probs, value_layer) + context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) + context_layer_shape = shape_list(context_layer) + # Set the final dimension here explicitly. + # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing + # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput + # requires final input dimension to be defined + new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = shape_list(query_layer)[-2] + relative_pos = build_relative_position(q, shape_list(key_layer)[-2]) + shape_list_pos = shape_list(relative_pos) + if len(shape_list_pos) == 2: + relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0) + elif len(shape_list_pos) == 3: + relative_pos = tf.expand_dims(relative_pos, 1) + # bxhxqxk + elif len(shape_list_pos) != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}") + + att_span = tf.cast( + tf.minimum( + tf.maximum(shape_list(query_layer)[-2], shape_list(key_layer)[-2]), self.max_relative_positions + ), + tf.int64, + ) + rel_embeddings = tf.expand_dims( + rel_embeddings[self.max_relative_positions - att_span : self.max_relative_positions + att_span, :], 0 + ) + + score = 0 + + # content->position + if "c2p" in self.pos_att_type: + pos_key_layer = self.pos_proj(rel_embeddings) + pos_key_layer = self.transpose_for_scores(pos_key_layer) + c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 1, 3, 2])) + c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch_gather(c2p_att, c2p_dynamic_expand(c2p_pos, query_layer, relative_pos), -1) + score += c2p_att + + # position->content + if "p2c" in self.pos_att_type: + pos_query_layer = self.pos_q_proj(rel_embeddings) + pos_query_layer = self.transpose_for_scores(pos_query_layer) + pos_query_layer /= tf.math.sqrt( + tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype) + ) + if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]: + r_pos = build_relative_position(shape_list(key_layer)[-2], shape_list(key_layer)[-2]) + else: + r_pos = relative_pos + p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 1, 3, 2])) + p2c_att = tf.transpose( + torch_gather(p2c_att, p2c_dynamic_expand(p2c_pos, query_layer, key_layer), -1), [0, 1, 3, 2] + ) + if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]: + pos_index = tf.expand_dims(relative_pos[:, :, :, 0], -1) + p2c_att = torch_gather(p2c_att, pos_dynamic_expand(pos_index, p2c_att, key_layer), -2) + score += p2c_att + + return score + + +class TFDebertaEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.position_biased_input = getattr(config, "position_biased_input", True) + self.initializer_range = config.initializer_range + if self.embedding_size != config.hidden_size: + self.embed_proj = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="embed_proj", + use_bias=False, + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout") + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + if self.config.type_vocab_size > 0: + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + else: + self.token_type_embeddings = None + + with tf.name_scope("position_embeddings"): + if self.position_biased_input: + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + else: + self.position_embeddings = None + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "embed_proj", None) is not None: + with tf.name_scope(self.embed_proj.name): + self.embed_proj.build([None, None, self.embedding_size]) + + def call( + self, + input_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + mask: tf.Tensor | None = None, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError("Need to provide either `input_ids` or `input_embeds`.") + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + + final_embeddings = inputs_embeds + if self.position_biased_input: + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + final_embeddings += position_embeds + if self.config.type_vocab_size > 0: + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings += token_type_embeds + + if self.embedding_size != self.hidden_size: + final_embeddings = self.embed_proj(final_embeddings) + + final_embeddings = self.LayerNorm(final_embeddings) + + if mask is not None: + if len(shape_list(mask)) != len(shape_list(final_embeddings)): + if len(shape_list(mask)) == 4: + mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1) + mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype) + + final_embeddings = final_embeddings * mask + + final_embeddings = self.dropout(final_embeddings, training=training) + + return final_embeddings + + +class TFDebertaPredictionHeadTransform(keras.layers.Layer): + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + + self.dense = keras.layers.Dense( + units=self.embedding_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = get_tf_activation(config.hidden_act) + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.embedding_size]) + + +class TFDebertaLMPredictionHead(keras.layers.Layer): + def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + + self.transform = TFDebertaPredictionHeadTransform(config, name="transform") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.input_embeddings = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "transform", None) is not None: + with tf.name_scope(self.transform.name): + self.transform.build(None) + + def get_output_embeddings(self) -> keras.layers.Layer: + return self.input_embeddings + + def set_output_embeddings(self, value: tf.Variable): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self) -> dict[str, tf.Variable]: + return {"bias": self.bias} + + def set_bias(self, value: tf.Variable): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.transform(hidden_states=hidden_states) + seq_length = shape_list(hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +class TFDebertaOnlyMLMHead(keras.layers.Layer): + def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + self.predictions = TFDebertaLMPredictionHead(config, input_embeddings, name="predictions") + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + prediction_scores = self.predictions(hidden_states=sequence_output) + + return prediction_scores + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +# @keras_serializable +class TFDebertaMainLayer(keras.layers.Layer): + config_class = DebertaConfig + + def __init__(self, config: DebertaConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFDebertaEmbeddings(config, name="embeddings") + self.encoder = TFDebertaEncoder(config, name="encoder") + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(dims=input_shape, value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + mask=attention_mask, + training=training, + ) + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return TFBaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +class TFDebertaPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaConfig + base_model_prefix = "deberta" + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://huggingface.co/papers/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`DebertaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +class TFDebertaModel(TFDebertaPreTrainedModel): + def __init__(self, config: DebertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.deberta = TFDebertaMainLayer(config, name="deberta") + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) +class TFDebertaForMaskedLM(TFDebertaPreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config: DebertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if config.is_decoder: + logger.warning( + "If you want to use `TFDebertaForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.deberta = TFDebertaMainLayer(config, name="deberta") + self.mlm = TFDebertaOnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +class TFDebertaForSequenceClassification(TFDebertaPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: DebertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.deberta = TFDebertaMainLayer(config, name="deberta") + self.pooler = TFDebertaContextPooler(config, name="pooler") + + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = TFDebertaStableDropout(drop_out, name="cls_dropout") + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.output_dim = self.pooler.output_dim + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + pooled_output = self.pooler(sequence_output, training=training) + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.output_dim]) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +class TFDebertaForTokenClassification(TFDebertaPreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config: DebertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.deberta = TFDebertaMainLayer(config, name="deberta") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(inputs=sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +class TFDebertaForQuestionAnswering(TFDebertaPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config: DebertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.deberta = TFDebertaMainLayer(config, name="deberta") + self.qa_outputs = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.qa_outputs(inputs=sequence_output) + start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) + start_logits = tf.squeeze(input=start_logits, axis=-1) + end_logits = tf.squeeze(input=end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFDebertaForMaskedLM", + "TFDebertaForQuestionAnswering", + "TFDebertaForSequenceClassification", + "TFDebertaForTokenClassification", + "TFDebertaModel", + "TFDebertaPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta/tokenization_deberta.py b/venv/lib/python3.13/site-packages/transformers/models/deberta/tokenization_deberta.py new file mode 100644 index 0000000000000000000000000000000000000000..74e958c8030b4f535bc1448f4abb8f119ef5f926 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta/tokenization_deberta.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model DeBERTa.""" + +import json +import os +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} + + +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class DebertaTokenizer(PreTrainedTokenizer): + """ + Construct a DeBERTa tokenizer. Based on byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import DebertaTokenizer + + >>> tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base") + >>> tokenizer("Hello world")["input_ids"] + [1, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [1, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (Deberta tokenizer detect beginning of words by the preceding space). + add_bos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as + any other word. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask", "token_type_ids"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="[CLS]", + eos_token="[SEP]", + sep_token="[SEP]", + cls_token="[CLS]", + unk_token="[UNK]", + pad_token="[PAD]", + mask_token="[MASK]", + add_prefix_space=False, + add_bos_token=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, special=True) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + self.add_bos_token = add_bos_token + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + add_bos_token=add_bos_token, + **kwargs, + ) + + @property + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.vocab_size + def vocab_size(self): + return len(self.encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + +__all__ = ["DebertaTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta/tokenization_deberta_fast.py b/venv/lib/python3.13/site-packages/transformers/models/deberta/tokenization_deberta_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f2e6552d9dfc8c5d3d48f3c396f816f82158bc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta/tokenization_deberta_fast.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for model DeBERTa.""" + +from typing import Optional + +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_deberta import DebertaTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +class DebertaTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" DeBERTa tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level + Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import DebertaTokenizerFast + + >>> tokenizer = DebertaTokenizerFast.from_pretrained("microsoft/deberta-base") + >>> tokenizer("Hello world")["input_ids"] + [1, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [1, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since + the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + merges_file (`str`, *optional*): + Path to the merges file. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (Deberta tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask", "token_type_ids"] + slow_tokenizer_class = DebertaTokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="[CLS]", + eos_token="[SEP]", + sep_token="[SEP]", + cls_token="[CLS]", + unk_token="[UNK]", + pad_token="[PAD]", + mask_token="[MASK]", + add_prefix_space=False, + **kwargs, + ): + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + self.add_bos_token = kwargs.pop("add_bos_token", False) + + @property + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + Deberta tokenizer has a special mask token to be used in the fill-mask pipeline. The mask token will greedily + comprise the space before the *[MASK]*. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._batch_encode_plus + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._encode_plus + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["DebertaTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c42c9c502862416b8942b9150567a33b6ec9301 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deberta_v2 import * + from .modeling_deberta_v2 import * + from .modeling_tf_deberta_v2 import * + from .tokenization_deberta_v2 import * + from .tokenization_deberta_v2_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/configuration_deberta_v2.py b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/configuration_deberta_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..5189cfd53ae7bb36b589c85ab02a4776266e160d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/configuration_deberta_v2.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2020, Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DeBERTa-v2 model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +if TYPE_CHECKING: + from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType + + +logger = logging.get_logger(__name__) + + +class DebertaV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DebertaV2Model`]. It is used to instantiate a + DeBERTa-v2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the DeBERTa + [microsoft/deberta-v2-xlarge](https://huggingface.co/microsoft/deberta-v2-xlarge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 128100): + Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DebertaV2Model`]. + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 24): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"` + are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 0): + The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-7): + The epsilon used by the layer normalization layers. + relative_attention (`bool`, *optional*, defaults to `True`): + Whether use relative position encoding. + max_relative_positions (`int`, *optional*, defaults to -1): + The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value + as `max_position_embeddings`. + pad_token_id (`int`, *optional*, defaults to 0): + The value used to pad input_ids. + position_biased_input (`bool`, *optional*, defaults to `True`): + Whether add absolute position embedding to content embedding. + pos_att_type (`list[str]`, *optional*): + The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`, + `["p2c", "c2p"]`, `["p2c", "c2p"]`. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + legacy (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the legacy `LegacyDebertaOnlyMLMHead`, which does not work properly + for mask infilling tasks. + + Example: + + ```python + >>> from transformers import DebertaV2Config, DebertaV2Model + + >>> # Initializing a DeBERTa-v2 microsoft/deberta-v2-xlarge style configuration + >>> configuration = DebertaV2Config() + + >>> # Initializing a model (with random weights) from the microsoft/deberta-v2-xlarge style configuration + >>> model = DebertaV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deberta-v2" + + def __init__( + self, + vocab_size=128100, + hidden_size=1536, + num_hidden_layers=24, + num_attention_heads=24, + intermediate_size=6144, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=0, + initializer_range=0.02, + layer_norm_eps=1e-7, + relative_attention=False, + max_relative_positions=-1, + pad_token_id=0, + position_biased_input=True, + pos_att_type=None, + pooler_dropout=0, + pooler_hidden_act="gelu", + legacy=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.relative_attention = relative_attention + self.max_relative_positions = max_relative_positions + self.pad_token_id = pad_token_id + self.position_biased_input = position_biased_input + + # Backwards compatibility + if isinstance(pos_att_type, str): + pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")] + + self.pos_att_type = pos_att_type + self.vocab_size = vocab_size + self.layer_norm_eps = layer_norm_eps + + self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size) + self.pooler_dropout = pooler_dropout + self.pooler_hidden_act = pooler_hidden_act + self.legacy = legacy + + +class DebertaV2OnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + if self._config.type_vocab_size > 0: + return OrderedDict( + [("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)] + ) + else: + return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)]) + + @property + def default_onnx_opset(self) -> int: + return 12 + + def generate_dummy_inputs( + self, + preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], + batch_size: int = -1, + seq_length: int = -1, + num_choices: int = -1, + is_pair: bool = False, + framework: Optional["TensorType"] = None, + num_channels: int = 3, + image_width: int = 40, + image_height: int = 40, + tokenizer: "PreTrainedTokenizerBase" = None, + ) -> Mapping[str, Any]: + dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework) + if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs: + del dummy_inputs["token_type_ids"] + return dummy_inputs + + +__all__ = ["DebertaV2Config", "DebertaV2OnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..dd04dd947738d0546c5d36a6f2d6d51fd06a0dc9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -0,0 +1,1376 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DeBERTa-v2 model.""" + +from collections.abc import Sequence +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_deberta_v2 import DebertaV2Config + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +@torch.jit.script +def make_log_bucket_position(relative_pos, bucket_size: int, max_position: int): + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + ) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) + return bucket_pos + + +def build_relative_position(query_layer, key_layer, bucket_size: int = -1, max_position: int = -1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + device (`torch.device`): the device on which tensors will be created. + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + """ + query_size = query_layer.size(-2) + key_size = key_layer.size(-2) + + q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device) + k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = rel_pos_ids.to(torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +@torch.jit.script +def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int): + return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + + +@torch.jit.script +def build_rpos(query_layer, key_layer, relative_pos, position_buckets: int, max_relative_positions: int): + if key_layer.size(-2) != query_layer.size(-2): + return build_relative_position( + key_layer, + key_layer, + bucket_size=position_buckets, + max_position=max_relative_positions, + ) + else: + return relative_pos + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = nn.Dropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, attention_heads) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.BoolTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, *optional*): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, *optional*): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = scaled_size_sqrt(query_layer, scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype)) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) + # bsz x height x length x dimension + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs) + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + relative_pos = build_relative_position( + query_layer, + key_layer, + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long) + + rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = scaled_size_sqrt(pos_key_layer, scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale.to(dtype=c2p_att.dtype) + + # position->content + if "p2c" in self.pos_att_type: + scale = scaled_size_sqrt(pos_query_layer, scale_factor) + r_pos = build_rpos( + query_layer, + key_layer, + relative_pos, + self.max_relative_positions, + self.position_buckets, + ) + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + score += p2c_att / scale.to(dtype=p2c_att.dtype) + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions: bool = False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + self_output, att_matrix = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return (attention_output, None) + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + attention_output, att_matrix = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + + if output_attentions: + return (layer_output, att_matrix) + else: + return (layer_output, None) + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm,Deberta->DebertaV2 +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + else: + self.token_type_embeddings = None + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + else: + self.embed_proj = None + + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings = embeddings + position_embeddings + if self.token_type_embeddings is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + + if self.embed_proj is not None: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + if query_states is not None: + relative_pos = build_relative_position( + query_states, + hidden_states, + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ) + else: + relative_pos = build_relative_position( + hidden_states, + hidden_states, + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = attention_mask.sum(-2) > 0 + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states: Optional[tuple[torch.Tensor]] = (hidden_states,) if output_hidden_states else None + all_attentions = () if output_attentions else None + + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + for i, layer_module in enumerate(self.layer): + output_states, attn_weights = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attn_weights,) + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +@auto_docstring +class DebertaV2PreTrainedModel(PreTrainedModel): + config: DebertaV2Config + base_model_prefix = "deberta" + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)): + module.bias.data.zero_() + + +@auto_docstring +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.deberta.modeling_deberta.LegacyDebertaPredictionHeadTransform with Deberta->DebertaV2 +class LegacyDebertaV2PredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + + self.dense = nn.Linear(config.hidden_size, self.embedding_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class LegacyDebertaV2LMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = LegacyDebertaV2PredictionHeadTransform(config) + + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class LegacyDebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = LegacyDebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class DebertaV2LMPredictionHead(nn.Module): + """https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # note that the input embeddings must be passed as an argument + def forward(self, hidden_states, word_embeddings): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias + return hidden_states + + +class DebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.lm_head = DebertaV2LMPredictionHead(config) + + # note that the input embeddings must be passed as an argument + def forward(self, sequence_output, word_embeddings): + prediction_scores = self.lm_head(sequence_output, word_embeddings) + return prediction_scores + + +@auto_docstring +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _keys_to_ignore_on_load_unexpected = [r"mask_predictions.*"] + + def __init__(self, config): + super().__init__(config) + self.legacy = config.legacy + self.deberta = DebertaV2Model(config) + if self.legacy: + self.cls = LegacyDebertaV2OnlyMLMHead(config) + else: + self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"] + self.lm_predictions = DebertaV2OnlyMLMHead(config) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + if self.legacy: + return self.cls.predictions.decoder + else: + return self.lm_predictions.lm_head.dense + + def set_output_embeddings(self, new_embeddings): + if self.legacy: + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + else: + self.lm_predictions.lm_head.dense = new_embeddings + self.lm_predictions.lm_head.bias = new_embeddings.bias + + @auto_docstring + # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM.forward with Deberta->DebertaV2 + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + if self.legacy: + prediction_scores = self.cls(sequence_output) + else: + prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = nn.Dropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +@auto_docstring( + custom_intro=""" + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = nn.Dropout(drop_out) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @auto_docstring + # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification.forward with Deberta->DebertaV2 + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + elif self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@auto_docstring +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@auto_docstring +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering.forward with Deberta->DebertaV2 + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, 1) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = nn.Dropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.deberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "DebertaV2ForMaskedLM", + "DebertaV2ForMultipleChoice", + "DebertaV2ForQuestionAnswering", + "DebertaV2ForSequenceClassification", + "DebertaV2ForTokenClassification", + "DebertaV2Model", + "DebertaV2PreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/modeling_tf_deberta_v2.py b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/modeling_tf_deberta_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..d71891ac19c00a8bc3ddb63e547b404570059a7a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/modeling_tf_deberta_v2.py @@ -0,0 +1,1879 @@ +# coding=utf-8 +# Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 DeBERTa-v2 model.""" + +from __future__ import annotations + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_deberta_v2 import DebertaV2Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaV2Config" +_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge" + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaContextPooler with Deberta->DebertaV2 +class TFDebertaV2ContextPooler(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense") + self.dropout = TFDebertaV2StableDropout(config.pooler_dropout, name="dropout") + self.config = config + + def call(self, hidden_states, training: bool = False): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token, training=training) + pooled_output = self.dense(context_token) + pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output) + return pooled_output + + @property + def output_dim(self) -> int: + return self.config.hidden_size + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.pooler_hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXSoftmax with Deberta->DebertaV2 +class TFDebertaV2XSoftmax(keras.layers.Layer): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`tf.Tensor`): The input tensor that will apply softmax. + mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + """ + + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call(self, inputs: tf.Tensor, mask: tf.Tensor): + rmask = tf.logical_not(tf.cast(mask, tf.bool)) + output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs) + output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis) + output = tf.where(rmask, 0.0, output) + return output + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2 +class TFDebertaV2StableDropout(keras.layers.Layer): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob, **kwargs): + super().__init__(**kwargs) + self.drop_prob = drop_prob + + @tf.custom_gradient + def xdropout(self, inputs): + """ + Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob. + """ + mask = tf.cast( + 1 + - tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)), + tf.bool, + ) + scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype) + if self.drop_prob > 0: + inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale + + def grad(upstream): + if self.drop_prob > 0: + return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale + else: + return upstream + + return inputs, grad + + def call(self, inputs: tf.Tensor, training: tf.Tensor = False): + if training: + return self.xdropout(inputs) + return inputs + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaSelfOutput with Deberta->DebertaV2 +class TFDebertaV2SelfOutput(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(config.hidden_size, name="dense") + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") + self.config = config + + def call(self, hidden_states, input_tensor, training: bool = False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaAttention with Deberta->DebertaV2 +class TFDebertaV2Attention(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + self.self = TFDebertaV2DisentangledSelfAttention(config, name="self") + self.dense_output = TFDebertaV2SelfOutput(config, name="output") + self.config = config + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + rel_embeddings: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor]: + self_outputs = self.self( + hidden_states=input_tensor, + attention_mask=attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + training=training, + ) + if query_states is None: + query_states = input_tensor + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=query_states, training=training + ) + + output = (attention_output,) + self_outputs[1:] + + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaIntermediate with Deberta->DebertaV2 +class TFDebertaV2Intermediate(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOutput with Deberta->DebertaV2 +class TFDebertaV2Output(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLayer with Deberta->DebertaV2 +class TFDebertaV2Layer(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + + self.attention = TFDebertaV2Attention(config, name="attention") + self.intermediate = TFDebertaV2Intermediate(config, name="intermediate") + self.bert_output = TFDebertaV2Output(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + rel_embeddings: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor]: + attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + training=training, + ) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + + +class TFDebertaV2ConvLayer(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + + self.kernel_size = getattr(config, "conv_kernel_size", 3) + # groups = getattr(config, "conv_groups", 1) + self.conv_act = get_tf_activation(getattr(config, "conv_act", "tanh")) + self.padding = (self.kernel_size - 1) // 2 + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") + self.config = config + + def build(self, input_shape=None): + if self.built: + return + self.built = True + with tf.name_scope("conv"): + self.conv_kernel = self.add_weight( + name="kernel", + shape=[self.kernel_size, self.config.hidden_size, self.config.hidden_size], + initializer=get_initializer(self.config.initializer_range), + ) + self.conv_bias = self.add_weight( + name="bias", shape=[self.config.hidden_size], initializer=tf.zeros_initializer() + ) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + def call( + self, hidden_states: tf.Tensor, residual_states: tf.Tensor, input_mask: tf.Tensor, training: bool = False + ) -> tf.Tensor: + out = tf.nn.conv2d( + tf.expand_dims(hidden_states, 1), + tf.expand_dims(self.conv_kernel, 0), + strides=1, + padding=[[0, 0], [0, 0], [self.padding, self.padding], [0, 0]], + ) + out = tf.squeeze(tf.nn.bias_add(out, self.conv_bias), 1) + rmask = tf.cast(1 - input_mask, tf.bool) + out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out) + out = self.dropout(out, training=training) + out = self.conv_act(out) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)): + if len(shape_list(input_mask)) == 4: + input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1) + input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), dtype=self.compute_dtype) + + output_states = output * input_mask + + return output_states + + +class TFDebertaV2Encoder(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + + self.layer = [TFDebertaV2Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + self.relative_attention = getattr(config, "relative_attention", False) + self.config = config + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + self.pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets * 2 + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + + self.conv = TFDebertaV2ConvLayer(config, name="conv") if getattr(config, "conv_kernel_size", 0) > 0 else None + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.relative_attention: + self.rel_embeddings = self.add_weight( + name="rel_embeddings.weight", + shape=[self.pos_ebd_size, self.config.hidden_size], + initializer=get_initializer(self.config.initializer_range), + ) + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, self.config.hidden_size]) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if len(shape_list(attention_mask)) <= 2: + extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2) + attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1) + attention_mask = tf.cast(attention_mask, tf.uint8) + elif len(shape_list(attention_mask)) == 3: + attention_mask = tf.expand_dims(attention_mask, 1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2] + relative_pos = build_relative_position( + q, + shape_list(hidden_states)[-2], + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ) + return relative_pos + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + if len(shape_list(attention_mask)) <= 2: + input_mask = attention_mask + else: + input_mask = tf.cast(tf.math.reduce_sum(attention_mask, axis=-2) > 0, dtype=tf.uint8) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + next_kv = hidden_states + + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + layer_outputs = layer_module( + hidden_states=next_kv, + attention_mask=attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + training=training, + ) + output_states = layer_outputs[0] + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = tf.math.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = tf.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, tf.math.abs(relative_pos)) + log_pos = tf.math.ceil( + tf.cast(tf.math.log(abs_pos / mid), tf.float32) + / tf.cast(tf.math.log((max_position - 1) / mid), tf.float32) + * tf.cast(mid - 1, tf.float32) # in graph mode + ) + tf.cast(mid, tf.float32) + bucket_pos = tf.cast( + tf.where(abs_pos <= mid, tf.cast(relative_pos, tf.float32), log_pos * tf.cast(sign, tf.float32)), tf.int32 + ) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + + Return: + `tf.Tensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = tf.range(query_size, dtype=tf.int32) + k_ids = tf.range(key_size, dtype=tf.int32) + rel_pos_ids = q_ids[:, None] - tf.tile(tf.expand_dims(k_ids, axis=0), [shape_list(q_ids)[0], 1]) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0) + return tf.cast(rel_pos_ids, tf.int64) + + +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + shapes = [ + shape_list(query_layer)[0], + shape_list(query_layer)[1], + shape_list(query_layer)[2], + shape_list(relative_pos)[-1], + ] + return tf.broadcast_to(c2p_pos, shapes) + + +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + shapes = [ + shape_list(query_layer)[0], + shape_list(query_layer)[1], + shape_list(key_layer)[-2], + shape_list(key_layer)[-2], + ] + return tf.broadcast_to(c2p_pos, shapes) + + +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]] + return tf.broadcast_to(pos_index, shapes) + + +def take_along_axis(x, indices): + # Only a valid port of np.take_along_axis when the gather axis is -1 + + # TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239 + if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy): + # [B, S, P] -> [B, S, P, D] + one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) + + # if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x) + # grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P] + gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x) + + # GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls + else: + gathered = tf.gather(x, indices, batch_dims=2) + + return gathered + + +class TFDebertaV2DisentangledSelfAttention(keras.layers.Layer): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + name="query_proj", + use_bias=True, + ) + self.key_proj = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + name="key_proj", + use_bias=True, + ) + self.value_proj = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + name="value_proj", + use_bias=True, + ) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout") + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + name="pos_proj", + use_bias=True, + ) + if "p2c" in self.pos_att_type: + self.pos_query_proj = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + name="pos_q_proj", + ) + self.softmax = TFDebertaV2XSoftmax(axis=-1) + self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout") + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor: + tensor_shape = shape_list(tensor) + # In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None + shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads] + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=shape) + tensor = tf.transpose(tensor, perm=[0, 2, 1, 3]) + x_shape = shape_list(tensor) + tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]]) + return tensor + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + query_states: tf.Tensor | None = None, + relative_pos: tf.Tensor | None = None, + rel_embeddings: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor]: + """ + Call the module + + Args: + hidden_states (`tf.Tensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`tf.Tensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + return_att (`bool`, *optional*): + Whether return the attention matrix. + + query_states (`tf.Tensor`, *optional*): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`tf.Tensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`tf.Tensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, dtype=self.compute_dtype)) + attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1]) / scale) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = tf.reshape( + attention_scores, + (-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]), + ) + + # bsz x height x length x dimension + attention_probs = self.softmax(attention_scores, attention_mask) + attention_probs = self.dropout(attention_probs, training=training) + context_layer = tf.matmul( + tf.reshape(attention_probs, [-1, shape_list(attention_probs)[-2], shape_list(attention_probs)[-1]]), + value_layer, + ) + context_layer = tf.transpose( + tf.reshape( + context_layer, + [-1, self.num_attention_heads, shape_list(context_layer)[-2], shape_list(context_layer)[-1]], + ), + [0, 2, 1, 3], + ) + # Set the final dimension here explicitly. + # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing + # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput + # requires final input dimension to be defined + context_layer_shape = shape_list(context_layer) + new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = shape_list(query_layer)[-2] + relative_pos = build_relative_position( + q, + shape_list(key_layer)[-2], + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ) + shape_list_pos = shape_list(relative_pos) + if len(shape_list_pos) == 2: + relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0) + elif len(shape_list_pos) == 3: + relative_pos = tf.expand_dims(relative_pos, 1) + # bsz x height x query x key + elif len(shape_list_pos) != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}") + + att_span = self.pos_ebd_size + rel_embeddings = tf.expand_dims( + rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :], 0 + ) + if self.share_att_key: + pos_query_layer = tf.tile( + self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads), + [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], + ) + pos_key_layer = tf.tile( + self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads), + [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], + ) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = tf.tile( + self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads), + [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = tf.tile( + self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads), + [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], + ) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, dtype=self.compute_dtype)) + c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 2, 1])) + c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = take_along_axis( + c2p_att, + tf.broadcast_to( + tf.squeeze(c2p_pos, 0), + [shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]], + ), + ) + score += c2p_att / scale + + # position->content + if "p2c" in self.pos_att_type: + scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype)) + if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]: + r_pos = build_relative_position( + shape_list(key_layer)[-2], + shape_list(key_layer)[-2], + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ) + r_pos = tf.expand_dims(r_pos, 0) + else: + r_pos = relative_pos + + p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1) + + p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1])) + p2c_att = tf.transpose( + take_along_axis( + p2c_att, + tf.broadcast_to( + tf.squeeze(p2c_pos, 0), + [shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]], + ), + ), + [0, 2, 1], + ) + score += p2c_att / scale + + return score + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query_proj", None) is not None: + with tf.name_scope(self.query_proj.name): + self.query_proj.build([None, None, self.config.hidden_size]) + if getattr(self, "key_proj", None) is not None: + with tf.name_scope(self.key_proj.name): + self.key_proj.build([None, None, self.config.hidden_size]) + if getattr(self, "value_proj", None) is not None: + with tf.name_scope(self.value_proj.name): + self.value_proj.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "pos_dropout", None) is not None: + with tf.name_scope(self.pos_dropout.name): + self.pos_dropout.build(None) + if getattr(self, "pos_key_proj", None) is not None: + with tf.name_scope(self.pos_key_proj.name): + self.pos_key_proj.build([None, None, self.config.hidden_size]) + if getattr(self, "pos_query_proj", None) is not None: + with tf.name_scope(self.pos_query_proj.name): + self.pos_query_proj.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaEmbeddings Deberta->DebertaV2 +class TFDebertaV2Embeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.position_biased_input = getattr(config, "position_biased_input", True) + self.initializer_range = config.initializer_range + if self.embedding_size != config.hidden_size: + self.embed_proj = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="embed_proj", + use_bias=False, + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + if self.config.type_vocab_size > 0: + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + else: + self.token_type_embeddings = None + + with tf.name_scope("position_embeddings"): + if self.position_biased_input: + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + else: + self.position_embeddings = None + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "embed_proj", None) is not None: + with tf.name_scope(self.embed_proj.name): + self.embed_proj.build([None, None, self.embedding_size]) + + def call( + self, + input_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + mask: tf.Tensor | None = None, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError("Need to provide either `input_ids` or `input_embeds`.") + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + + final_embeddings = inputs_embeds + if self.position_biased_input: + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + final_embeddings += position_embeds + if self.config.type_vocab_size > 0: + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings += token_type_embeds + + if self.embedding_size != self.hidden_size: + final_embeddings = self.embed_proj(final_embeddings) + + final_embeddings = self.LayerNorm(final_embeddings) + + if mask is not None: + if len(shape_list(mask)) != len(shape_list(final_embeddings)): + if len(shape_list(mask)) == 4: + mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1) + mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype) + + final_embeddings = final_embeddings * mask + + final_embeddings = self.dropout(final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPredictionHeadTransform with Deberta->DebertaV2 +class TFDebertaV2PredictionHeadTransform(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + + self.dense = keras.layers.Dense( + units=self.embedding_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = get_tf_activation(config.hidden_act) + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.embedding_size]) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLMPredictionHead with Deberta->DebertaV2 +class TFDebertaV2LMPredictionHead(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + + self.transform = TFDebertaV2PredictionHeadTransform(config, name="transform") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.input_embeddings = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "transform", None) is not None: + with tf.name_scope(self.transform.name): + self.transform.build(None) + + def get_output_embeddings(self) -> keras.layers.Layer: + return self.input_embeddings + + def set_output_embeddings(self, value: tf.Variable): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self) -> dict[str, tf.Variable]: + return {"bias": self.bias} + + def set_bias(self, value: tf.Variable): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.transform(hidden_states=hidden_states) + seq_length = shape_list(hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOnlyMLMHead with Deberta->DebertaV2 +class TFDebertaV2OnlyMLMHead(keras.layers.Layer): + def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + self.predictions = TFDebertaV2LMPredictionHead(config, input_embeddings, name="predictions") + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + prediction_scores = self.predictions(hidden_states=sequence_output) + + return prediction_scores + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaMainLayer with Deberta->DebertaV2 +class TFDebertaV2MainLayer(keras.layers.Layer): + config_class = DebertaV2Config + + def __init__(self, config: DebertaV2Config, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFDebertaV2Embeddings(config, name="embeddings") + self.encoder = TFDebertaV2Encoder(config, name="encoder") + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(dims=input_shape, value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + mask=attention_mask, + training=training, + ) + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return TFBaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPreTrainedModel with Deberta->DebertaV2 +class TFDebertaV2PreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = "deberta" + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://huggingface.co/papers/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaModel with Deberta->DebertaV2 +class TFDebertaV2Model(TFDebertaV2PreTrainedModel): + def __init__(self, config: DebertaV2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.deberta = TFDebertaV2MainLayer(config, name="deberta") + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForMaskedLM with Deberta->DebertaV2 +class TFDebertaV2ForMaskedLM(TFDebertaV2PreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config: DebertaV2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if config.is_decoder: + logger.warning( + "If you want to use `TFDebertaV2ForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.deberta = TFDebertaV2MainLayer(config, name="deberta") + self.mlm = TFDebertaV2OnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForSequenceClassification with Deberta->DebertaV2 +class TFDebertaV2ForSequenceClassification(TFDebertaV2PreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: DebertaV2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.deberta = TFDebertaV2MainLayer(config, name="deberta") + self.pooler = TFDebertaV2ContextPooler(config, name="pooler") + + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = TFDebertaV2StableDropout(drop_out, name="cls_dropout") + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.output_dim = self.pooler.output_dim + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + pooled_output = self.pooler(sequence_output, training=training) + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.output_dim]) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForTokenClassification with Deberta->DebertaV2 +class TFDebertaV2ForTokenClassification(TFDebertaV2PreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config: DebertaV2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.deberta = TFDebertaV2MainLayer(config, name="deberta") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(inputs=sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForQuestionAnswering with Deberta->DebertaV2 +class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config: DebertaV2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.deberta = TFDebertaV2MainLayer(config, name="deberta") + self.qa_outputs = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.deberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.qa_outputs(inputs=sequence_output) + start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) + start_logits = tf.squeeze(input=start_logits, axis=-1) + end_logits = tf.squeeze(input=end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + DEBERTA_START_DOCSTRING, +) +class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + # _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] + # _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: DebertaV2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.deberta = TFDebertaV2MainLayer(config, name="deberta") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.pooler = TFDebertaV2ContextPooler(config, name="pooler") + self.classifier = keras.layers.Dense( + units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.output_dim = self.pooler.output_dim + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None + flat_attention_mask = ( + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None + ) + flat_token_type_ids = ( + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None + ) + flat_position_ids = ( + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None + ) + flat_inputs_embeds = ( + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.deberta( + input_ids=flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + position_ids=flat_position_ids, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + pooled_output = self.pooler(sequence_output, training=training) + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "deberta", None) is not None: + with tf.name_scope(self.deberta.name): + self.deberta.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.output_dim]) + + +__all__ = [ + "TFDebertaV2ForMaskedLM", + "TFDebertaV2ForQuestionAnswering", + "TFDebertaV2ForMultipleChoice", + "TFDebertaV2ForSequenceClassification", + "TFDebertaV2ForTokenClassification", + "TFDebertaV2Model", + "TFDebertaV2PreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/tokenization_deberta_v2.py b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/tokenization_deberta_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..e192474c3dcdcc184b7041a44953df740540cf76 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/tokenization_deberta_v2.py @@ -0,0 +1,499 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model DeBERTa.""" + +import os +import unicodedata +from typing import Any, Optional + +import sentencepiece as sp + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "spm.model"} + + +@requires(backends=("sentencepiece",)) +class DebertaV2Tokenizer(PreTrainedTokenizer): + r""" + Constructs a DeBERTa-v2 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + bos_token (`string`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + eos_token (`string`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. When building a sequence using special tokens, this is not the token that is + used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=False, + split_by_punct=False, + bos_token="[CLS]", + eos_token="[SEP]", + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.do_lower_case = do_lower_case + self.split_by_punct = split_by_punct + self.vocab_file = vocab_file + self._tokenizer = SPMTokenizer( + vocab_file, None, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs + ) + unk_token = AddedToken(unk_token, normalized=True, special=True) if isinstance(unk_token, str) else unk_token + super().__init__( + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + split_by_punct=split_by_punct, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + self._tokenizer.special_tokens = self.all_special_tokens + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab(self): + return self._tokenizer.vocab + + def get_vocab(self): + vocab = self.vocab.copy() + vocab.update(self.get_added_vocab()) + return vocab + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + if self.do_lower_case: + text = text.lower() + return self._tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self._tokenizer.spm.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self._tokenizer.spm.IdToPiece(index) if index < self.vocab_size else self.unk_token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + return self._tokenizer.decode(tokens) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", False) + if is_split_into_words or add_prefix_space: + text = " " + text + return (text, kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) + + +class SPMTokenizer: + r""" + Constructs a tokenizer based on [SentencePiece](https://github.com/google/sentencepiece). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + def __init__( + self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[dict[str, Any]] = None + ): + self.split_by_punct = split_by_punct + self.vocab_file = vocab_file + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) + if not os.path.exists(vocab_file): + raise FileNotFoundError(f"{vocab_file} does not exist!") + spm.load(vocab_file) + bpe_vocab_size = spm.GetPieceSize() + # Token map + # 0+1 + # 1+1 + # 2+1 + self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} + self.ids_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] + # self.vocab['[PAD]'] = 0 + # self.vocab['[CLS]'] = 1 + # self.vocab['[SEP]'] = 2 + # self.vocab['[UNK]'] = 3 + + self.spm = spm + self.special_tokens = special_tokens + + def __getstate__(self): + state = self.__dict__.copy() + state["spm"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) + self.spm.Load(self.vocab_file) + + def tokenize(self, text): + return self._encode_as_pieces(text) + + def convert_ids_to_tokens(self, ids): + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def decode(self, tokens, start=-1, end=-1, raw_text=None): + if raw_text is None: + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.spm.decode_pieces(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.spm.decode_pieces(current_sub_tokens) + return out_string.strip() + else: + words = self.split_to_words(raw_text) + word_tokens = [self.tokenize(w) for w in words] + token2words = [0] * len(tokens) + tid = 0 + for i, w in enumerate(word_tokens): + for k, t in enumerate(w): + token2words[tid] = i + tid += 1 + word_start = token2words[start] + word_end = token2words[end] if end < len(tokens) else len(words) + text = "".join(words[word_start:word_end]) + return text + + # TODO add a deprecation cycle as this can have different behaviour from our API + def add_special_token(self, token): + if token not in self.special_tokens: + self.special_tokens.append(token) + if token not in self.vocab: + self.vocab[token] = len(self.vocab) - 1 + self.ids_to_tokens.append(token) + return self.id(token) + + def part_of_whole_word(self, token, is_bos=False): + logger.warning_once( + "The `DebertaTokenizer.part_of_whole_word` method is deprecated and will be removed in `transformers==4.35`" + ) + if is_bos: + return True + if ( + len(token) == 1 + and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0])) + ) or token in self.special_tokens: + return False + + word_start = b"\xe2\x96\x81".decode("utf-8") + return not token.startswith(word_start) + + def pad(self): + return "[PAD]" + + def bos(self): + return "[CLS]" + + def eos(self): + return "[SEP]" + + def unk(self): + return "[UNK]" + + def mask(self): + return "[MASK]" + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + logger.warning_once( + "The `DebertaTokenizer.id` method is deprecated and will be removed in `transformers==4.35`" + ) + return self.vocab.get(sym, 1) + + def _encode_as_pieces(self, text): + text = convert_to_unicode(text) + if self.split_by_punct: + words = self._run_split_on_punc(text) + pieces = [self.spm.encode(w, out_type=str) for w in words] + return [p for w in pieces for p in w] + else: + return self.spm.encode(text, out_type=str) + + def split_to_words(self, text): + pieces = self._encode_as_pieces(text) + word_start = b"\xe2\x96\x81".decode("utf-8") + words = [] + offset = 0 + prev_end = 0 + for i, p in enumerate(pieces): + if p.startswith(word_start): + if offset > prev_end: + words.append(text[prev_end:offset]) + prev_end = offset + w = p.replace(word_start, "") + else: + w = p + try: + s = text.index(w, offset) + pn = "" + k = i + 1 + while k < len(pieces): + pn = pieces[k].replace(word_start, "") + if len(pn) > 0: + break + k += 1 + + if len(pn) > 0 and pn in text[offset:s]: + offset = offset + 1 + else: + offset = s + len(w) + except Exception: + offset = offset + 1 + + if prev_end < offset: + words.append(text[prev_end:offset]) + + return words + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def save_pretrained(self, path: str, filename_prefix: Optional[str] = None): + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + "-" + filename + full_path = os.path.join(path, filename) + with open(full_path, "wb") as fs: + fs.write(self.spm.serialized_model_proto()) + return (full_path,) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise TypeError(f"Unsupported string type: {type(text)}") + + +__all__ = ["DebertaV2Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..1bad9684dd84a5c33fe8584ebb113e5bc3461e83 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for model DeBERTa.""" + +import os +from shutil import copyfile +from typing import Optional + +from ...file_utils import is_sentencepiece_available +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging + + +if is_sentencepiece_available(): + from .tokenization_deberta_v2 import DebertaV2Tokenizer +else: + DebertaV2Tokenizer = None + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spm.model", "tokenizer_file": "tokenizer.json"} + + +class DebertaV2TokenizerFast(PreTrainedTokenizerFast): + r""" + Constructs a DeBERTa-v2 fast tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + bos_token (`string`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + eos_token (`string`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. When building a sequence using special tokens, this is not the token that is + used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = DebertaV2Tokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=False, + split_by_punct=False, + bos_token="[CLS]", + eos_token="[SEP]", + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs, + ) -> None: + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + split_by_punct=split_by_punct, + **kwargs, + ) + + self.do_lower_case = do_lower_case + self.split_by_punct = split_by_punct + self.vocab_file = vocab_file + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["DebertaV2TokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/depth_anything/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/depth_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7425e37e0399c792155a88488045176fb3b5e7a5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/depth_anything/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_depth_anything import * + from .modeling_depth_anything import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/depth_anything/configuration_depth_anything.py b/venv/lib/python3.13/site-packages/transformers/models/depth_anything/configuration_depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..65884fe67c3cc85afe5b474eb5756665b8b20e11 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/depth_anything/configuration_depth_anything.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DepthAnything model configuration""" + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class DepthAnythingConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DepthAnythingModel`]. It is used to instantiate a DepthAnything + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DepthAnything + [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*): + The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to + leverage the [`AutoBackbone`] API. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches to extract from the backbone features. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + reassemble_hidden_size (`int`, *optional*, defaults to 384): + The number of input channels of the reassemble layers. + reassemble_factors (`list[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`): + The up/downsampling factors of the reassemble layers. + neck_hidden_sizes (`list[str]`, *optional*, defaults to `[48, 96, 192, 384]`): + The hidden sizes to project to for the feature maps of the backbone. + fusion_hidden_size (`int`, *optional*, defaults to 64): + The number of channels before fusion. + head_in_index (`int`, *optional*, defaults to -1): + The index of the features to use in the depth estimation head. + head_hidden_size (`int`, *optional*, defaults to 32): + The number of output channels in the second convolution of the depth estimation head. + depth_estimation_type (`str`, *optional*, defaults to `"relative"`): + The type of depth estimation to use. Can be one of `["relative", "metric"]`. + max_depth (`float`, *optional*): + The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models + and 80 for outdoor models. For "relative" depth estimation, this value is ignored. + + Example: + + ```python + >>> from transformers import DepthAnythingConfig, DepthAnythingForDepthEstimation + + >>> # Initializing a DepthAnything small style configuration + >>> configuration = DepthAnythingConfig() + + >>> # Initializing a model from the DepthAnything small style configuration + >>> model = DepthAnythingForDepthEstimation(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "depth_anything" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + patch_size=14, + initializer_range=0.02, + reassemble_hidden_size=384, + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_sizes=[48, 96, 192, 384], + fusion_hidden_size=64, + head_in_index=-1, + head_hidden_size=32, + depth_estimation_type="relative", + max_depth=None, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.") + backbone_config = CONFIG_MAPPING["dinov2"]( + image_size=518, + hidden_size=384, + num_attention_heads=6, + out_indices=[9, 10, 11, 12], + apply_layernorm=True, + reshape_hidden_states=False, + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.reassemble_hidden_size = reassemble_hidden_size + self.patch_size = patch_size + self.initializer_range = initializer_range + self.reassemble_factors = reassemble_factors + self.neck_hidden_sizes = neck_hidden_sizes + self.fusion_hidden_size = fusion_hidden_size + self.head_in_index = head_in_index + self.head_hidden_size = head_hidden_size + if depth_estimation_type not in ["relative", "metric"]: + raise ValueError("depth_estimation_type must be one of ['relative', 'metric']") + self.depth_estimation_type = depth_estimation_type + self.max_depth = max_depth if max_depth else 1 + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output + + +__all__ = ["DepthAnythingConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/depth_anything/modeling_depth_anything.py b/venv/lib/python3.13/site-packages/transformers/models/depth_anything/modeling_depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7d741312041fede83c34467f73282b44ab5430 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/depth_anything/modeling_depth_anything.py @@ -0,0 +1,427 @@ +# coding=utf-8 +# Copyright 2024 TikTok and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Depth Anything model.""" + +from typing import Optional, Union + +import torch +from torch import nn + +from ...modeling_outputs import DepthEstimatorOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.backbone_utils import load_backbone +from .configuration_depth_anything import DepthAnythingConfig + + +logger = logging.get_logger(__name__) + +# General docstring + + +class DepthAnythingReassembleLayer(nn.Module): + def __init__(self, config, channels, factor): + super().__init__() + self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) + + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1) + + # Copied from transformers.models.dpt.modeling_dpt.DPTReassembleLayer.forward + def forward(self, hidden_state): + hidden_state = self.projection(hidden_state) + hidden_state = self.resize(hidden_state) + + return hidden_state + + +class DepthAnythingReassembleStage(nn.Module): + """ + This class reassembles the hidden states of the backbone into image-like feature representations at various + resolutions. + + This happens in 3 stages: + 1. Take the patch embeddings and reshape them to image-like feature representations. + 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`. + 3. Resizing the spatial dimensions (height, width). + + Args: + config (`[DepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.layers = nn.ModuleList() + for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors): + self.layers.append(DepthAnythingReassembleLayer(config, channels=channels, factor=factor)) + + def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]: + """ + Args: + hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`): + List of hidden states from the backbone. + """ + out = [] + + for i, hidden_state in enumerate(hidden_states): + # reshape to (batch_size, num_channels, height, width) + hidden_state = hidden_state[:, 1:] + batch_size, _, num_channels = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + hidden_state = self.layers[i](hidden_state) + out.append(hidden_state) + + return out + + +class DepthAnythingPreActResidualLayer(nn.Module): + """ + ResidualConvUnit, pre-activate residual unit. + + Args: + config (`[DepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.activation1 = nn.ReLU() + self.convolution1 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + self.activation2 = nn.ReLU() + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + residual = hidden_state + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution1(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + return hidden_state + residual + + +class DepthAnythingFeatureFusionLayer(nn.Module): + """Feature fusion layer, merges feature maps from different stages. + + Args: + config (`[DepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) + + self.residual_layer1 = DepthAnythingPreActResidualLayer(config) + self.residual_layer2 = DepthAnythingPreActResidualLayer(config) + + def forward(self, hidden_state, residual=None, size=None): + if residual is not None: + if hidden_state.shape != residual.shape: + residual = nn.functional.interpolate( + residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False + ) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + + modifier = {"scale_factor": 2} if size is None else {"size": size} + + hidden_state = nn.functional.interpolate( + hidden_state, + **modifier, + mode="bilinear", + align_corners=True, + ) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class DepthAnythingFeatureFusionStage(nn.Module): + # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage.__init__ with DPT->DepthAnything + def __init__(self, config: DepthAnythingConfig): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(len(config.neck_hidden_sizes)): + self.layers.append(DepthAnythingFeatureFusionLayer(config)) + + def forward(self, hidden_states, size=None): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + fused_hidden_state = None + + for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)): + size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None + + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state, size=size) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size) + + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything +# avoiding sdpa and flash_attn_2 support, it's done in the backend +@auto_docstring +class DepthAnythingPreTrainedModel(PreTrainedModel): + config: DepthAnythingConfig + base_model_prefix = "depth_anything" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class DepthAnythingNeck(nn.Module): + """ + DepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as + input and produces another list of tensors as output. For DepthAnything, it includes 2 stages: + + * DepthAnythingReassembleStage + * DepthAnythingFeatureFusionStage. + + Args: + config (dict): config dict. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.reassemble_stage = DepthAnythingReassembleStage(config) + + self.convs = nn.ModuleList() + for channel in config.neck_hidden_sizes: + self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False)) + + # fusion + self.fusion_stage = DepthAnythingFeatureFusionStage(config) + + def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]: + """ + Args: + hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): + List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise TypeError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features) + + return output + + +class DepthAnythingDepthEstimationHead(nn.Module): + """ + Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples + the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's + supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation + type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining. + """ + + def __init__(self, config): + super().__init__() + + self.head_in_index = config.head_in_index + self.patch_size = config.patch_size + + features = config.fusion_hidden_size + self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1) + self.activation1 = nn.ReLU() + self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0) + if config.depth_estimation_type == "relative": + self.activation2 = nn.ReLU() + elif config.depth_estimation_type == "metric": + self.activation2 = nn.Sigmoid() + else: + raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}") + self.max_depth = config.max_depth + + def forward(self, hidden_states: list[torch.Tensor], patch_height, patch_width) -> torch.Tensor: + hidden_states = hidden_states[self.head_in_index] + + predicted_depth = self.conv1(hidden_states) + predicted_depth = nn.functional.interpolate( + predicted_depth, + (int(patch_height * self.patch_size), int(patch_width * self.patch_size)), + mode="bilinear", + align_corners=True, + ) + predicted_depth = self.conv2(predicted_depth) + predicted_depth = self.activation1(predicted_depth) + predicted_depth = self.conv3(predicted_depth) + predicted_depth = self.activation2(predicted_depth) * self.max_depth + predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + + return predicted_depth + + +@auto_docstring( + custom_intro=""" + Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """ +) +class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): + _no_split_modules = ["DPTViTEmbeddings"] + + def __init__(self, config): + super().__init__(config) + + self.backbone = load_backbone(config) + self.neck = DepthAnythingNeck(config) + self.head = DepthAnythingDepthEstimationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], DepthEstimatorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Examples: + ```python + >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf") + >>> model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # interpolate to original size + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], + ... ) + + >>> # visualize the prediction + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 255 / predicted_depth.max() + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint8")) + ```""" + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + hidden_states = self.neck(hidden_states, patch_height, patch_width) + + predicted_depth = self.head(hidden_states, patch_height, patch_width) + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["DepthAnythingForDepthEstimation", "DepthAnythingPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/detr/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/detr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8aae5e70381b21772bdec395693c007a9c02b7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/detr/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_detr import * + from .feature_extraction_detr import * + from .image_processing_detr import * + from .image_processing_detr_fast import * + from .modeling_detr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/detr/configuration_detr.py b/venv/lib/python3.13/site-packages/transformers/models/detr/configuration_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..c9540382927cf2c9849b188ed4b53dea3315f4c4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/detr/configuration_detr.py @@ -0,0 +1,297 @@ +# coding=utf-8 +# Copyright 2021 Facebook AI Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DETR model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class DetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DetrModel`]. It is used to instantiate a DETR + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DETR + [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_timm_backbone (`bool`, *optional*, defaults to `True`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which + case it will default to `ResNetConfig()`. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_queries (`int`, *optional*, defaults to 100): + Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can + detect in a single image. For COCO, we recommend 100 queries. + d_model (`int`, *optional*, defaults to 256): + This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float`, *optional*, defaults to 1): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (`str`, *optional*, defaults to `"sine"`): + Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. + backbone (`str`, *optional*, defaults to `"resnet50"`): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, `True`): + Whether to use pretrained weights for the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + dilation (`bool`, *optional*, defaults to `False`): + Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when + `use_timm_backbone` = `True`. + class_cost (`float`, *optional*, defaults to 1): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + mask_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Focal loss in the panoptic segmentation loss. + dice_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + + Examples: + + ```python + >>> from transformers import DetrConfig, DetrModel + + >>> # Initializing a DETR facebook/detr-resnet-50 style configuration + >>> configuration = DetrConfig() + + >>> # Initializing a model (with random weights) from the facebook/detr-resnet-50 style configuration + >>> model = DetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "detr" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + use_timm_backbone=True, + backbone_config=None, + num_channels=3, + num_queries=100, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + init_xavier_std=1.0, + auxiliary_loss=False, + position_embedding_type="sine", + backbone="resnet50", + use_pretrained_backbone=True, + backbone_kwargs=None, + dilation=False, + class_cost=1, + bbox_cost=5, + giou_cost=2, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.1, + **kwargs, + ): + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): + if backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + backbone = None + # set timm attributes to None + dilation = None + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.use_timm_backbone = use_timm_backbone + self.backbone_config = backbone_config + self.num_channels = num_channels + self.num_queries = num_queries + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.num_hidden_layers = encoder_layers + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.backbone_kwargs = backbone_kwargs + self.dilation = dilation + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.eos_coefficient = eos_coefficient + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + @classmethod + def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + Returns: + [`DetrConfig`]: An instance of a configuration object + """ + return cls(backbone_config=backbone_config, **kwargs) + + +class DetrOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("pixel_mask", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 + + @property + def default_onnx_opset(self) -> int: + return 12 + + +__all__ = ["DetrConfig", "DetrOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/detr/feature_extraction_detr.py b/venv/lib/python3.13/site-packages/transformers/models/detr/feature_extraction_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..a81f83c8c313bdb8a904f0b359360c0e100a83d9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/detr/feature_extraction_detr.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for DETR.""" + +import warnings + +from ...image_transforms import rgb_to_id as _rgb_to_id +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_detr import DetrImageProcessor + + +logger = logging.get_logger(__name__) + + +def rgb_to_id(x): + warnings.warn( + "rgb_to_id has moved and will not be importable from this module from v5. " + "Please import from transformers.image_transforms instead.", + FutureWarning, + ) + return _rgb_to_id(x) + + +@requires(backends=("vision",)) +class DetrFeatureExtractor(DetrImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class DetrFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use DetrImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["DetrFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/detr/image_processing_detr.py b/venv/lib/python3.13/site-packages/transformers/models/detr/image_processing_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..f29bd48a5934cca8224a91d0cc25ce54f79e169f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/detr/image_processing_detr.py @@ -0,0 +1,2049 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for DETR.""" + +import io +import pathlib +from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + PaddingMode, + center_to_corners_format, + corners_to_center_format, + id_to_rgb, + pad, + rescale, + resize, + rgb_to_id, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_annotations, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + is_flax_available, + is_jax_tensor, + is_scipy_available, + is_tf_available, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + is_vision_available, + logging, +) +from ...utils.import_utils import requires + + +if is_torch_available(): + import torch + from torch import nn + + +if is_vision_available(): + import PIL + + +if is_scipy_available(): + import scipy.special + import scipy.stats + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + + +# From the original repo: https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/datasets/transforms.py#L76 +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + +def get_image_size_for_max_height_width( + input_image: np.ndarray, + max_height: int, + max_width: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + input_image (`np.ndarray`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, tuple[int, int], list[int]], + max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. If the desired output size + is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output + image size is computed by keeping the aspect ratio of the input image size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `tuple[int, int]` or `list[int]`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + if isinstance(size, (list, tuple)): + return size + + return get_size_with_aspect_ratio(image_size, size, max_size) + + +def get_numpy_to_framework_fn(arr) -> Callable: + """ + Returns a function that converts a numpy array to the framework of the input array. + + Args: + arr (`np.ndarray`): The array to convert. + """ + if isinstance(arr, np.ndarray): + return np.array + if is_tf_available() and is_tf_tensor(arr): + import tensorflow as tf + + return tf.convert_to_tensor + if is_torch_available() and is_torch_tensor(arr): + import torch + + return torch.tensor + if is_flax_available() and is_jax_tensor(arr): + import jax.numpy as jnp + + return jnp.array + raise ValueError(f"Cannot convert arrays of type {type(arr)}") + + +def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: + """ + Squeezes an array, but only if the axis specified has dim 1. + """ + if axis is None: + return arr.squeeze() + + try: + return arr.squeeze(axis=axis) + except ValueError: + return arr + + +def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + +# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices +def max_across_indices(values: Iterable[Any]) -> list[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width +def get_max_height_width( + images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> list[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + + if input_data_format == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_data_format == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_data_format}") + return (max_height, max_width) + + +# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask +def make_pixel_mask( + image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33 +def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray: + """ + Convert a COCO polygon annotation to a mask. + + Args: + segmentations (`list[list[float]]`): + List of polygons, each polygon represented by a list of x-y coordinates. + height (`int`): + Height of the mask. + width (`int`): + Width of the mask. + """ + try: + from pycocotools import mask as coco_mask + except ImportError: + raise ImportError("Pycocotools is not installed in your environment.") + + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = np.asarray(mask, dtype=np.uint8) + mask = np.any(mask, axis=2) + masks.append(mask) + if masks: + masks = np.stack(masks, axis=0) + else: + masks = np.zeros((0, height, width), dtype=np.uint8) + + return masks + + +# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50 +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by DETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + + image_id = target["image_id"] + image_id = np.asarray([image_id], dtype=np.int64) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + + classes = [obj["category_id"] for obj in annotations] + classes = np.asarray(classes, dtype=np.int64) + + # for conversion to coco api + area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32) + iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64) + + boxes = [obj["bbox"] for obj in annotations] + # guard against no boxes via resizing + boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = {} + new_target["image_id"] = image_id + new_target["class_labels"] = classes[keep] + new_target["boxes"] = boxes[keep] + new_target["area"] = area[keep] + new_target["iscrowd"] = iscrowd[keep] + new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64) + + if annotations and "keypoints" in annotations[0]: + keypoints = [obj["keypoints"] for obj in annotations] + # Converting the filtered keypoints list to a numpy array + keypoints = np.asarray(keypoints, dtype=np.float32) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + if return_segmentation_masks: + segmentation_masks = [obj["segmentation"] for obj in annotations] + masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width) + new_target["masks"] = masks[keep] + + return new_target + + +def masks_to_boxes(masks: np.ndarray) -> np.ndarray: + """ + Compute the bounding boxes around the provided panoptic segmentation masks. + + Args: + masks: masks in format `[number_masks, height, width]` where N is the number of masks + + Returns: + boxes: bounding boxes in format `[number_masks, 4]` in xyxy format + """ + if masks.size == 0: + return np.zeros((0, 4)) + + h, w = masks.shape[-2:] + y = np.arange(0, h, dtype=np.float32) + x = np.arange(0, w, dtype=np.float32) + # see https://github.com/pytorch/pytorch/issues/50276 + y, x = np.meshgrid(y, x, indexing="ij") + + x_mask = masks * np.expand_dims(x, axis=0) + x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1) + x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool))) + x_min = x.filled(fill_value=1e8) + x_min = x_min.reshape(x_min.shape[0], -1).min(-1) + + y_mask = masks * np.expand_dims(y, axis=0) + y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1) + y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool))) + y_min = y.filled(fill_value=1e8) + y_min = y_min.reshape(y_min.shape[0], -1).min(-1) + + return np.stack([x_min, y_min, x_max, y_max], 1) + + +def prepare_coco_panoptic_annotation( + image: np.ndarray, + target: dict, + masks_path: Union[str, pathlib.Path], + return_masks: bool = True, + input_data_format: Union[ChannelDimension, str] = None, +) -> dict: + """ + Prepare a coco panoptic annotation for DETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + annotation_path = pathlib.Path(masks_path) / target["file_name"] + + new_target = {} + new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64) + new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64) + new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64) + + if "segments_info" in target: + masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32) + masks = rgb_to_id(masks) + + ids = np.array([segment_info["id"] for segment_info in target["segments_info"]]) + masks = masks == ids[:, None, None] + masks = masks.astype(np.uint8) + if return_masks: + new_target["masks"] = masks + new_target["boxes"] = masks_to_boxes(masks) + new_target["class_labels"] = np.array( + [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64 + ) + new_target["iscrowd"] = np.asarray( + [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64 + ) + new_target["area"] = np.asarray( + [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32 + ) + + return new_target + + +def get_segmentation_image( + masks: np.ndarray, input_size: tuple, target_size: tuple, stuff_equiv_classes, deduplicate=False +): + h, w = input_size + final_h, final_w = target_size + + m_id = scipy.special.softmax(masks.transpose(0, 1), -1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = np.zeros((h, w), dtype=np.int64) + else: + m_id = m_id.argmax(-1).reshape(h, w) + + if deduplicate: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + for eq_id in equiv: + m_id[m_id == eq_id] = equiv[0] + + seg_img = id_to_rgb(m_id) + seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST) + return seg_img + + +def get_mask_area(seg_img: np.ndarray, target_size: tuple[int, int], n_classes: int) -> np.ndarray: + final_h, final_w = target_size + np_seg_img = seg_img.astype(np.uint8) + np_seg_img = np_seg_img.reshape(final_h, final_w, 3) + m_id = rgb_to_id(np_seg_img) + area = [(m_id == i).sum() for i in range(n_classes)] + return area + + +def score_labels_from_class_probabilities(logits: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + probs = scipy.special.softmax(logits, axis=-1) + labels = probs.argmax(-1, keepdims=True) + scores = np.take_along_axis(probs, labels, axis=-1) + scores, labels = scores.squeeze(-1), labels.squeeze(-1) + return scores, labels + + +def post_process_panoptic_sample( + out_logits: np.ndarray, + masks: np.ndarray, + boxes: np.ndarray, + processed_size: tuple[int, int], + target_size: tuple[int, int], + is_thing_map: dict, + threshold=0.85, +) -> dict: + """ + Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample. + + Args: + out_logits (`torch.Tensor`): + The logits for this sample. + masks (`torch.Tensor`): + The predicted segmentation masks for this sample. + boxes (`torch.Tensor`): + The predicted bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y, + width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding). + processed_size (`tuple[int, int]`): + The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size + after data augmentation but before batching. + target_size (`tuple[int, int]`): + The target size of the image, `(height, width)` corresponding to the requested final size of the + prediction. + is_thing_map (`Dict`): + A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not. + threshold (`float`, *optional*, defaults to 0.85): + The threshold used to binarize the segmentation masks. + """ + # we filter empty queries and detection below threshold + scores, labels = score_labels_from_class_probabilities(out_logits) + keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold) + + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_boxes = center_to_corners_format(boxes[keep]) + + if len(cur_boxes) != len(cur_classes): + raise ValueError("Not as many boxes as there are classes") + + cur_masks = masks[keep] + cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR) + cur_masks = safe_squeeze(cur_masks, 1) + b, h, w = cur_masks.shape + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.reshape(b, -1) + stuff_equiv_classes = defaultdict(list) + for k, label in enumerate(cur_classes): + if not is_thing_map[label]: + stuff_equiv_classes[label].append(k) + + seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True) + area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores)) + + # We filter out any mask that is too small + if cur_classes.size() > 0: + # We know filter empty masks as long as we find some + filtered_small = np.array([a <= 4 for a in area], dtype=bool) + while filtered_small.any(): + cur_masks = cur_masks[~filtered_small] + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True) + area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores)) + filtered_small = np.array([a <= 4 for a in area], dtype=bool) + else: + cur_classes = np.ones((1, 1), dtype=np.int64) + + segments_info = [ + {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a} + for i, (cat, a) in enumerate(zip(cur_classes, area)) + ] + del cur_classes + + with io.BytesIO() as out: + PIL.Image.fromarray(seg_img).save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + + return predictions + + +def resize_annotation( + annotation: dict[str, Any], + orig_size: tuple[int, int], + target_size: tuple[int, int], + threshold: float = 0.5, + resample: PILImageResampling = PILImageResampling.NEAREST, +): + """ + Resizes an annotation to a target size. + + Args: + annotation (`dict[str, Any]`): + The annotation dictionary. + orig_size (`tuple[int, int]`): + The original size of the input image. + target_size (`tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`): + The resampling filter to use when resizing the masks. + """ + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size)) + ratio_height, ratio_width = ratios + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = np.array([resize(mask, target_size, resample=resample) for mask in masks]) + masks = masks.astype(np.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + +# TODO - (Amy) make compatible with other frameworks +def binary_mask_to_rle(mask): + """ + Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + mask (`torch.Tensor` or `numpy.array`): + A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target + segment_id or class_id. + Returns: + `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE + format. + """ + if is_torch_tensor(mask): + mask = mask.numpy() + + pixels = mask.flatten() + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return list(runs) + + +# TODO - (Amy) make compatible with other frameworks +def convert_segmentation_to_rle(segmentation): + """ + Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + segmentation (`torch.Tensor` or `numpy.array`): + A segmentation map of shape `(height, width)` where each value denotes a segment or class id. + Returns: + `list[List]`: A list of lists, where each list is the run-length encoding of a segment / class id. + """ + segment_ids = torch.unique(segmentation) + + run_length_encodings = [] + for idx in segment_ids: + mask = torch.where(segmentation == idx, 1, 0) + rle = binary_mask_to_rle(mask) + run_length_encodings.append(rle) + + return run_length_encodings + + +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_area = (mask_probs[k] >= mask_threshold).sum() + mask_exists = mask_k_area > 0 and original_area > 0 + + # Eliminate disconnected tiny segments + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, mask_k + + +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_size: Optional[tuple[int, int]] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: list[dict] = [] + + if target_size is not None: + mask_probs = nn.functional.interpolate( + mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False + )[0] + + current_segment_id = 0 + + # Weigh each mask by its prediction score + mask_probs *= pred_scores.view(-1, 1, 1) + mask_labels = mask_probs.argmax(0) # [height, width] + + # Keep track of instances of each class + stuff_memory_list: dict[str, int] = {} + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + should_fuse = pred_class in label_ids_to_fuse + + # Check if mask exists and large enough to be a segment + mask_exists, mask_k = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if mask_exists: + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + + # Add current object segment to final segmentation map + segmentation[mask_k] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "was_fused": should_fuse, + "score": segment_score, + } + ) + if should_fuse: + stuff_memory_list[pred_class] = current_segment_id + + return segmentation, segments + + +@requires(backends=("vision",)) +class DetrImageProcessor(BaseImageProcessor): + r""" + Constructs a Detr image processor. + + Args: + format (`str`, *optional*, defaults to `"coco_detection"`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be + overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to True): + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean values to use when normalizing the image. Can be a single value or a list of values, one for each + channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one + for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_annotations: Optional[bool] = None, + do_pad: bool = True, + pad_size: Optional[dict[str, int]] = None, + **kwargs, + ) -> None: + if "pad_and_return_pixel_mask" in kwargs: + do_pad = kwargs.pop("pad_and_return_pixel_mask") + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None if size is None else 1333 + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + size = get_size_dict(size, max_size=max_size, default_to_square=False) + + # Backwards compatibility + if do_convert_annotations is None: + do_convert_annotations = do_normalize + + super().__init__(**kwargs) + self.format = format + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_annotations = do_convert_annotations + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + self._valid_processor_keys = [ + "images", + "annotations", + "return_segmentation_masks", + "masks_path", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "do_convert_annotations", + "image_mean", + "image_std", + "do_pad", + "pad_size", + "format", + "return_tensors", + "data_format", + "input_data_format", + ] + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + + def prepare_annotation( + self, + image: np.ndarray, + target: dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> dict: + """ + Prepare an annotation for feeding into DETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + elif format == AnnotationFormat.COCO_PANOPTIC: + return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_panoptic_annotation( + image, + target, + masks_path=masks_path, + return_masks=return_segmentation_masks, + input_data_format=input_data_format, + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + new_size = get_resize_output_image_size( + image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format + ) + elif "max_height" in size and "max_width" in size: + new_size = get_image_size_for_max_height_width( + image, size["max_height"], size["max_width"], input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + image = resize( + image, + size=new_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return image + + def resize_annotation( + self, + annotation, + orig_size, + size, + resample: PILImageResampling = PILImageResampling.NEAREST, + ) -> dict: + """ + Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched + to this number. + """ + return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample) + + # TODO (Amy) - update to use `rescale_factor` instead of `scale` + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict: + """ + Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to + `[center_x, center_y, width, height]` format and from absolute to relative pixel values. + """ + return normalize_annotation(annotation, image_size=image_size) + + def _update_annotation_for_padded_image( + self, + annotation: dict, + input_image_size: tuple[int, int], + output_image_size: tuple[int, int], + padding, + update_bboxes, + ) -> dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = pad( + masks, + padding, + mode=PaddingMode.CONSTANT, + constant_values=0, + input_data_format=ChannelDimension.FIRST, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= np.asarray( + [ + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + ] + ) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + def _pad_image( + self, + image: np.ndarray, + output_size: tuple[int, int], + annotation: Optional[dict[str, Any]] = None, + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes + ) + return padded_image, annotation + + def pad( + self, + images: list[np.ndarray], + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + pad_size: Optional[dict[str, int]] = None, + ) -> BatchFeature: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + images (list[`np.ndarray`]): + Images to pad. + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + Annotations to transform according to the padding that is applied to the images. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + update_bboxes (`bool`, *optional*, defaults to `True`): + Whether to update the bounding boxes in the annotations to match the padded images. If the + bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)` + format, the bounding boxes will not be updated. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) + + annotation_list = annotations if annotations is not None else [None] * len(images) + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotation_list): + padded_image, padded_annotation = self._pad_image( + image, + padded_size, + annotation, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=update_bboxes, + ) + padded_images.append(padded_image) + padded_annotations.append(padded_annotation) + + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [ + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) + for image in images + ] + data["pixel_mask"] = masks + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations + ] + + return encoded_inputs + + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample=None, # PILImageResampling + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + do_convert_annotations: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: Optional[bool] = None, + format: Optional[Union[str, AnnotationFormat]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + do_resize (`bool`, *optional*, defaults to self.do_resize): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to self.size): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to self.resample): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to self.do_rescale): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to self.rescale_factor): + Rescale factor to use when rescaling the image. + do_normalize (`bool`, *optional*, defaults to self.do_normalize): + Whether to normalize the image. + do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations): + Whether to convert the annotations to the format expected by the model. Converts the bounding + boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)` + and in relative coordinates. + image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean): + Mean to use when normalizing the image. + image_std (`float` or `list[float]`, *optional*, defaults to self.image_std): + Standard deviation to use when normalizing the image. + do_pad (`bool`, *optional*, defaults to self.do_pad): + Whether to pad the image. If `True`, padding will be applied to the bottom and right of + the image with zeros. If `pad_size` is provided, the image will be padded to the specified + dimensions. Otherwise, the image will be padded to the maximum height and width of the batch. + format (`str` or `AnnotationFormat`, *optional*, defaults to self.format): + Format of the annotations. + return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): + Type of tensors to return. If `None`, will return the list of images. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + if "pad_and_return_pixel_mask" in kwargs: + logger.warning_once( + "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, " + "use `do_pad` instead." + ) + do_pad = kwargs.pop("pad_and_return_pixel_mask") + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` argument is deprecated and will be removed in a future version, use" + " `size['longest_edge']` instead." + ) + size = kwargs.pop("max_size") + + do_resize = self.do_resize if do_resize is None else do_resize + size = self.size if size is None else size + size = get_size_dict(size=size, default_to_square=False) + resample = self.resample if resample is None else resample + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + do_convert_annotations = ( + self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations + ) + do_pad = self.do_pad if do_pad is None else do_pad + pad_size = self.pad_size if pad_size is None else pad_size + format = self.format if format is None else format + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + if ( + masks_path is not None + and format == AnnotationFormat.COCO_PANOPTIC + and not isinstance(masks_path, (pathlib.Path, str)) + ): + raise ValueError( + "The path to the directory containing the mask PNG files should be provided as a" + f" `pathlib.Path` or string object, but is {type(masks_path)} instead." + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + prepared_images = [] + prepared_annotations = [] + for image, target in zip(images, annotations): + target = self.prepare_annotation( + image, + target, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=input_data_format, + ) + prepared_images.append(image) + prepared_annotations.append(target) + images = prepared_images + annotations = prepared_annotations + del prepared_images, prepared_annotations + + # transformations + if do_resize: + if annotations is not None: + resized_images, resized_annotations = [], [] + for image, target in zip(images, annotations): + orig_size = get_image_size(image, input_data_format) + resized_image = self.resize( + image, size=size, resample=resample, input_data_format=input_data_format + ) + resized_annotation = self.resize_annotation( + target, orig_size, get_image_size(resized_image, input_data_format) + ) + resized_images.append(resized_image) + resized_annotations.append(resized_annotation) + images = resized_images + annotations = resized_annotations + del resized_images, resized_annotations + else: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + + if do_normalize: + images = [ + self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images + ] + + if do_convert_annotations and annotations is not None: + annotations = [ + self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + for annotation, image in zip(annotations, images) + ] + + if do_pad: + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + encoded_inputs = self.pad( + images, + annotations=annotations, + return_pixel_mask=True, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=do_convert_annotations, + return_tensors=return_tensors, + pad_size=pad_size, + ) + else: + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + + return encoded_inputs + + # POSTPROCESSING METHODS - TODO: add support for other frameworks + # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258 + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the + original image size (before any data augmentation). For visualization, this should be the image size + after data augment, but before padding. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + ) + + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + return results + + def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5): + """ + Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`): + Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. + threshold (`float`, *optional*, defaults to 0.9): + Threshold to use to filter out queries. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image + in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_semantic_segmentation`.", + ) + out_logits, raw_masks = outputs.logits, outputs.pred_masks + empty_label = out_logits.shape[-1] - 1 + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.tolist()) + + for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes): + # we filter empty queries and detection below threshold + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) + cur_scores = cur_scores[keep] + cur_labels = cur_labels[keep] + cur_masks = cur_masks[keep] + cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) + cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1 + + predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks} + preds.append(predictions) + return preds + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218 + def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5): + """ + Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports + PyTorch. + + Args: + results (`list[Dict]`): + Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added. + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). + max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the + original image size (before any data augmentation). + threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an + image in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_instance_segmentation`.", + ) + + if len(orig_target_sizes) != len(max_target_sizes): + raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes") + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs.pred_masks.squeeze(2) + outputs_masks = nn.functional.interpolate( + outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False + ) + outputs_masks = (outputs_masks.sigmoid() > threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = nn.functional.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241 + def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85): + """ + Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch. + + Args: + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`): + Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data + augmentation but before batching. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`, *optional*): + Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction. + If left to None, it will default to the `processed_sizes`. + is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): + Dictionary mapping class indices to either True or False, depending on whether or not they are a thing. + If not set, defaults to the `is_thing_map` of COCO panoptic. + threshold (`float`, *optional*, defaults to 0.85): + Threshold to use to filter out queries. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for + an image in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_panoptic_segmentation`.", + ) + if target_sizes is None: + target_sizes = processed_sizes + if len(processed_sizes) != len(target_sizes): + raise ValueError("Make sure to pass in as many processed_sizes as target_sizes") + + if is_thing_map is None: + # default to is_thing_map of COCO panoptic + is_thing_map = {i: i <= 90 for i in range(201)} + + out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes + if not len(out_logits) == len(raw_masks) == len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks" + ) + empty_label = out_logits.shape[-1] - 1 + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) + cur_scores = cur_scores[keep] + cur_labels = cur_labels[keep] + cur_masks = cur_masks[keep] + cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) + cur_boxes = center_to_corners_format(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + if len(cur_boxes) != len(cur_labels): + raise ValueError("Not as many boxes as there are classes") + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_labels): + if not is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST) + + np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())) + np_seg_img = np_seg_img.view(final_h, final_w, 3) + np_seg_img = np_seg_img.numpy() + + m_id = torch.from_numpy(rgb_to_id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_labels.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_labels = cur_labels[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_labels[i].item() + segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a}) + del cur_labels + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258 + def post_process_object_detection( + self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, list): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None): + """ + Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + Raw outputs of the model. + target_sizes (`list[tuple[int, int]]`, *optional*): + A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the + batch. If unset, predictions will not be resized. + Returns: + `list[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218 + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[list[tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + ) -> list[dict]: + """ + Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If unset, predictions will not be resized. + return_coco_annotation (`bool`, *optional*): + Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) + format. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`. Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=[], + target_size=target_size, + ) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241 + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_sizes: Optional[list[tuple[int, int]]] = None, + ) -> list[dict]: + """ + Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports + PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + The outputs from [`DetrForSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If unset, predictions will not be resized. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to + the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["DetrImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/detr/image_processing_detr_fast.py b/venv/lib/python3.13/site-packages/transformers/models/detr/image_processing_detr_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..ffe040898497b6abc7c563799e11e0283125b569 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/detr/image_processing_detr_fast.py @@ -0,0 +1,1259 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for DETR.""" + +import io +import pathlib +from collections import defaultdict +from typing import Any, Optional, Union + +import PIL +import torch +from torch import nn +from torchvision.io import read_image +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, + get_image_size_for_max_height_width, + get_max_height_width, + safe_squeeze, +) +from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + validate_annotations, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + logging, +) +from ...utils.import_utils import requires +from .image_processing_detr import ( + compute_segments, + convert_segmentation_to_rle, + get_size_with_aspect_ratio, + remove_low_and_no_objects, +) + + +logger = logging.get_logger(__name__) + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + + +# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33 +def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor: + """ + Convert a COCO polygon annotation to a mask. + + Args: + segmentations (`list[list[float]]`): + List of polygons, each polygon represented by a list of x-y coordinates. + height (`int`): + Height of the mask. + width (`int`): + Width of the mask. + """ + try: + from pycocotools import mask as coco_mask + except ImportError: + raise ImportError("Pycocotools is not installed in your environment.") + + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8, device=device) + mask = torch.any(mask, axis=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, axis=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8, device=device) + + return masks + + +# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50 +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by DETR. + """ + image_height, image_width = image.size()[-2:] + + image_id = target["image_id"] + image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + classes = [] + area = [] + boxes = [] + keypoints = [] + for obj in annotations: + if "iscrowd" not in obj or obj["iscrowd"] == 0: + classes.append(obj["category_id"]) + area.append(obj["area"]) + boxes.append(obj["bbox"]) + if "keypoints" in obj: + keypoints.append(obj["keypoints"]) + + classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) + area = torch.as_tensor(area, dtype=torch.float32, device=image.device) + iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device) + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = { + "image_id": image_id, + "class_labels": classes[keep], + "boxes": boxes[keep], + "area": area[keep], + "iscrowd": iscrowd[keep], + "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device), + } + + if keypoints: + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + if return_segmentation_masks: + segmentation_masks = [obj["segmentation"] for obj in annotations] + masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width, device=image.device) + new_target["masks"] = masks[keep] + + return new_target + + +def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: + """ + Compute the bounding boxes around the provided panoptic segmentation masks. + + Args: + masks: masks in format `[number_masks, height, width]` where N is the number of masks + + Returns: + boxes: bounding boxes in format `[number_masks, 4]` in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + y = torch.arange(0, h, dtype=torch.float32, device=masks.device) + x = torch.arange(0, w, dtype=torch.float32, device=masks.device) + # see https://github.com/pytorch/pytorch/issues/50276 + y, x = torch.meshgrid(y, x, indexing="ij") + + x_mask = masks * torch.unsqueeze(x, 0) + x_max = x_mask.view(x_mask.shape[0], -1).max(-1)[0] + x_min = ( + torch.where(masks, x.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0] + ) + + y_mask = masks * torch.unsqueeze(y, 0) + y_max = y_mask.view(y_mask.shape[0], -1).max(-1)[0] + y_min = ( + torch.where(masks, y.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0] + ) + + return torch.stack([x_min, y_min, x_max, y_max], 1) + + +# 2 functions below adapted from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py +# Copyright (c) 2018, Alexander Kirillov +# All rights reserved. +def rgb_to_id(color): + """ + Converts RGB color to unique ID. + """ + if isinstance(color, torch.Tensor) and len(color.shape) == 3: + if color.dtype == torch.uint8: + color = color.to(torch.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + +def prepare_coco_panoptic_annotation( + image: torch.Tensor, + target: dict, + masks_path: Union[str, pathlib.Path], + return_masks: bool = True, + input_data_format: Union[ChannelDimension, str] = None, +) -> dict: + """ + Prepare a coco panoptic annotation for DETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + annotation_path = pathlib.Path(masks_path) / target["file_name"] + + new_target = {} + new_target["image_id"] = torch.as_tensor( + [target["image_id"] if "image_id" in target else target["id"]], dtype=torch.int64, device=image.device + ) + new_target["size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device) + new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device) + + if "segments_info" in target: + masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device) + masks = rgb_to_id(masks) + + ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device) + masks = masks == ids[:, None, None] + masks = masks.to(torch.bool) + if return_masks: + new_target["masks"] = masks + new_target["boxes"] = masks_to_boxes(masks) + new_target["class_labels"] = torch.as_tensor( + [segment_info["category_id"] for segment_info in target["segments_info"]], + dtype=torch.int64, + device=image.device, + ) + new_target["iscrowd"] = torch.as_tensor( + [segment_info["iscrowd"] for segment_info in target["segments_info"]], + dtype=torch.int64, + device=image.device, + ) + new_target["area"] = torch.as_tensor( + [segment_info["area"] for segment_info in target["segments_info"]], + dtype=torch.float32, + device=image.device, + ) + + return new_target + + +class DetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + return_segmentation_masks (`bool`, *optional*, defaults to `False`): + Whether to return segmentation masks. + """ + + format: Optional[Union[str, AnnotationFormat]] + do_convert_annotations: Optional[bool] + return_segmentation_masks: Optional[bool] + + +@auto_docstring +@requires(backends=("torchvision", "torch")) +class DetrImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + format = AnnotationFormat.COCO_DETECTION + do_resize = True + do_rescale = True + do_normalize = True + do_pad = True + size = {"shortest_edge": 800, "longest_edge": 1333} + default_to_square = False + model_input_names = ["pixel_values", "pixel_mask"] + valid_kwargs = DetrFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[DetrFastImageProcessorKwargs]) -> None: + if "pad_and_return_pixel_mask" in kwargs: + kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask") + + size = kwargs.pop("size", None) + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None if size is None else 1333 + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + self.size = get_size_dict(size, max_size=max_size, default_to_square=False) + + # Backwards compatibility + do_convert_annotations = kwargs.get("do_convert_annotations") + do_normalize = kwargs.get("do_normalize") + if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None: + self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize + + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `DetrImageProcessorFast.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + + def prepare_annotation( + self, + image: torch.Tensor, + target: dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> dict: + """ + Prepare an annotation for feeding into DETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + elif format == AnnotationFormat.COCO_PANOPTIC: + return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_panoptic_annotation( + image, + target, + masks_path=masks_path, + return_masks=return_segmentation_masks, + input_data_format=input_data_format, + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"] = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use if resizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size["shortest_edge"], + size["longest_edge"], + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"]) + elif size.height and size.width: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + + image = F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + return image + + def resize_annotation( + self, + annotation: dict[str, Any], + orig_size: tuple[int, int], + target_size: tuple[int, int], + threshold: float = 0.5, + interpolation: Optional["F.InterpolationMode"] = None, + ): + """ + Resizes an annotation to a target size. + + Args: + annotation (`dict[str, Any]`): + The annotation dictionary. + orig_size (`tuple[int, int]`): + The original size of the input image. + target_size (`tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`): + The resampling filter to use when resizing the masks. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST_EXACT + ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)] + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device + ) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks] + masks = torch.stack(masks).to(torch.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= torch.as_tensor( + [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device + ) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + def _update_annotation_for_padded_image( + self, + annotation: dict, + input_image_size: tuple[int, int], + output_image_size: tuple[int, int], + padding, + update_bboxes, + ) -> dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size)) + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = F.pad( + masks, + padding, + fill=0, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + def pad( + self, + image: torch.Tensor, + padded_size: tuple[int, int], + annotation: Optional[dict[str, Any]] = None, + update_bboxes: bool = True, + fill: int = 0, + ): + original_size = image.size()[-2:] + padding_bottom = padded_size[0] - original_size[0] + padding_right = padded_size[1] - original_size[1] + if padding_bottom < 0 or padding_right < 0: + raise ValueError( + f"Padding dimensions are negative. Please make sure that the padded size is larger than the " + f"original size. Got padded size: {padded_size}, original size: {original_size}." + ) + if original_size != padded_size: + padding = [0, 0, padding_right, padding_bottom] + image = F.pad(image, padding, fill=fill) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, original_size, padded_size, padding, update_bboxes + ) + + # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device) + pixel_mask[: original_size[0], : original_size[1]] = 1 + + return image, pixel_mask, annotation + + @auto_docstring + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + **kwargs: Unpack[DetrFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + """ + if "pad_and_return_pixel_mask" in kwargs: + kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask") + logger.warning_once( + "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, " + "use `do_pad` instead." + ) + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` argument is deprecated and will be removed in a future version, use" + " `size['longest_edge']` instead." + ) + kwargs["size"] = kwargs.pop("max_size") + + return super().preprocess(images, annotations, masks_path, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + annotations: Optional[Union[AnnotationType, list[AnnotationType]]], + masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_convert_annotations: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_pad: bool, + pad_size: Optional[SizeDict], + format: Optional[Union[str, AnnotationFormat]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + """ + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + if ( + masks_path is not None + and format == AnnotationFormat.COCO_PANOPTIC + and not isinstance(masks_path, (pathlib.Path, str)) + ): + raise ValueError( + "The path to the directory containing the mask PNG files should be provided as a" + f" `pathlib.Path` or string object, but is {type(masks_path)} instead." + ) + + data = {} + + processed_images = [] + processed_annotations = [] + pixel_masks = [] # Initialize pixel_masks here + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + annotation = self.prepare_annotation( + image, + annotation, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=ChannelDimension.FIRST, + ) + + if do_resize: + resized_image = self.resize(image, size=size, interpolation=interpolation) + if annotations is not None: + annotation = self.resize_annotation( + annotation, + orig_size=image.size()[-2:], + target_size=resized_image.size()[-2:], + ) + image = resized_image + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + if do_convert_annotations and annotations is not None: + annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) + + processed_images.append(image) + processed_annotations.append(annotation) + images = processed_images + annotations = processed_annotations if annotations is not None else None + + if do_pad: + # depends on all resized image shapes so we need another loop + if pad_size is not None: + padded_size = (pad_size.height, pad_size.width) + else: + padded_size = get_max_height_width(images) + + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + if padded_size == image.size()[-2:]: + padded_images.append(image) + pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) + padded_annotations.append(annotation) + continue + image, pixel_mask, annotation = self.pad( + image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations + ) + padded_images.append(image) + padded_annotations.append(annotation) + pixel_masks.append(pixel_mask) + images = padded_images + annotations = padded_annotations if annotations is not None else None + data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) + + data.update({"pixel_values": torch.stack(images, dim=0)}) + encoded_inputs = BatchFeature(data, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + return encoded_inputs + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the + original image size (before any data augmentation). For visualization, this should be the image size + after data augment, but before padding. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + ) + + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + return results + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_segmentation + def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5): + """ + Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`): + Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. + threshold (`float`, *optional*, defaults to 0.9): + Threshold to use to filter out queries. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image + in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_semantic_segmentation`.", + ) + out_logits, raw_masks = outputs.logits, outputs.pred_masks + empty_label = out_logits.shape[-1] - 1 + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.tolist()) + + for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes): + # we filter empty queries and detection below threshold + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) + cur_scores = cur_scores[keep] + cur_labels = cur_labels[keep] + cur_masks = cur_masks[keep] + cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) + cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1 + + predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks} + preds.append(predictions) + return preds + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance + def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5): + """ + Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports + PyTorch. + + Args: + results (`list[Dict]`): + Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added. + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). + max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the + original image size (before any data augmentation). + threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an + image in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_instance_segmentation`.", + ) + + if len(orig_target_sizes) != len(max_target_sizes): + raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes") + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs.pred_masks.squeeze(2) + outputs_masks = nn.functional.interpolate( + outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False + ) + outputs_masks = (outputs_masks.sigmoid() > threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = nn.functional.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic + def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85): + """ + Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch. + + Args: + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`): + Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data + augmentation but before batching. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`, *optional*): + Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction. + If left to None, it will default to the `processed_sizes`. + is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): + Dictionary mapping class indices to either True or False, depending on whether or not they are a thing. + If not set, defaults to the `is_thing_map` of COCO panoptic. + threshold (`float`, *optional*, defaults to 0.85): + Threshold to use to filter out queries. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for + an image in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_panoptic_segmentation`.", + ) + if target_sizes is None: + target_sizes = processed_sizes + if len(processed_sizes) != len(target_sizes): + raise ValueError("Make sure to pass in as many processed_sizes as target_sizes") + + if is_thing_map is None: + # default to is_thing_map of COCO panoptic + is_thing_map = {i: i <= 90 for i in range(201)} + + out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes + if not len(out_logits) == len(raw_masks) == len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks" + ) + empty_label = out_logits.shape[-1] - 1 + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) + cur_scores = cur_scores[keep] + cur_labels = cur_labels[keep] + cur_masks = cur_masks[keep] + cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) + cur_boxes = center_to_corners_format(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + if len(cur_boxes) != len(cur_labels): + raise ValueError("Not as many boxes as there are classes") + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_labels): + if not is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST) + + np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())) + np_seg_img = np_seg_img.view(final_h, final_w, 3) + np_seg_img = np_seg_img.numpy() + + m_id = torch.from_numpy(rgb_to_id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_labels.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_labels = cur_labels[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_labels[i].item() + segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a}) + del cur_labels + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_object_detection + def post_process_object_detection( + self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, list): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_semantic_segmentation + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None): + """ + Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + Raw outputs of the model. + target_sizes (`list[tuple[int, int]]`, *optional*): + A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the + batch. If unset, predictions will not be resized. + Returns: + `list[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance_segmentation + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[list[tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + ) -> list[dict]: + """ + Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If unset, predictions will not be resized. + return_coco_annotation (`bool`, *optional*): + Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) + format. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`. Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=[], + target_size=target_size, + ) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic_segmentation + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_sizes: Optional[list[tuple[int, int]]] = None, + ) -> list[dict]: + """ + Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports + PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + The outputs from [`DetrForSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If unset, predictions will not be resized. + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to + the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["DetrImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/detr/modeling_detr.py b/venv/lib/python3.13/site-packages/transformers/models/detr/modeling_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..86835ca62cfc02f7e2f8f134c9824c3eda1f31e5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/detr/modeling_detr.py @@ -0,0 +1,1693 @@ +# coding=utf-8 +# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DETR model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + auto_docstring, + is_timm_available, + logging, + requires_backends, +) +from ...utils.backbone_utils import load_backbone +from .configuration_detr import DetrConfig + + +if is_timm_available(): + from timm import create_model + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions, + namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them + gone through a layernorm. This is useful when training the model with auxiliary decoding losses. + """ +) +class DetrDecoderOutput(BaseModelOutputWithCrossAttentions): + r""" + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput, + namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them + gone through a layernorm. This is useful when training the model with auxiliary decoding losses. + """ +) +class DetrModelOutput(Seq2SeqModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`DetrForObjectDetection`]. + """ +) +class DetrObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`DetrForSegmentation`]. + """ +) +class DetrSegmentationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`): + Segmentation masks logits for all queries. See also + [`~DetrImageProcessor.post_process_semantic_segmentation`] or + [`~DetrImageProcessor.post_process_instance_segmentation`] + [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic + segmentation masks respectively. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + + +# BELOW: utilities copied from +# https://github.com/facebookresearch/detr/blob/master/backbone.py +class DetrFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = DetrFrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class DetrConvEncoder(nn.Module): + """ + Convolutional backbone, using either the AutoBackbone API or one from the timm library. + + nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above. + + """ + + def __init__(self, config): + super().__init__() + + self.config = config + + # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API + if config.use_timm_backbone: + # We default to values which were previously hard-coded. This enables configurability from the config + # using backbone arguments, while keeping the default behavior the same. + requires_backends(self, ["timm"]) + kwargs = getattr(config, "backbone_kwargs", {}) + kwargs = {} if kwargs is None else kwargs.copy() + out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) + num_channels = kwargs.pop("in_chans", config.num_channels) + if config.dilation: + kwargs["output_stride"] = kwargs.get("output_stride", 16) + backbone = create_model( + config.backbone, + pretrained=config.use_pretrained_backbone, + features_only=True, + out_indices=out_indices, + in_chans=num_channels, + **kwargs, + ) + else: + backbone = load_backbone(config) + + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = ( + self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels + ) + + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + + if "resnet" in backbone_model_type: + for name, parameter in self.model.named_parameters(): + if config.use_timm_backbone: + if "layer2" not in name and "layer3" not in name and "layer4" not in name: + parameter.requires_grad_(False) + else: + if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name: + parameter.requires_grad_(False) + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +class DetrConvModel(nn.Module): + """ + This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. + """ + + def __init__(self, conv_encoder, position_embedding): + super().__init__() + self.conv_encoder = conv_encoder + self.position_embedding = position_embedding + + def forward(self, pixel_values, pixel_mask): + # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples + out = self.conv_encoder(pixel_values, pixel_mask) + pos = [] + for feature_map, mask in out: + # position encoding + pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) + + return out, pos + + +class DetrSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.embedding_dim = embedding_dim + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, pixel_values, pixel_mask): + if pixel_mask is None: + raise ValueError("No pixel mask provided") + y_embed = pixel_mask.cumsum(1, dtype=torch.float32) + x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale + + dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class DetrLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256): + super().__init__() + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) + + def forward(self, pixel_values, pixel_mask=None): + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(config): + n_steps = config.d_model // 2 + if config.position_embedding_type == "sine": + # TODO find a better way of exposing other arguments + position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True) + elif config.position_embedding_type == "learned": + position_embedding = DetrLearnedPositionEmbedding(n_steps) + else: + raise ValueError(f"Not supported {config.position_embedding_type}") + + return position_embedding + + +class DetrAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]): + return tensor if object_queries is None else tensor + object_queries + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + spatial_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size, target_len, embed_dim = hidden_states.size() + + # add position embeddings to the hidden states before projecting to queries and keys + if object_queries is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, object_queries) + + # add key-value position embeddings to the key value states + if spatial_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + if attention_mask.dtype == torch.bool: + attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( + attention_mask, -torch.inf + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class DetrEncoderLayer(nn.Module): + def __init__(self, config: DetrConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + object_queries: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class DetrDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DetrConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = DetrAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + object_queries that are added to the hidden states + in the cross-attention layer. + query_position_embeddings (`torch.FloatTensor`, *optional*): + position embeddings that are added to the queries and keys + in the self-attention layer. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + object_queries=query_position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + object_queries=query_position_embeddings, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + spatial_position_embeddings=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@auto_docstring +class DetrPreTrainedModel(PreTrainedModel): + config: DetrConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + xavier_std = self.config.init_xavier_std + + if isinstance(module, DetrMHAttentionMap): + nn.init.zeros_(module.k_linear.bias) + nn.init.zeros_(module.q_linear.bias) + nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) + nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + elif isinstance(module, DetrLearnedPositionEmbedding): + nn.init.uniform_(module.row_embeddings.weight) + nn.init.uniform_(module.column_embeddings.weight) + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class DetrEncoder(DetrPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`DetrEncoderLayer`]. + + The encoder updates the flattened feature map through multiple self-attention layers. + + Small tweak for DETR: + + - object_queries are added to the forward pass. + + Args: + config: DetrConfig + """ + + def __init__(self, config: DetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + + # in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + object_queries=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Object queries that are added to the queries in each self-attention layer. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + # we add object_queries as extra input to the encoder_layer + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class DetrDecoder(DetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some small tweaks for DETR: + + - object_queries and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: DetrConfig + """ + + def __init__(self, config: DetrConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + + self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + # in DETR, the decoder uses layernorm after the last decoder layer output + self.layernorm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + object_queries=None, + query_position_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The query embeddings that are passed into the decoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Object queries that are added to the queries and keys in each cross-attention layer. + query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + , *optional*): Position embeddings that are added to the values and keys in each self-attention layer. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + combined_attention_mask = None + + if attention_mask is not None and combined_attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # optional intermediate hidden states + intermediate = () if self.config.auxiliary_loss else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + combined_attention_mask, + object_queries, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if self.config.auxiliary_loss: + hidden_states = self.layernorm(hidden_states) + intermediate += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # finally, apply layernorm + hidden_states = self.layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # stack intermediate decoder activations + if self.config.auxiliary_loss: + intermediate = torch.stack(intermediate) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate] + if v is not None + ) + return DetrDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=intermediate, + ) + + +@auto_docstring( + custom_intro=""" + The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without + any specific head on top. + """ +) +class DetrModel(DetrPreTrainedModel): + def __init__(self, config: DetrConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = DetrConvEncoder(config) + object_queries = build_position_encoding(config) + self.backbone = DetrConvModel(backbone, object_queries) + + # Create projection layer + self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1) + + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = DetrEncoder(config) + self.decoder = DetrDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], DetrModelOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, DetrModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50") + >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # the last hidden states are the final query embeddings of the Transformer decoder + >>> # these are of shape (batch_size, num_queries, hidden_size) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 100, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # pixel_values should be of shape (batch_size, num_channels, height, width) + # pixel_mask should be of shape (batch_size, height, width) + features, object_queries_list = self.backbone(pixel_values, pixel_mask) + + # get final feature map and downsampled mask + feature_map, mask = features[-1] + + if mask is None: + raise ValueError("Backbone does not return downsampled pixel mask") + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + projected_feature_map = self.input_projection(feature_map) + + # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) + + flattened_mask = mask.flatten(1) + + # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder + # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, height*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + object_queries=object_queries, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + queries = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=queries, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=flattened_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return DetrModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + ) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class DetrMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@auto_docstring( + custom_intro=""" + DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks + such as COCO detection. + """ +) +class DetrForObjectDetection(DetrPreTrainedModel): + def __init__(self, config: DetrConfig): + super().__init__(config) + + # DETR encoder-decoder model + self.model = DetrModel(config) + + # Object detection heads + self.class_labels_classifier = nn.Linear( + config.d_model, config.num_labels + 1 + ) # We add one for the "no object" class + self.bbox_predictor = DetrMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], DetrObjectDetectionOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, DetrForObjectDetection + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50") + >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[ + ... 0 + ... ] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98] + Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66] + Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76] + Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93] + Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through DETR base model to obtain encoder + decoder outputs + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # class logits + predicted bounding boxes + logits = self.class_labels_classifier(sequence_output) + pred_boxes = self.bbox_predictor(sequence_output).sigmoid() + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4] + outputs_class = self.class_labels_classifier(intermediate) + outputs_coord = self.bbox_predictor(intermediate).sigmoid() + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + auxiliary_outputs + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return DetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@auto_docstring( + custom_intro=""" + DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks + such as COCO panoptic. + """ +) +class DetrForSegmentation(DetrPreTrainedModel): + def __init__(self, config: DetrConfig): + super().__init__(config) + + # object detection model + self.detr = DetrForObjectDetection(config) + + # segmentation head + hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads + intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes + + self.mask_head = DetrMaskHeadSmallConv( + hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size + ) + + self.bbox_attention = DetrMHAttentionMap( + hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std + ) + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], DetrSegmentationOutput]: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each + dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels, + bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves + should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a + `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a + `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`. + + Examples: + + ```python + >>> import io + >>> import requests + >>> from PIL import Image + >>> import torch + >>> import numpy + + >>> from transformers import AutoImageProcessor, DetrForSegmentation + >>> from transformers.image_transforms import rgb_to_id + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic") + >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps + >>> # Segmentation results are returned as a list of dictionaries + >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)]) + + >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found + >>> panoptic_seg = result[0]["segmentation"] + >>> # Get prediction score and segment_id to class_id mapping of each segment + >>> panoptic_segments_info = result[0]["segments_info"] + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=device) + + # First, get list of feature maps and position embeddings + features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask) + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + feature_map, mask = features[-1] + batch_size, num_channels, height, width = feature_map.shape + projected_feature_map = self.detr.model.input_projection(feature_map) + + # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) + + flattened_mask = mask.flatten(1) + + # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder + # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, height*width) + if encoder_outputs is None: + encoder_outputs = self.detr.model.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + object_queries=object_queries, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat( + batch_size, 1, 1 + ) + queries = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + decoder_outputs = self.detr.model.decoder( + inputs_embeds=queries, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=flattened_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Sixth, compute logits, pred_boxes and pred_masks + logits = self.detr.class_labels_classifier(sequence_output) + pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid() + + memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width) + mask = flattened_mask.view(batch_size, height, width) + + # FIXME h_boxes takes the last one computed, keep this in mind + # important: we need to reverse the mask, since in the original implementation the mask works reversed + # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32) + bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask) + + seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]]) + + pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1] + outputs_class = self.detr.class_labels_classifier(intermediate) + outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid() + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs + else: + output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return DetrSegmentationOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + pred_masks=pred_masks, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +def _expand(tensor, length: int): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py +class DetrMaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + + if dim % 8 != 0: + raise ValueError( + "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in" + " GroupNorm is set to 8" + ) + + inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] + + self.lay1 = nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = nn.GroupNorm(8, dim) + self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1]) + self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2]) + self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3]) + self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4]) + self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]): + # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with + # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32). + # We expand the projected feature map to match the number of heads. + x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = nn.functional.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = nn.functional.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = nn.functional.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = nn.functional.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = nn.functional.relu(x) + + x = self.out_lay(x) + return x + + +class DetrMHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask: Optional[Tensor] = None): + q = self.q_linear(q) + k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) + + if mask is not None: + weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min) + weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size()) + weights = self.dropout(weights) + return weights + + +__all__ = [ + "DetrForObjectDetection", + "DetrForSegmentation", + "DetrModel", + "DetrPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d738fbc087888597da19735271366d4e35ab708c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dia import * + from .feature_extraction_dia import * + from .generation_dia import * + from .modeling_dia import * + from .processing_dia import * + from .tokenization_dia import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/configuration_dia.py b/venv/lib/python3.13/site-packages/transformers/models/dia/configuration_dia.py new file mode 100644 index 0000000000000000000000000000000000000000..d4dec60b3e4853574e4d528e7b641507a8c0b414 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/configuration_dia.py @@ -0,0 +1,376 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dia model configuration""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DiaEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia + encoder according to the specified arguments, defining the encoder architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + Number of key and value heads for each attention layer in the Transformer encoder. + head_dim (`int`, *optional*, defaults to 128): + Dimensionality of the attention head. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DiaModel`]. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"swish"` and `"gelu_new"` are supported. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "dia_encoder" + + def __init__( + self, + max_position_embeddings: int = 1024, + num_hidden_layers: int = 12, + hidden_size: int = 1024, + num_attention_heads: int = 16, + num_key_value_heads: int = 16, + head_dim: int = 128, + intermediate_size: int = 4096, + norm_eps: float = 1e-5, + vocab_size: int = 256, + hidden_act: str = "silu", + rope_theta: float = 10000.0, + rope_scaling: Optional[dict] = None, + initializer_range: float = 0.02, + **kwargs, + ): + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.norm_eps = norm_eps + self.vocab_size = vocab_size + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + self.initializer_range = initializer_range + super().__init__(**kwargs) + + +class DiaDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia + decoder according to the specified arguments, defining the decoder architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + max_position_embeddings (`int`, *optional*, defaults to 3072): + The maximum sequence length that this model might ever be used with. + num_hidden_layers (`int`, *optional*, defaults to 18): + Number of hidden layers in the Transformer decoder. + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the decoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + Number of key and value heads for each attention layer in the Transformer decoder. + head_dim (`int`, *optional*, defaults to 128): + Dimensionality of the attention head. + cross_num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each cross-attention layer in the Transformer decoder. + cross_head_dim (`int`, *optional*, defaults to 128): + Dimensionality of the cross-attention head. + cross_num_key_value_heads (`int`, *optional*, defaults to 16): + Number of key and value heads for each cross-attention layer in the Transformer decoder. + cross_hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the cross-attention layers. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + vocab_size (`int`, *optional*, defaults to 1028): + Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DiaModel`]. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`, + `"swish"` and `"gelu_new"` are supported. + num_channels (`int`, *optional*, defaults to 9): + Number of channels for the Dia decoder. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Indicating that this model is part of an encoder-decoder architecture. + """ + + model_type = "dia_decoder" + + def __init__( + self, + max_position_embeddings: int = 3072, + num_hidden_layers: int = 18, + hidden_size: int = 2048, + intermediate_size: int = 8192, + num_attention_heads: int = 16, + num_key_value_heads: int = 4, + head_dim: int = 128, + cross_num_attention_heads: int = 16, + cross_head_dim: int = 128, + cross_num_key_value_heads: int = 16, + cross_hidden_size: int = 1024, + norm_eps: float = 1e-5, + vocab_size: int = 1028, + hidden_act: str = "silu", + num_channels: int = 9, + rope_theta: float = 10000.0, + rope_scaling: Optional[dict] = None, + initializer_range: float = 0.02, + use_cache: bool = True, + is_encoder_decoder: bool = True, + **kwargs, + ): + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.cross_num_key_value_heads = cross_num_key_value_heads + self.cross_num_attention_heads = cross_num_attention_heads + self.cross_head_dim = cross_head_dim + self.cross_hidden_size = cross_hidden_size + self.norm_eps = norm_eps + self.vocab_size = vocab_size + self.hidden_act = hidden_act + self.num_channels = num_channels + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + self.initializer_range = initializer_range + self.use_cache = use_cache + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + +class DiaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a + Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + encoder_config (`DiaEncoderConfig`, *optional*): + Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used. + decoder_config (`DiaDecoderConfig`, *optional*): + Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Indicating that this model uses an encoder-decoder architecture. + pad_token_id (`int`, *optional*, defaults to 1025): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1024): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 1026): + Beginning of stream token id. + delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`): + The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example: + + ```python + >>> from transformers import DiaConfig, DiaModel + + >>> # Initializing a DiaConfig with default values + >>> configuration = DiaConfig() + + >>> # Initializing a DiaModel (with random weights) from the configuration + >>> model = DiaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dia" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig} + + def __init__( + self, + encoder_config: Optional[DiaEncoderConfig] = None, + decoder_config: Optional[DiaDecoderConfig] = None, + norm_eps: float = 1e-5, + is_encoder_decoder: bool = True, + pad_token_id: int = 1025, + eos_token_id: int = 1024, + bos_token_id: int = 1026, + delay_pattern: Optional[list[int]] = None, + initializer_range: float = 0.02, + use_cache: bool = True, + **kwargs, + ): + if isinstance(encoder_config, dict): + encoder_config = DiaEncoderConfig(**encoder_config) + if isinstance(decoder_config, dict): + decoder_config = DiaDecoderConfig(**decoder_config) + self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig() + self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig() + self.norm_eps = norm_eps + self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15] + self.initializer_range = initializer_range + self.use_cache = use_cache + + assert self.decoder_config.num_channels == len(self.delay_pattern), ( + "Number of channels must match delay pattern length." + ) + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + def get_text_config(self, *args, **kwargs): + """Defaulting to audio config as it's the decoder in this case which is usually the text backbone""" + return self.decoder_config + + +__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/feature_extraction_dia.py b/venv/lib/python3.13/site-packages/transformers/models/dia/feature_extraction_dia.py new file mode 100644 index 0000000000000000000000000000000000000000..b4376b773b27365932774a35746b2928cf0af707 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/feature_extraction_dia.py @@ -0,0 +1,183 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Dia""" + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class DiaFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs an Dia feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used for padding. + hop_length (`int`, *optional*, defaults to 512): + Overlap length between successive windows. + """ + + model_input_names = ["input_values", "n_quantizers"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 16000, + padding_value: float = 0.0, + hop_length: int = 512, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.hop_length = hop_length + + def __call__( + self, + raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Optional[Union[bool, str, PaddingStrategy]] = None, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio + (`feature_size = 2`). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, *optional*, defaults to `False`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if padding and truncation: + raise ValueError("Both padding and truncation were set. Make sure you only set one.") + elif padding is None: + # by default let's pad the inputs + padding = True + + is_batched = bool( + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] + elif not is_batched and not isinstance(raw_audio, np.ndarray): + raw_audio = np.asarray(raw_audio, dtype=np.float32) + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): + raw_audio = raw_audio.astype(np.float32) + + # always return batch + if not is_batched: + raw_audio = [np.asarray(raw_audio).T] + + # convert stereo to mono if necessary, unique to Dia + for idx, example in enumerate(raw_audio): + if self.feature_size == 2 and example.ndim == 2: + raw_audio[idx] = np.mean(example, -1) + + # verify inputs are valid + for idx, example in enumerate(raw_audio): + if example.ndim > 2: + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") + if self.feature_size == 1 and example.ndim != 1: + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") + if self.feature_size == 2 and example.ndim != 1: # note the conversion before + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") + + input_values = BatchFeature({"input_values": raw_audio}) + + # temporarily treat it as if we were mono as we also convert stereo to mono + original_feature_size = self.feature_size + self.feature_size = 1 + + # normal padding on batch + padded_inputs = self.pad( + input_values, + max_length=max_length, + truncation=truncation, + padding=padding, + return_attention_mask=True, + pad_to_multiple_of=self.hop_length, + ) + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") + + input_values = [] + for example in padded_inputs.pop("input_values"): + if self.feature_size == 1: + example = example[..., None] + input_values.append(example.T) + + padded_inputs["input_values"] = input_values + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + # rewrite back to original feature size + self.feature_size = original_feature_size + + return padded_inputs + + +__all__ = ["DiaFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/generation_dia.py b/venv/lib/python3.13/site-packages/transformers/models/dia/generation_dia.py new file mode 100644 index 0000000000000000000000000000000000000000..c297de7203d4e5b30a189047233976d310179907 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/generation_dia.py @@ -0,0 +1,463 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist + +from ...generation.logits_process import ( + DiaClassifierFreeGuidanceLogitsProcessor, + DiaEOSChannelFilterLogitsProcessor, + DiaEOSDelayPatternLogitsProcessor, + LogitsProcessorList, + TemperatureLogitsWarper, +) +from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation.streamers import BaseStreamer +from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_utils import PreTrainedModel +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DiaGenerationMixin(GenerationMixin): + # Indicates CFG which needs preparation to be properly handled by repeats + _uses_cfg = None + + def _get_logits_processor( + self, + generation_config: GenerationConfig, + input_ids_seq_length: Optional[int] = None, + encoder_input_ids: Optional[torch.LongTensor] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + device: Optional[str] = None, + model_kwargs: Optional[dict[str, Any]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + ) -> LogitsProcessorList: + # Need either custom order or custom processor instead + # (Temporarily disabling those for the super function) + original_guidance_scale = generation_config.guidance_scale + original_temperature = generation_config.temperature + generation_config.guidance_scale = None + generation_config.temperature = None + + # Get base processors and those we can integrate easily + custom_processors = LogitsProcessorList() + + if original_temperature is not None and original_temperature != 1.0: + custom_processors.append(TemperatureLogitsWarper(original_temperature)) + + custom_processors.append( + DiaEOSChannelFilterLogitsProcessor( + num_channels=len(self.config.delay_pattern), + eos_token_id=self.config.eos_token_id, + ) + ) + + merged_processors = super()._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=encoder_input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=custom_processors, + device=device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # Custom processors we need at specific positions + if original_guidance_scale is not None and original_guidance_scale != 1: + cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor( + guidance_scale=original_guidance_scale, + guidance_top_k=generation_config.top_k, + ) + merged_processors.insert(0, cfg_processor) + + merged_processors.append( + DiaEOSDelayPatternLogitsProcessor( + delay_pattern=self.config.delay_pattern, + eos_token_id=self.config.eos_token_id, + max_generation_len=generation_config.max_length, + device=device, + ) + ) + + # Enable temporarily disabled values back + generation_config.guidance_scale = original_guidance_scale + generation_config.temperature = original_temperature + + return merged_processors + + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any + ) -> tuple[GenerationConfig, dict]: + generation_config, model_kwargs = super()._prepare_generation_config( + generation_config, use_model_defaults, **kwargs + ) + + # We allow generation up to max length + max delay pattern + # (will revert back to max length after generation) + generation_config.max_length += max(self.config.delay_pattern) + + # Internal flag to indicate CFG that needs to prepare unconditioned input + self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1 + + return generation_config, model_kwargs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + inputs, input_name, model_kwargs = super()._prepare_model_inputs( + inputs=inputs, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + + # If CFG is requested we fill in the unconditioned parts + if self._uses_cfg: + unconditioned_inputs = torch.zeros_like(inputs) + inputs = torch.cat([inputs, unconditioned_inputs], dim=0) + + if model_kwargs.get("attention_mask", None) is not None: + model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1) + + return inputs, input_name, model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: dict[str, torch.Tensor], + decoder_start_token_id: torch.Tensor, + device: Optional[torch.device] = None, + ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out + decoder_input_ids = decoder_attention_mask = None + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + if model_kwargs is not None and "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs.pop("decoder_attention_mask") + + # We allow generating without preparation (no proper delay) but discourage it + if decoder_input_ids is None or decoder_attention_mask is None: + logger.warning_once( + "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:" + f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}." + f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation." + ) + + num_channels = self.config.decoder_config.num_channels + real_batch_size = batch_size // 2 if self._uses_cfg else batch_size + + if decoder_input_ids is None: + decoder_input_ids = torch.full( + (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device + ) + + decoder_attention_mask = torch.ones( + size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device + ) + + # 2. Determine the valid input and what works as mask within the input + delay_mask = decoder_input_ids.long() + valid_input_size = ( + decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max() + ) + decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long() + decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long() + + # 3. Overwrite into model kwargs + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + model_kwargs["decoder_delay_mask"] = delay_mask + + return decoder_input_ids, model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids, + encoder_outputs=None, # Using this to easily get the batch size + decoder_delay_mask=None, + **kwargs, + ): + # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape + batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0] + input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2) + + # Base method handles most things except CFG and the delay pattern mask + model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs) + + # Post processing for CFG and overwriting via delay pattern mask + # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask) + model_inputs["decoder_input_ids"] = self.apply_delay_mask( + input_ids, self.config.pad_token_id, decoder_delay_mask + ) + + # Depending on cache usage we need to pass all or just one + if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0: + model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :] + + # Be compile friendly + model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous() + + # 2. Apply CFG duplication if needed + if self._uses_cfg: + for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]: + if model_inputs.get(key, None) is not None: + # double first dimension and keep everything else the same + repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1)) + model_inputs[key] = model_inputs[key].repeat(*repeat_pattern) + + return model_inputs + + @staticmethod + def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor: + if delay_mask is None: + return input_ids + + mask_len = min(input_ids.shape[1], delay_mask.shape[1]) + valid_mask = delay_mask[:, :mask_len, :] + valid_input = input_ids[:, :mask_len, :] + + # Overwrite the respective parts of the input + input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask) + + return input_ids + + def _main_generate_loop( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + use_model_defaults: Optional[bool] = None, + custom_generate: Optional[str] = None, + **kwargs, + ): + # ********** mostly taken from main generate function up to calling the different methods (see NOTE) ********** + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + generation_mode_kwargs = self._extract_generation_mode_kwargs( + custom_generate, + kwargs, + synced_gpus, + assistant_model, + streamer, + ) + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, use_model_defaults, **kwargs + ) + generation_mode = generation_config.get_generation_mode(assistant_model) + + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs) + + # 2. Set generation parameters if not already defined + if synced_gpus is None: + synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + # 3. Define model inputs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # 4. Define other model kwargs + if "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + device=inputs_tensor.device, + ) + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer")) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + # NOTE: incorrect `input_ids.shape[1]` previously + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding + # dynamically overrides this value as it can need more than the last token logits + if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs: + model_kwargs["logits_to_keep"] = 1 + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + max_cache_length = generation_config.max_length - 1 + if ( + inputs_tensor.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += inputs_tensor.shape[1] + self._prepare_cache_for_generation( + generation_config, model_kwargs, generation_mode, batch_size, max_cache_length + ) + + # 8. prepare logits processors and stopping criteria + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=generation_mode_kwargs.get("tokenizer"), + ) + + # Set model_kwargs `use_cache` so we can use it later in forward runs + model_kwargs["use_cache"] = generation_config.use_cache + # ******************* taken from main generate function up to calling the different methods ******************* + + # Prepare inner 2D logic in generation loop + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) + + # 10. go into different generation modes + if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + if generation_config.num_return_sequences > 1: + raise ValueError("`num_return_sequences>1` is incompatible with Dia.") + + # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + return self._sample( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + **generation_mode_kwargs, + **model_kwargs, + ) + else: + raise ValueError( + "Got incompatible mode for generation, should be one of greedy or sampling. " + "Ensure that beam search is de-activated by setting `num_beams=1`." + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + use_model_defaults: Optional[bool] = None, + custom_generate: Optional[str] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + # We expect the initial input ids to be the complete mask (delayed input) + delay_mask = kwargs.get("decoder_input_ids") + if delay_mask is not None: + delay_mask = delay_mask.clone() + + output = self._main_generate_loop( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + use_model_defaults=use_model_defaults, + custom_generate=custom_generate, + **kwargs, + ) + + return_dict_in_generate = not isinstance(output, torch.Tensor) + + if return_dict_in_generate: + output_sequences = output.sequences + else: + output_sequences = output + + # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels) + num_channels = self.config.decoder_config.num_channels + bsz = output_sequences.shape[0] // num_channels + output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2) + + # Apply delay mask + output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask) + + if return_dict_in_generate: + output.sequences = output_sequences + else: + output = output_sequences + + return output diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/modeling_dia.py b/venv/lib/python3.13/site-packages/transformers/models/dia/modeling_dia.py new file mode 100644 index 0000000000000000000000000000000000000000..cf662b224aabd884521d7e14b8a167886377f4b5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/modeling_dia.py @@ -0,0 +1,958 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dia/modular_dia.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dia.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig +from .generation_dia import DiaGenerationMixin + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class DiaPreTrainedModel(PreTrainedModel): + config: DiaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = True + main_input_name = "input_ids" + _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] + + +class DiaMultiChannelEmbedding(nn.Module): + """In order to efficiently compute the audio embedding from the 9 different channels, + we vectorize the embedding process by using a single embedding layer and an offset. + Example: + - num_embeds = 4 + - vocab_size = 8 + - num_channels = 3 + We would have offsets = [0, 8, 16] + If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8], + then tokens = audio_codes + offsets + = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24] + This allows us to use a single embedding layer for all channels. + """ + + def __init__(self, config: DiaDecoderConfig): + super().__init__() + self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size) + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,) + self.register_buffer("offsets", offsets, persistent=False) + + def forward(self, audio_codes: torch.Tensor) -> torch.Tensor: + tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1) + embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size) + return embeds.sum(dim=2) + + +class DiaMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +@use_kernel_forward_from_hub("RMSNorm") +class DiaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DiaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DiaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: DiaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class DiaSelfAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = is_causal + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.cross_hidden_size = config.cross_hidden_size + self.num_heads = self.config.cross_num_attention_heads + self.num_key_value_heads = self.config.cross_num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = config.cross_head_dim + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False + if past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values + else: + key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + + if past_key_values is not None: + # save all states to the cache + key_states, value_states = past_key_values.cross_attention_cache.update( + key_states, + value_states, + self.layer_idx, + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape((*input_shape, -1)).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiaEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaEncoderConfig, layer_idx: int): + super().__init__() + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False) + self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.post_sa_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights + + +class DiaEncoder(DiaPreTrainedModel): + def __init__(self, config: DiaEncoderConfig): + super().__init__(config) + self.config = config + + self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.rotary_embeddings = DiaRotaryEmbedding(config) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[BaseModelOutput, tuple]: + hidden_states = self.embedding(input_ids) + + # RoPE + # Note: We expect right padding and hence always generate + # the position ids on the fly to reduce preparation overhead + position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :] + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + encoder_states += (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + +class DiaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True) + self.cross_attention = DiaCrossAttention(config, layer_idx) + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + self_attn_cache = past_key_values + if isinstance(self_attn_cache, EncoderDecoderCache): + self_attn_cache = self_attn_cache.self_attention_cache + + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings, + attention_mask, + # Needs to be an arg in order to function properly + # on inplace operations to be carried (e.g. compile) + self_attn_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.pre_ca_norm(hidden_states) + cross_states, cross_attn_weights = self.cross_attention( + normed_states, + encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + hidden_states = residual + cross_states + + residual = hidden_states + normed_states = self.pre_mlp_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights, cross_attn_weights + + +class DiaDecoder(DiaPreTrainedModel): + """Transformer Decoder Stack using DenseGeneral.""" + + def __init__(self, config: DiaDecoderConfig): + super().__init__(config) + self.num_channels = config.num_channels + self.vocab_size = config.vocab_size + self.embeddings = DiaMultiChannelEmbedding(config) + self.rotary_embeddings = DiaRotaryEmbedding(config) + self.layers = nn.ModuleList( + [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`): + The original `decoder_input_ids` in 3D shape to facilitate more efficient computations. + + [What are input IDs?](../glossary#input-ids) + """ + + batch_size, seq_length = input_ids.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=input_ids.device + ) + if position_ids is None: + position_ids = cache_position[None, :] + + # RoPE + hidden_states = self.embeddings(input_ids) + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device) + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + hidden_states.shape[:2], + hidden_states, + ) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + position_embeddings, + attention_mask, + encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns = all_self_attns + (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +@auto_docstring( + custom_intro=""" + The bare Dia model outputting raw hidden-states without any specific head on top. + """ +) +class DiaModel(DiaPreTrainedModel): + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.encoder = DiaEncoder(config.encoder_config) + self.decoder = DiaDecoder(config.decoder_config) + self.post_init() + + def get_encoder(self): + return self.encoder + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + """ + + if input_ids is None and encoder_outputs is None: + raise ValueError( + "You should either provide text ids or the cached text encodings. Neither has been found." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if self.is_gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **kwargs, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # On default we initialize the decoder with bos tokens if nothing has been provided + bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels) + if decoder_input_ids is None: + decoder_input_ids = torch.full( + size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device + ) + # Ensure 3D + if decoder_input_ids.ndim == 2: + decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + position_ids=decoder_position_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs[0], + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top. + """ +) +class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin): + base_model_prefix = "model" + + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.model = DiaModel(config) + + self.num_channels = config.decoder_config.num_channels + self.vocab_size = config.decoder_config.vocab_size + self.logits_dense = nn.Linear( + config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False + ) + self.loss_type = "ForMaskedLM" + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in + `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100` + are ignored (masked). + """ + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_position_ids=decoder_position_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + last_hidden_state = outputs[0] + batch_size = last_hidden_state.shape[0] + # 3D <-> 2D makes it necessary to prioritize channel dim + audio_logits = ( + self.logits_dense(last_hidden_state) + .view((batch_size, -1, self.num_channels, self.vocab_size)) + .transpose(1, 2) + .contiguous() + .view(batch_size * self.num_channels, -1, self.vocab_size) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=audio_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/modular_dia.py b/venv/lib/python3.13/site-packages/transformers/models/dia/modular_dia.py new file mode 100644 index 0000000000000000000000000000000000000000..f99d32a01d9cffab79bd72852d776447e084681b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/modular_dia.py @@ -0,0 +1,773 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Dia model.""" + +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...cache_utils import DynamicCache, EncoderDecoderCache +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaRMSNorm, + LlamaRotaryEmbedding, + eager_attention_forward, +) +from ..phi3.modeling_phi3 import Phi3MLP +from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig +from .generation_dia import DiaGenerationMixin + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class DiaPreTrainedModel(PreTrainedModel): + config: DiaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = True + main_input_name = "input_ids" + _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] + + +class DiaMultiChannelEmbedding(nn.Module): + """In order to efficiently compute the audio embedding from the 9 different channels, + we vectorize the embedding process by using a single embedding layer and an offset. + Example: + - num_embeds = 4 + - vocab_size = 8 + - num_channels = 3 + We would have offsets = [0, 8, 16] + If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8], + then tokens = audio_codes + offsets + = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24] + This allows us to use a single embedding layer for all channels. + """ + + def __init__(self, config: DiaDecoderConfig): + super().__init__() + self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size) + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,) + self.register_buffer("offsets", offsets, persistent=False) + + def forward(self, audio_codes: torch.Tensor) -> torch.Tensor: + tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1) + embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size) + return embeds.sum(dim=2) + + +class DiaMLP(Phi3MLP): + pass + + +class DiaRMSNorm(LlamaRMSNorm): + pass + + +class DiaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class DiaSelfAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = is_causal + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + +class DiaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.cross_hidden_size = config.cross_hidden_size + self.num_heads = self.config.cross_num_attention_heads + self.num_key_value_heads = self.config.cross_num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = config.cross_head_dim + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False + if past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values + else: + key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + + if past_key_values is not None: + # save all states to the cache + key_states, value_states = past_key_values.cross_attention_cache.update( + key_states, + value_states, + self.layer_idx, + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape((*input_shape, -1)).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiaEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaEncoderConfig, layer_idx: int): + super().__init__() + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False) + self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.post_sa_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights + + +class DiaEncoder(DiaPreTrainedModel): + def __init__(self, config: DiaEncoderConfig): + super().__init__(config) + self.config = config + + self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.rotary_embeddings = DiaRotaryEmbedding(config) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[BaseModelOutput, tuple]: + hidden_states = self.embedding(input_ids) + + # RoPE + # Note: We expect right padding and hence always generate + # the position ids on the fly to reduce preparation overhead + position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :] + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + encoder_states += (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + +class DiaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True) + self.cross_attention = DiaCrossAttention(config, layer_idx) + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + self_attn_cache = past_key_values + if isinstance(self_attn_cache, EncoderDecoderCache): + self_attn_cache = self_attn_cache.self_attention_cache + + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings, + attention_mask, + # Needs to be an arg in order to function properly + # on inplace operations to be carried (e.g. compile) + self_attn_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.pre_ca_norm(hidden_states) + cross_states, cross_attn_weights = self.cross_attention( + normed_states, + encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + hidden_states = residual + cross_states + + residual = hidden_states + normed_states = self.pre_mlp_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights, cross_attn_weights + + +class DiaDecoder(DiaPreTrainedModel): + """Transformer Decoder Stack using DenseGeneral.""" + + def __init__(self, config: DiaDecoderConfig): + super().__init__(config) + self.num_channels = config.num_channels + self.vocab_size = config.vocab_size + self.embeddings = DiaMultiChannelEmbedding(config) + self.rotary_embeddings = DiaRotaryEmbedding(config) + self.layers = nn.ModuleList( + [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`): + The original `decoder_input_ids` in 3D shape to facilitate more efficient computations. + + [What are input IDs?](../glossary#input-ids) + """ + + batch_size, seq_length = input_ids.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=input_ids.device + ) + if position_ids is None: + position_ids = cache_position[None, :] + + # RoPE + hidden_states = self.embeddings(input_ids) + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device) + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + hidden_states.shape[:2], + hidden_states, + ) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + position_embeddings, + attention_mask, + encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns = all_self_attns + (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +@auto_docstring( + custom_intro=""" + The bare Dia model outputting raw hidden-states without any specific head on top. + """ +) +class DiaModel(DiaPreTrainedModel): + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.encoder = DiaEncoder(config.encoder_config) + self.decoder = DiaDecoder(config.decoder_config) + self.post_init() + + def get_encoder(self): + return self.encoder + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + """ + + if input_ids is None and encoder_outputs is None: + raise ValueError( + "You should either provide text ids or the cached text encodings. Neither has been found." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if self.is_gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **kwargs, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # On default we initialize the decoder with bos tokens if nothing has been provided + bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels) + if decoder_input_ids is None: + decoder_input_ids = torch.full( + size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device + ) + # Ensure 3D + if decoder_input_ids.ndim == 2: + decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + position_ids=decoder_position_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs[0], + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top. + """ +) +class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin): + base_model_prefix = "model" + + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.model = DiaModel(config) + + self.num_channels = config.decoder_config.num_channels + self.vocab_size = config.decoder_config.vocab_size + self.logits_dense = nn.Linear( + config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False + ) + self.loss_type = "ForMaskedLM" + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in + `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100` + are ignored (masked). + """ + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_position_ids=decoder_position_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + last_hidden_state = outputs[0] + batch_size = last_hidden_state.shape[0] + # 3D <-> 2D makes it necessary to prioritize channel dim + audio_logits = ( + self.logits_dense(last_hidden_state) + .view((batch_size, -1, self.num_channels, self.vocab_size)) + .transpose(1, 2) + .contiguous() + .view(batch_size * self.num_channels, -1, self.vocab_size) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=audio_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/processing_dia.py b/venv/lib/python3.13/site-packages/transformers/models/dia/processing_dia.py new file mode 100644 index 0000000000000000000000000000000000000000..402f5152a64bda378ccdf5edd512c86fe643145c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/processing_dia.py @@ -0,0 +1,474 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for Dia""" + +import math +from pathlib import Path +from typing import Optional, Union + +from ...audio_utils import AudioInput, make_list_of_audio +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...utils import is_soundfile_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_soundfile_available(): + import soundfile as sf + + +class DiaAudioKwargs(AudioKwargs, total=False): + bos_token_id: int + eos_token_id: int + pad_token_id: int + delay_pattern: list[int] + generation: bool + + +class DiaProcessorKwargs(ProcessingKwargs, total=False): + audio_kwargs: DiaAudioKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "padding_side": "right", + "add_special_tokens": False, + }, + "audio_kwargs": { + "eos_token_id": 1024, + "pad_token_id": 1025, + "bos_token_id": 1026, + "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15], + "generation": True, + "sampling_rate": 44100, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class DiaProcessor(ProcessorMixin): + r""" + Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into + a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio- + nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more + information. + + Args: + feature_extractor (`DiaFeatureExtractor`): + An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`DiaTokenizer`): + An instance of [`DiaTokenizer`]. The tokenizer is a required input. + audio_tokenizer (`DacModel`): + An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input. + """ + + feature_extractor_class = "DiaFeatureExtractor" + tokenizer_class = "DiaTokenizer" + audio_tokenizer_class = "DacModel" + + def __init__(self, feature_extractor, tokenizer, audio_tokenizer): + super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer) + + def __call__( + self, + text: Union[str, list[str]], + audio: Optional[AudioInput] = None, + output_labels: Optional[bool] = False, + **kwargs: Unpack[DiaProcessorKwargs], + ): + """ + Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is + forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the + DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer + to the docstring of the above methods for more information. + """ + if not is_torch_available(): + raise ValueError( + "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't " + "find it in your environment. You can install torch via `pip install torch`." + ) + + if text is None: + raise ValueError("You need to specify the `text` input to process.") + + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + + text_kwargs = output_kwargs["text_kwargs"] + audio_kwargs = output_kwargs["audio_kwargs"] + common_kwargs = output_kwargs["common_kwargs"] + + return_tensors = common_kwargs.pop("return_tensors", None) + if return_tensors != "pt": + raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.") + + data = {} + + # Text + if isinstance(text, str): + text = [text] + elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + encodings = self.tokenizer(text, **text_kwargs) + data.update(encodings) + + # Audio + delay_pattern = audio_kwargs.pop("delay_pattern", None) + audio_bos_token_id = audio_kwargs.pop("bos_token_id", None) + audio_eos_token_id = audio_kwargs.pop("eos_token_id", None) + audio_pad_token_id = audio_kwargs.pop("pad_token_id", None) + generation = audio_kwargs.pop("generation", True) + if ( + audio_bos_token_id is None + or audio_eos_token_id is None + or audio_pad_token_id is None + or delay_pattern is None + ): + raise ValueError( + "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, " + "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those." + ) + + if generation and output_labels: + raise ValueError( + f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}." + ) + + batch_size = data["input_ids"].shape[0] + num_channels = len(delay_pattern) + max_delay = max(delay_pattern) + + # Voice cloning generation / general training + if audio is not None: + audio = make_list_of_audio(audio) + input_audios = self.feature_extractor(audio, **audio_kwargs) + + compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios) + max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate + + decoder_input_ids = [] + decoder_attention_mask = [] + # TODO: dac with batching is currently broken, but non-batch is working + # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script + for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]): + # get current length with hop length in mind (as if it were sampled as a single audio) + base_pad_len = self.feature_extractor.hop_length + current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len + + encoded_sequence_len = current_audio_len // compression_rate + padding_len = max_encoded_sequence_len - encoded_sequence_len + + # compute non-padded forward pass; one extra bos (and eos if training) is added + with torch.no_grad(): + audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device) + input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2) + + if not generation: + input_ids = torch.nn.functional.pad( + input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id + ) + + # apply padding + # +1 for the bos within the real sequence + input_ids = torch.nn.functional.pad( + input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id + ) + num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay + num_valid_inputs += 0 if generation else 1 # eos if training + attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :] + + decoder_input_ids.append(input_ids) + decoder_attention_mask.append(attention_mask) + + decoder_input_ids = torch.cat(decoder_input_ids, dim=0) + decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0) + # TTS generation + elif generation: + # all bos to start with TTS + decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long) + + # we preemptively add the delay + decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long) + else: + raise ValueError("If you try to train, you should provide audio data as well.") + + if batch_size != decoder_input_ids.shape[0]: + raise ValueError( + f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and " + f"audio samples = {decoder_input_ids.shape[0]} instead." + ) + + # prepare shift indices per delay + max_seq_len = decoder_attention_mask.shape[-1] + max_audio_len = max_seq_len - max_delay + precomputed_idx = self.build_indices( + bsz=batch_size, + seq_len=max_seq_len, + num_channels=num_channels, + delay_pattern=delay_pattern, + revert=False, + ) + + # create delay pattern input + # the pad token will be used for masking which input is valid for prediction during generation + prefill = torch.full( + (batch_size, max_seq_len, num_channels), + fill_value=audio_pad_token_id, + dtype=torch.int, + ) + prefill[:, :max_audio_len] = decoder_input_ids + + delayed_decoder_input_ids = self.apply_audio_delay( + audio=prefill, + pad_token_id=audio_pad_token_id, + bos_token_id=audio_bos_token_id, + precomputed_idx=precomputed_idx, + ) + + data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask}) + + if output_labels: + # Base idea is to shift on the sequence dim + labels = data["decoder_input_ids"].clone()[:, 1:] + labels[labels == audio_pad_token_id] = -100 + labels[labels == audio_bos_token_id] = -100 + + data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long() + data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1] + data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1] + + return BatchFeature(data=data, tensor_type=return_tensors) + + def batch_decode( + self, + decoder_input_ids: "torch.Tensor", + audio_prompt_len: Optional[int] = None, + **kwargs: Unpack[DiaProcessorKwargs], + ) -> list["torch.Tensor"]: + """ + Decodes a batch of audio codebook sequences into their respective audio waveforms via the + `audio_tokenizer`. See [`~DacModel.decode`] for more information. + + Args: + decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder. + audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning). + """ + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + + delay_pattern = audio_kwargs.pop("delay_pattern", None) + audio_bos_token_id = audio_kwargs.pop("bos_token_id", None) + audio_pad_token_id = audio_kwargs.pop("pad_token_id", None) + if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None: + raise ValueError( + "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, " + "and `delay_pattern`. You may have accidentally overwritten one of those." + ) + + # either decode the whole audio sequence or only the generated parts + if audio_prompt_len is not None: + audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long) + start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0]) + else: + start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1) + # -1 for the eos token + end_of_generation_idx = ( + decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1 + ) + + # revert delay + bsz, seq_len, num_channels = decoder_input_ids.shape + precomputed_idx = self.build_indices( + bsz=bsz, + seq_len=seq_len, + num_channels=num_channels, + delay_pattern=delay_pattern, + revert=True, + ) + + output_sequences = self.apply_audio_delay( + audio=decoder_input_ids, + # We do not care about these values as we cut them out + # with `start_of_generation_idx` and `end_of_generation_idx` + pad_token_id=-1, + bos_token_id=-1, + precomputed_idx=precomputed_idx, + ).transpose(1, 2) + + # retrieve the correct sequences each + audios = [] + # TODO: see above, dac doesn't work in batches yet + with torch.no_grad(): + for i in range(start_of_generation_idx.shape[0]): + output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...] + output_i = output_i.to(self.audio_tokenizer.device) + audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze() + audios.append(audio_i) + + return audios + + def decode( + self, + decoder_input_ids: "torch.Tensor", + audio_prompt_len: Optional[int] = None, + **kwargs: Unpack[DiaProcessorKwargs], + ) -> "torch.Tensor": + """ + Decodes a single sequence of audio codebooks into the respective audio waveform via the + `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information. + """ + if decoder_input_ids.shape[0] != 1: + raise ValueError( + f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead." + ) + + return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0] + + def get_audio_prompt_len( + self, + decoder_attention_mask: "torch.Tensor", + **kwargs: Unpack[DiaProcessorKwargs], + ) -> int: + """Utility function to get the audio prompt length.""" + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + + delay_pattern = audio_kwargs.pop("delay_pattern", None) + if delay_pattern is None: + raise ValueError( + "To enable the utility of retrieving the prompt length for Dia, we need the " + "`delay_pattern`. You may have accidentally overwritten this." + ) + return decoder_attention_mask.shape[1] - max(delay_pattern) + + # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia + def save_audio( + self, + audio: AudioInput, + saving_path: Union[str, Path, list[Union[str, Path]]], + **kwargs: Unpack[DiaProcessorKwargs], + ): + # TODO: @eustlb, this should be in AudioProcessor + if not is_soundfile_available(): + raise ImportError("Please install `soundfile` to save audio files.") + + # ensure correct audio input + audio = make_list_of_audio(audio) + + # ensure correct saving path + if isinstance(saving_path, (str, Path)): + saving_path = [saving_path] + elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)): + raise ValueError("Invalid input path. Please provide a string, or a list of strings") + + if len(audio) != len(saving_path): + raise ValueError("The number of audio and saving paths must be the same") + + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + sampling_rate = audio_kwargs["sampling_rate"] + + for audio_value, p in zip(audio, saving_path): + if isinstance(audio_value, torch.Tensor): + audio_value = audio_value.cpu().float().numpy() + sf.write(p, audio_value, sampling_rate) + + @staticmethod + def build_indices( + bsz: int, + seq_len: int, + num_channels: int, + delay_pattern: list[int], + revert: bool = False, + ) -> tuple["torch.Tensor", "torch.Tensor"]: + """ + Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel] + or in[seq, channel] = out[seq + delay[channel], channel] if `revert`. + Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD. + """ + delay_array = torch.tensor(delay_pattern, dtype=torch.int32) + + # (0..seq_len-1) + sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None] + # + or - delay depending if we delay or revert the delay + if not revert: + sequence_idx = sequence_idx - delay_array[None, None, :] + else: + sequence_idx = sequence_idx + delay_array[None, None, :] + # if delay goes over the range we clamp back to valid values + valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1) + + batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels) + channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels) + + all_idx = torch.stack( + [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)], + dim=1, + ).long() + + return sequence_idx, all_idx + + @staticmethod + def apply_audio_delay( + audio: "torch.Tensor", + pad_token_id: int, + bos_token_id: int, + precomputed_idx: tuple["torch.Tensor", "torch.Tensor"], + ) -> "torch.Tensor": + """ + Applies or reverts the delay pattern to batched audio tokens using precomputed indices, + inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len. + + Args: + audio: audio tokens of shape [bsz, seq_len, num_channels] + pad_token_id: the PAD token + bos_token_id: the BOS token + precomputed_idx: from `build_indices` + + Returns: + final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels] + """ + # Move everything to the same device + device = audio.device + sequence_idx, all_idx = precomputed_idx + sequence_idx = sequence_idx.to(device) + all_idx = all_idx.to(device) + + # Gather per precomputed indices + batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1) + gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size()) + + # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD + mask_bos = sequence_idx < 0 + mask_pad = sequence_idx >= audio.shape[1] + final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio)) + + return final_audio + + +__all__ = ["DiaProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dia/tokenization_dia.py b/venv/lib/python3.13/site-packages/transformers/models/dia/tokenization_dia.py new file mode 100644 index 0000000000000000000000000000000000000000..4e205906ea709ee2c20f25b0bf6f4fa66ab1f4a4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dia/tokenization_dia.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Dia.""" + +from typing import Optional + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DiaTokenizer(PreTrainedTokenizer): + """ + Construct a Dia tokenizer. Dia simply uses raw bytes utf-8 encoding except for special tokens `[S1]` and `[S2]`. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + max_length (`int`, *optional*, defaults to 1024): + The maximum length of the sequences when encoding. Sequences longer than this will be truncated. + offset (`int`, *optional*, defaults to 0): + The offset of the tokenizer. + """ + + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + pad_token: Optional[str] = "", + unk_token: Optional[str] = "", + max_length: Optional[int] = 1024, + offset: int = 0, + **kwargs, + ): + # We have no eos/bos tokens but allow padding -- no l/r strip as we treat them as tokens as well + pad_token = AddedToken(pad_token) if isinstance(pad_token, str) else pad_token + unk_token = AddedToken(unk_token) if isinstance(unk_token, str) else unk_token + + self._utf_vocab_size = 2**8 # utf is 8 bits + self._added_tokens_decoder = {0: pad_token, 1: AddedToken("[S1]"), 2: AddedToken("[S2]")} + self.offset = offset + super().__init__( + unk_token=unk_token, + pad_token=pad_token, + max_length=max_length, + **kwargs, + ) + + @property + def vocab_size(self): + return self._utf_vocab_size + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + tokens = [chr(i) for i in text.encode("utf-8")] + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + + if len(token) != 1: + token_id = None + else: + token_id = ord(token) + self.offset + + return token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = chr(index - self.offset) + return token + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """Converts a sequence of tokens (string) in a single string.""" + bstring = b"" + for token in tokens: + if token in self.added_tokens_decoder: + added_token_obj = self.added_tokens_decoder[token] + tok_string = str(added_token_obj).encode("utf-8") + elif token in self.added_tokens_encoder: + tok_string = token.encode("utf-8") + else: + tok_string = token.encode("utf-8") # Assume general string token + bstring += tok_string + string = bstring.decode("utf-8", errors="ignore") + return string + + # No vocab file + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + return () + + +__all__ = ["DiaTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/diffllama/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/diffllama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c162fce0a48bd164bd0e0a615b942ee4805a12aa --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/diffllama/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_diffllama import * + from .modeling_diffllama import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/diffllama/configuration_diffllama.py b/venv/lib/python3.13/site-packages/transformers/models/diffllama/configuration_diffllama.py new file mode 100644 index 0000000000000000000000000000000000000000..210607271927ab2f3a7aa1ec1e874fb296c32a73 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/diffllama/configuration_diffllama.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on Llama implementations in this library and Microsoft's +# Differential Transformer implementations. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DiffLlama model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class DiffLlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults + will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DiffLlamaModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 16): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'diffllama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'diffllama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'diffllama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'diffllama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + lambda_std_dev (`float`, *optional*, defaults to 0.1): + The standard deviation for initialization of parameter lambda in attention layer. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_heads + + ```python + >>> from transformers import DiffLlamaModel, DiffLlamaConfig + + >>> # Initializing a DiffLlama diffllama-7b style configuration + >>> configuration = DiffLlamaConfig() + + >>> # Initializing a model from the diffllama-7b style configuration + >>> model = DiffLlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "diffllama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=16, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + lambda_std_dev=0.1, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.lambda_std_dev = lambda_std_dev + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["DiffLlamaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/diffllama/modeling_diffllama.py b/venv/lib/python3.13/site-packages/transformers/models/diffllama/modeling_diffllama.py new file mode 100644 index 0000000000000000000000000000000000000000..599f427280fce1ab7e640e4e438e3e7052ea1caf --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/diffllama/modeling_diffllama.py @@ -0,0 +1,767 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_diffllama.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on Llama implementations in this library and Microsoft's +# Differential Transformer implementations. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_diffllama import DiffLlamaConfig + + +logger = logging.get_logger(__name__) + + +class DiffLlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def lambda_init_fn(layer_idx): + return 0.8 - 0.6 * math.exp(-0.3 * layer_idx) + + +class DiffLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + # under this are not used + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.lambda_init = lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, target_len, _ = hidden_states.size() + q_len = target_len + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = torch.matmul(attn_weights, value_states) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiffLlamaFlashAttention2(DiffLlamaAttention): + """ + DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, None]: + if isinstance(past_key_values, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DiffLlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + value_states1, value_states2 = torch.chunk(value_states, 2, dim=2) + value_states1 = value_states1.repeat(1, 1, 2, 1) + value_states2 = value_states2.repeat(1, 1, 2, 1) + + attn_output1 = _flash_attention_forward( + query_states, + key_states, + value_states1, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output2 = _flash_attention_forward( + query_states, + key_states, + value_states2, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = torch.cat([attn_output1, attn_output2], dim=-1) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None + + +class DiffLlamaSdpaAttention(DiffLlamaAttention): + """ + DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from DiffLlamaAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = causal_mask is None and q_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, None + + +@use_kernel_forward_from_hub("RMSNorm") +class DiffLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DiffLlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +DIFFLLAMA_ATTENTION_CLASSES = { + "eager": DiffLlamaAttention, + "flash_attention_2": DiffLlamaFlashAttention2, + "sdpa": DiffLlamaSdpaAttention, +} + + +class DiffLlamaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiffLlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = DiffLlamaMLP(config) + self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class DiffLlamaPreTrainedModel(PreTrainedModel): + config: DiffLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DiffLlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = False + + _can_compile_fullgraph = True + _supports_attention_backend = False + _can_record_outputs = { + "hidden_states": DiffLlamaDecoderLayer, + "attentions": DiffLlamaAttention, + } + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, DiffLlamaAttention): + module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) + + +class DiffLlamaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: DiffLlamaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class DiffLlamaModel(DiffLlamaPreTrainedModel): + def __init__(self, config: DiffLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = DiffLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM + + >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DiffLlamaForSequenceClassification(GenericForSequenceClassification, DiffLlamaPreTrainedModel): + pass + + +class DiffLlamaForQuestionAnswering(GenericForQuestionAnswering, DiffLlamaPreTrainedModel): + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +class DiffLlamaForTokenClassification(GenericForTokenClassification, DiffLlamaPreTrainedModel): + pass + + +__all__ = [ + "DiffLlamaPreTrainedModel", + "DiffLlamaModel", + "DiffLlamaForCausalLM", + "DiffLlamaForSequenceClassification", + "DiffLlamaForQuestionAnswering", + "DiffLlamaForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/diffllama/modular_diffllama.py b/venv/lib/python3.13/site-packages/transformers/models/diffllama/modular_diffllama.py new file mode 100644 index 0000000000000000000000000000000000000000..253b99edff0d7a557da404fa680ce8403e22ccf1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/diffllama/modular_diffllama.py @@ -0,0 +1,447 @@ +# coding=utf-8 +# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on Llama implementations in this library and Microsoft's +# Differential Transformer implementations. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional + +import torch +from torch import nn + +from ...cache_utils import Cache, StaticCache +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from ...utils.deprecation import deprecate_kwarg +from ..gemma.modeling_gemma import GemmaForCausalLM +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + LlamaPreTrainedModel, + apply_rotary_pos_emb, + repeat_kv, +) +from ..mistral.modeling_mistral import MistralMLP +from .configuration_diffllama import DiffLlamaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" +_CONFIG_FOR_DOC = "DiffLlamaConfig" + + +class DiffLlamaMLP(MistralMLP): + pass + + +def lambda_init_fn(layer_idx): + return 0.8 - 0.6 * math.exp(-0.3 * layer_idx) + + +class DiffLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + # under this are not used + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.lambda_init = lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, target_len, _ = hidden_states.size() + q_len = target_len + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = torch.matmul(attn_weights, value_states) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiffLlamaFlashAttention2(DiffLlamaAttention): + """ + DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, None]: + if isinstance(past_key_values, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DiffLlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + value_states1, value_states2 = torch.chunk(value_states, 2, dim=2) + value_states1 = value_states1.repeat(1, 1, 2, 1) + value_states2 = value_states2.repeat(1, 1, 2, 1) + + attn_output1 = _flash_attention_forward( + query_states, + key_states, + value_states1, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output2 = _flash_attention_forward( + query_states, + key_states, + value_states2, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = torch.cat([attn_output1, attn_output2], dim=-1) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None + + +class DiffLlamaSdpaAttention(DiffLlamaAttention): + """ + DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from DiffLlamaAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = causal_mask is None and q_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, None + + +DIFFLLAMA_ATTENTION_CLASSES = { + "eager": DiffLlamaAttention, + "flash_attention_2": DiffLlamaFlashAttention2, + "sdpa": DiffLlamaSdpaAttention, +} + + +class DiffLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: DiffLlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + +class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): + _supports_flex_attn = False + _supports_attention_backend = False + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, DiffLlamaAttention): + module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) + + +class DiffLlamaModel(LlamaModel): + pass + + +class DiffLlamaForCausalLM(GemmaForCausalLM): + pass + + +class DiffLlamaForSequenceClassification(LlamaForSequenceClassification): + pass + + +class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +class DiffLlamaForTokenClassification(LlamaForTokenClassification): + pass + + +__all__ = [ + "DiffLlamaPreTrainedModel", + "DiffLlamaModel", + "DiffLlamaForCausalLM", + "DiffLlamaForSequenceClassification", + "DiffLlamaForQuestionAnswering", + "DiffLlamaForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinat/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dinat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b64cdbb3c7eb0467f6112225b8c0d9e1f65f9e99 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dinat/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dinat import * + from .modeling_dinat import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinat/configuration_dinat.py b/venv/lib/python3.13/site-packages/transformers/models/dinat/configuration_dinat.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d7fa509c5a3b2f5efc3b936cf1761b4ab0e107 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dinat/configuration_dinat.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dilated Neighborhood Attention Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class DinatConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Dinat + [shi-labs/dinat-mini-in1k-224](https://huggingface.co/shi-labs/dinat-mini-in1k-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 64): + Dimensionality of patch embedding. + depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 5]`): + Number of layers in each level of the encoder. + num_heads (`list[int]`, *optional*, defaults to `[2, 4, 8, 16]`): + Number of attention heads in each layer of the Transformer encoder. + kernel_size (`int`, *optional*, defaults to 7): + Neighborhood Attention kernel size. + dilations (`list[list[int]]`, *optional*, defaults to `[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]]`): + Dilation value of each NA layer in the Transformer encoder. + mlp_ratio (`float`, *optional*, defaults to 3.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + layer_scale_init_value (`float`, *optional*, defaults to 0.0): + The initial value for the layer scale. Disabled if <=0. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + + ```python + >>> from transformers import DinatConfig, DinatModel + + >>> # Initializing a Dinat shi-labs/dinat-mini-in1k-224 style configuration + >>> configuration = DinatConfig() + + >>> # Initializing a model (with random weights) from the shi-labs/dinat-mini-in1k-224 style configuration + >>> model = DinatModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinat" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + patch_size=4, + num_channels=3, + embed_dim=64, + depths=[3, 4, 6, 5], + num_heads=[2, 4, 8, 16], + kernel_size=7, + dilations=[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]], + mlp_ratio=3.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + initializer_range=0.02, + layer_norm_eps=1e-5, + layer_scale_init_value=0.0, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.kernel_size = kernel_size + self.dilations = dilations + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.layer_scale_init_value = layer_scale_init_value + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["DinatConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinat/modeling_dinat.py b/venv/lib/python3.13/site-packages/transformers/models/dinat/modeling_dinat.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7ec37b0ea8489e5db87df22c1876fd4548fe86 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dinat/modeling_dinat.py @@ -0,0 +1,855 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Dilated Neighborhood Attention Transformer model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + OptionalDependencyNotAvailable, + auto_docstring, + is_natten_available, + logging, + requires_backends, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_dinat import DinatConfig + + +if is_natten_available(): + from natten.functional import natten2dav, natten2dqkrpb +else: + + def natten2dqkrpb(*args, **kwargs): + raise OptionalDependencyNotAvailable() + + def natten2dav(*args, **kwargs): + raise OptionalDependencyNotAvailable() + + +logger = logging.get_logger(__name__) + + +# drop_path and DinatDropPath are from the timm library. + + +@dataclass +@auto_docstring( + custom_intro=""" + Dinat encoder's outputs, with potential hidden states and attentions. + """ +) +class DinatEncoderOutput(ModelOutput): + r""" + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Dinat model's outputs that also contains a pooling of the last hidden states. + """ +) +class DinatModelOutput(ModelOutput): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Dinat outputs for image classification. + """ +) +class DinatImageClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +class DinatEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = DinatPatchEmbeddings(config) + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor]: + embeddings = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class DinatPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + patch_size = config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + self.num_channels = num_channels + + if patch_size == 4: + pass + else: + # TODO: Support arbitrary patch sizes. + raise ValueError("Dinat only supports patch size of 4 at the moment.") + + self.projection = nn.Sequential( + nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + ) + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor: + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + embeddings = embeddings.permute(0, 2, 3, 1) + + return embeddings + + +class DinatDownsampler(nn.Module): + """ + Convolutional Downsampling Layer. + + Args: + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.dim = dim + self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, input_feature: torch.Tensor) -> torch.Tensor: + input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + input_feature = self.norm(input_feature) + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat +class DinatDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class NeighborhoodAttention(nn.Module): + def __init__(self, config, dim, num_heads, kernel_size, dilation): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.kernel_size = kernel_size + self.dilation = dilation + + # rpb is learnable relative positional biases; same concept is used Swin. + self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1))) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Apply the scale factor before computing attention weights. It's usually more efficient because + # attention weights are typically a bigger tensor compared to query. + # It gives identical results because scalars are commutable in matrix multiplication. + query_layer = query_layer / math.sqrt(self.attention_head_size) + + # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases. + attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation) + context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class NeighborhoodAttentionOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class NeighborhoodAttentionModule(nn.Module): + def __init__(self, config, dim, num_heads, kernel_size, dilation): + super().__init__() + self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation) + self.output = NeighborhoodAttentionOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class DinatIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class DinatOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class DinatLayer(nn.Module): + def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.kernel_size = config.kernel_size + self.dilation = dilation + self.window_size = self.kernel_size * self.dilation + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = NeighborhoodAttentionModule( + config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation + ) + self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = DinatIntermediate(config, dim) + self.output = DinatOutput(config, dim) + self.layer_scale_parameters = ( + nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True) + if config.layer_scale_init_value > 0 + else None + ) + + def maybe_pad(self, hidden_states, height, width): + window_size = self.window_size + pad_values = (0, 0, 0, 0, 0, 0) + if height < window_size or width < window_size: + pad_l = pad_t = 0 + pad_r = max(0, window_size - width) + pad_b = max(0, window_size - height) + pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, height, width, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + # pad hidden_states if they are smaller than kernel size x dilation + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + + attention_outputs = self.attention(hidden_states, output_attentions=output_attentions) + + attention_output = attention_outputs[0] + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_output = attention_output[:, :height, :width, :].contiguous() + + if self.layer_scale_parameters is not None: + attention_output = self.layer_scale_parameters[0] * attention_output + + hidden_states = shortcut + self.drop_path(attention_output) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.output(self.intermediate(layer_output)) + + if self.layer_scale_parameters is not None: + layer_output = self.layer_scale_parameters[1] * layer_output + + layer_output = hidden_states + self.drop_path(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class DinatStage(nn.Module): + def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample): + super().__init__() + self.config = config + self.dim = dim + self.layers = nn.ModuleList( + [ + DinatLayer( + config=config, + dim=dim, + num_heads=num_heads, + dilation=dilations[i], + drop_path_rate=drop_path_rate[i], + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + _, height, width, _ = hidden_states.size() + for i, layer_module in enumerate(self.layers): + layer_outputs = layer_module(hidden_states, output_attentions) + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + hidden_states = self.downsample(hidden_states_before_downsampling) + + stage_outputs = (hidden_states, hidden_states_before_downsampling) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class DinatEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.num_levels = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.levels = nn.ModuleList( + [ + DinatStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + dilations=config.dilations[i_layer], + drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None, + ) + for i_layer in range(self.num_levels) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple, DinatEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.levels): + layer_outputs = layer_module(hidden_states, output_attentions) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + + if output_hidden_states and output_hidden_states_before_downsampling: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return DinatEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +@auto_docstring +class DinatPreTrainedModel(PreTrainedModel): + config: DinatConfig + base_model_prefix = "dinat" + main_input_name = "pixel_values" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class DinatModel(DinatPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + + requires_backends(self, ["natten"]) + + self.config = config + self.num_levels = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1)) + + self.embeddings = DinatEmbeddings(config) + self.encoder = DinatEncoder(config) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, DinatModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return DinatModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """ +) +class DinatForImageClassification(DinatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + requires_backends(self, ["natten"]) + + self.num_labels = config.num_labels + self.dinat = DinatModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, DinatImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinat( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DinatImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + NAT backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class DinatBackbone(DinatPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + requires_backends(self, ["natten"]) + + self.embeddings = DinatEmbeddings(config) + self.encoder = DinatEncoder(config) + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + + # Add layer norms to hidden states of out_features + hidden_states_norms = {} + for stage, num_channels in zip(self._out_features, self.channels): + hidden_states_norms[stage] = nn.LayerNorm(num_channels) + self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") + >>> model = AutoBackbone.from_pretrained( + ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 512, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + return_dict=True, + ) + + hidden_states = outputs.reshaped_hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.13/site-packages/transformers/models/doge/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/doge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91aeca0340f162439b6b7485f325f5343802dd6e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/doge/__init__.py @@ -0,0 +1,28 @@ +# coding=utf-8 +# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_doge import * + from .modeling_doge import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/doge/configuration_doge.py b/venv/lib/python3.13/site-packages/transformers/models/doge/configuration_doge.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a93fa198f2073ca0e34e613a7d9bed01f892d0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/doge/configuration_doge.py @@ -0,0 +1,241 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/doge/modular_doge.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_doge.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. +# +# The Doge family of small language models is trained by SmallDoge Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class DogeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge + model according to the specified arguments, defining the model architecture like [SmallDoge/Doge-320M](https://huggingface.co/SmallDoge/Doge-320M). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the Doge2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DogeModel`] + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + hidden_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for each sequence transformation and state transformation module. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. + NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. + Doge family of small models use `{ 'rope_type': 'dynamic', 'factor': 4.0, 'original_max_position_embeddings': 2048 }` as the default value. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. + In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. + The original max position embeddings used during pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. + If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (<`original_max_position_embeddings`). + Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (<`original_max_position_embeddings`). + Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. + If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. + When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. + For more details checkout [this paper](https://huggingface.co/papers/2305.13245). + If it is not specified, will default to `num_attention_heads`. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None`. + keep_window_size (`int`, *optional*, defaults to 2048): + The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. + is_moe (`bool`, *optional*, defaults to `False`): + Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize. + num_experts (`int`, *optional*, defaults to 16384): + Number of routed experts in the model. This is only used when `is_moe=True`. + num_experts_per_tok (`int`, *optional*, defaults to 64): + Number of selected experts to route per-token. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import DogeConfig, DogeModel + + >>> # Initializing a Doge-320M style configuration + >>> configuration = DogeConfig() + + >>> # Initializing a model from the Doge-320M style configuration + >>> model = DogeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "doge" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `DogeModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.dt_proj": "rowwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.input_layernorm.weight": "sequence_parallel", + "layers.*.input_residual.weight": "sequence_parallel", + "layers.*.post_attention_layernorm.weight": "sequence_parallel", + "layers.*.post_attention_residual.weight": "sequence_parallel", + "norm.weight": "sequence_parallel", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + "layers.*.mlp.router_gate": "colwise_rep", + "layers.*.mlp.down_embed": "rowwise_rep", + "layers.*.mlp.up_embed": "rowwise_rep", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32768, + hidden_size=1024, + intermediate_size=2048, + num_hidden_layers=32, + hidden_dropout=0.0, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + tie_word_embeddings=False, + max_position_embeddings=2048, + rope_theta=10000.0, + rope_scaling=None, + num_attention_heads=8, + num_key_value_heads=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=None, + keep_window_size=2048, + is_moe=False, + num_experts=16384, + num_experts_per_tok=64, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + + self.hidden_dropout = hidden_dropout + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.sliding_window = sliding_window + self.keep_window_size = keep_window_size + self.is_moe = is_moe + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # for backward compatibility + if num_key_value_heads is None: + self.num_key_value_heads = num_attention_heads + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["DogeConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/doge/modeling_doge.py b/venv/lib/python3.13/site-packages/transformers/models/doge/modeling_doge.py new file mode 100644 index 0000000000000000000000000000000000000000..05c0ae688c95039f8653619239074d424ab0a856 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/doge/modeling_doge.py @@ -0,0 +1,812 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/doge/modular_doge.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_doge.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. +# +# The Doge family of small language models is trained by SmallDoge Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...integrations.flex_attention import compile_friendly_flex_attention +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import AttentionInterface, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_doge import DogeConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + +@use_kernel_forward_from_hub("RMSNorm") +class DogeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DogeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DogeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: DogeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def flex_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + block_mask = None + causal_mask = None + if isinstance(attention_mask, BlockMask): + block_mask = attention_mask + else: + causal_mask = attention_mask + + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) + if causal_mask is not None: + score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx] + if head_mask is not None: + score = score + head_mask[batch_idx][head_idx][0][0] + return score + + attn_output, attention_weights = compile_friendly_flex_attention( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + enable_gqa=True, + scale=scaling, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=True, + ) + # lse is returned in float32 + attention_weights = attention_weights.to(value.dtype) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attention_weights + + +ALL_ATTENTION_FUNCTIONS = AttentionInterface() +ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward + + +class DogeAttention(nn.Module): + def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.keep_window_size = config.keep_window_size + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + # dynamic mask for the QK^T attention weights matrix + self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.dt_proj = nn.Linear( + config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # calculate dynamic mask from value_states + dt_states = self.dt_proj( + value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) + ) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + attn_mask = self.prepare_dynamic_mask( + hidden_states=hidden_states, + dt_states=dt_states, + keep_window_size=self.keep_window_size, + attention_mask=attention_mask, + ) + attn_mask = repeat_kv(attn_mask, self.num_key_value_groups) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=attn_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + def prepare_dynamic_mask( + self, + hidden_states: torch.Tensor, + dt_states: torch.Tensor, + keep_window_size: int = 2048, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention. + + Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. + + Args: + hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision. + dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`. + keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. + attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`. + """ + min_dtype = torch.finfo(hidden_states.dtype).min + dtype = hidden_states.dtype + attn_mask = dt_states[:, :, None, :].expand( + -1, -1, hidden_states.shape[1], -1 + ) # [batch_size, num_heads, query_len, key_len] + if attention_mask is not None and not isinstance(attention_mask, BlockMask): + if attention_mask.dtype == torch.bool: + dtype = hidden_states.dtype + attention_mask = torch.where( + attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype + ) + attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype) + if attn_mask.shape[-1] > keep_window_size: + active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) + topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices + active_mask = active_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) + return attn_mask + + +class DogeMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DogeCDMoE(nn.Module): + def __init__(self, config: DogeConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + + self.num_experts = config.num_experts + self.num_keys = math.floor(math.sqrt(self.num_experts)) + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # shared expert + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + + # router gate for retrieval experts + self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False) + + # routed experts + self.down_embed = nn.Embedding(self.num_experts, self.hidden_size) + self.up_embed = nn.Embedding(self.num_experts, self.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + bsz, seq_len, _ = hidden_states.shape + + # get routing logits with router gate + router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1) + + # get experts with the highest routing logits + (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1) + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) + all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2) + all_scores = all_scores.view(*all_scores.shape[:-2], -1) + all_indices = all_indices.view(*all_indices.shape[:-2], -1) + scores, position_indices = all_scores.topk(self.top_k, dim=-1) + indices = all_indices.gather(-1, position_indices) + routing_weights = F.softmax(scores, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # mix routed experts states with shared expert states + down_embed = self.down_embed(indices) + up_embed = self.up_embed(indices) + experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1) + experts_weights = self.act_fn(experts_weights) * routing_weights + experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1) + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = hidden_states + experts_states + return hidden_states, router_logits + + +class DogeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_dropout = config.hidden_dropout + + self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = DogeAttention(config=config, layer_idx=layer_idx) + self.input_residual = nn.Parameter(torch.ones(config.hidden_size)) + + self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config) + self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size)) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + # sequence transformation + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) + hidden_states = self.input_residual * residual + hidden_states + + # state transformation + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) + hidden_states = self.post_attention_residual * residual + hidden_states + + return hidden_states + + +@auto_docstring +class DogePreTrainedModel(PreTrainedModel): + config: DogeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DogeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = False + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(DogeCDMoE, index=1), + "hidden_states": DogeDecoderLayer, + "attentions": DogeAttention, + } + + def _init_weights(self, module): + """Initialize the weights""" + super()._init_weights(module) + if isinstance(module, DogeAttention): + if hasattr(module, "A"): + module.A.data.zero_() + elif isinstance(module, DogeDecoderLayer): + if hasattr(module, "input_residual"): + module.input_residual.data.fill_(1.0) + if hasattr(module, "post_attention_residual"): + module.post_attention_residual.data.fill_(1.0) + + +@auto_docstring +class DogeModel(DogePreTrainedModel): + def __init__(self, config: DogeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DogeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + num_keys: Optional[int] = None, + top_k: int = 2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [2, batch_size * sequence_length, num_keys]. + num_experts: + Number of experts + num_keys: + Number of keys + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + compute_dtype = gate_logits[0].dtype + compute_device = gate_logits[0].device + all_expert_indices = [] + all_routing_weights = [] + + for layer_gate_logits in gate_logits: + layer_gate_logits = layer_gate_logits.to(compute_device) + + (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1) + + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) + all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2) + all_scores = all_scores.view(*all_scores.shape[:-2], -1) + all_indices = all_indices.view(*all_indices.shape[:-2], -1) + + _, position_indices = all_scores.topk(top_k, dim=-1) + expert_indices = all_indices.gather(-1, position_indices) + + routing_weights = F.softmax(all_scores, dim=-1) + + all_expert_indices.append(expert_indices) + all_routing_weights.append(routing_weights) + all_expert_indices = torch.cat(all_expert_indices, dim=0) + all_routing_weights = torch.cat(all_routing_weights, dim=0) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + all_expert_indices = all_expert_indices.view(-1) + tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device) + pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device) + tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0] + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(all_routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = len(gate_logits) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k)) + .reshape(-1) + .to(compute_device) + ) + all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()] + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device) + pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device) + tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum( + expert_attention_mask + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert) + return overall_loss * num_experts + + +@auto_docstring +class DogeForCausalLM(DogePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = DogeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + output_router_logits: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, DogeForCausalLM + + >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M") + >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + math.floor(math.sqrt(self.num_experts)), + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class DogeForSequenceClassification(GenericForSequenceClassification, DogePreTrainedModel): + pass + + +__all__ = ["DogeForCausalLM", "DogeModel", "DogePreTrainedModel", "DogeForSequenceClassification"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/doge/modular_doge.py b/venv/lib/python3.13/site-packages/transformers/models/doge/modular_doge.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c95e627376982b3042465624a618bdf1278b87 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/doge/modular_doge.py @@ -0,0 +1,800 @@ +# coding=utf-8 +# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. +# +# The Doge family of small language models is trained by SmallDoge Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Doge model.""" + +import math +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...integrations.flex_attention import compile_friendly_flex_attention +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import AttentionInterface, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, is_torch_flex_attn_available +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder +from ..llama.modeling_llama import ( + LlamaForSequenceClassification, + LlamaMLP, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, + repeat_kv, +) +from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + +class DogeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge + model according to the specified arguments, defining the model architecture like [SmallDoge/Doge-320M](https://huggingface.co/SmallDoge/Doge-320M). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the Doge2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DogeModel`] + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + hidden_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for each sequence transformation and state transformation module. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. + NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. + Doge family of small models use `{ 'rope_type': 'dynamic', 'factor': 4.0, 'original_max_position_embeddings': 2048 }` as the default value. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. + In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. + The original max position embeddings used during pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. + If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (<`original_max_position_embeddings`). + Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (<`original_max_position_embeddings`). + Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. + If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. + When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. + For more details checkout [this paper](https://huggingface.co/papers/2305.13245). + If it is not specified, will default to `num_attention_heads`. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None`. + keep_window_size (`int`, *optional*, defaults to 2048): + The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. + is_moe (`bool`, *optional*, defaults to `False`): + Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize. + num_experts (`int`, *optional*, defaults to 16384): + Number of routed experts in the model. This is only used when `is_moe=True`. + num_experts_per_tok (`int`, *optional*, defaults to 64): + Number of selected experts to route per-token. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import DogeConfig, DogeModel + + >>> # Initializing a Doge-320M style configuration + >>> configuration = DogeConfig() + + >>> # Initializing a model from the Doge-320M style configuration + >>> model = DogeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "doge" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `DogeModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.dt_proj": "rowwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.input_layernorm.weight": "sequence_parallel", + "layers.*.input_residual.weight": "sequence_parallel", + "layers.*.post_attention_layernorm.weight": "sequence_parallel", + "layers.*.post_attention_residual.weight": "sequence_parallel", + "norm.weight": "sequence_parallel", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + "layers.*.mlp.router_gate": "colwise_rep", + "layers.*.mlp.down_embed": "rowwise_rep", + "layers.*.mlp.up_embed": "rowwise_rep", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32768, + hidden_size=1024, + intermediate_size=2048, + num_hidden_layers=32, + hidden_dropout=0.0, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + tie_word_embeddings=False, + max_position_embeddings=2048, + rope_theta=10000.0, + rope_scaling=None, + num_attention_heads=8, + num_key_value_heads=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=None, + keep_window_size=2048, + is_moe=False, + num_experts=16384, + num_experts_per_tok=64, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + + self.hidden_dropout = hidden_dropout + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.sliding_window = sliding_window + self.keep_window_size = keep_window_size + self.is_moe = is_moe + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # for backward compatibility + if num_key_value_heads is None: + self.num_key_value_heads = num_attention_heads + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DogeRMSNorm(LlamaRMSNorm): + pass + + +class DogeRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +def flex_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + block_mask = None + causal_mask = None + if isinstance(attention_mask, BlockMask): + block_mask = attention_mask + else: + causal_mask = attention_mask + + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) + if causal_mask is not None: + score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx] + if head_mask is not None: + score = score + head_mask[batch_idx][head_idx][0][0] + return score + + attn_output, attention_weights = compile_friendly_flex_attention( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + enable_gqa=True, + scale=scaling, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=True, + ) + # lse is returned in float32 + attention_weights = attention_weights.to(value.dtype) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attention_weights + + +ALL_ATTENTION_FUNCTIONS = AttentionInterface() +ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward + + +class DogeAttention(nn.Module): + def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.keep_window_size = config.keep_window_size + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + # dynamic mask for the QK^T attention weights matrix + self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.dt_proj = nn.Linear( + config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # calculate dynamic mask from value_states + dt_states = self.dt_proj( + value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) + ) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + attn_mask = self.prepare_dynamic_mask( + hidden_states=hidden_states, + dt_states=dt_states, + keep_window_size=self.keep_window_size, + attention_mask=attention_mask, + ) + attn_mask = repeat_kv(attn_mask, self.num_key_value_groups) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=attn_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + def prepare_dynamic_mask( + self, + hidden_states: torch.Tensor, + dt_states: torch.Tensor, + keep_window_size: int = 2048, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention. + + Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. + + Args: + hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision. + dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`. + keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. + attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`. + """ + min_dtype = torch.finfo(hidden_states.dtype).min + dtype = hidden_states.dtype + attn_mask = dt_states[:, :, None, :].expand( + -1, -1, hidden_states.shape[1], -1 + ) # [batch_size, num_heads, query_len, key_len] + if attention_mask is not None and not isinstance(attention_mask, BlockMask): + if attention_mask.dtype == torch.bool: + dtype = hidden_states.dtype + attention_mask = torch.where( + attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype + ) + attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype) + if attn_mask.shape[-1] > keep_window_size: + active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) + topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices + active_mask = active_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) + return attn_mask + + +class DogeMLP(LlamaMLP): + pass + + +class DogeCDMoE(nn.Module): + def __init__(self, config: DogeConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + + self.num_experts = config.num_experts + self.num_keys = math.floor(math.sqrt(self.num_experts)) + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # shared expert + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + + # router gate for retrieval experts + self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False) + + # routed experts + self.down_embed = nn.Embedding(self.num_experts, self.hidden_size) + self.up_embed = nn.Embedding(self.num_experts, self.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + bsz, seq_len, _ = hidden_states.shape + + # get routing logits with router gate + router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1) + + # get experts with the highest routing logits + (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1) + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) + all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2) + all_scores = all_scores.view(*all_scores.shape[:-2], -1) + all_indices = all_indices.view(*all_indices.shape[:-2], -1) + scores, position_indices = all_scores.topk(self.top_k, dim=-1) + indices = all_indices.gather(-1, position_indices) + routing_weights = F.softmax(scores, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # mix routed experts states with shared expert states + down_embed = self.down_embed(indices) + up_embed = self.up_embed(indices) + experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1) + experts_weights = self.act_fn(experts_weights) * routing_weights + experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1) + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = hidden_states + experts_states + return hidden_states, router_logits + + +class DogeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_dropout = config.hidden_dropout + + self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = DogeAttention(config=config, layer_idx=layer_idx) + self.input_residual = nn.Parameter(torch.ones(config.hidden_size)) + + self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config) + self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size)) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + # sequence transformation + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) + hidden_states = self.input_residual * residual + hidden_states + + # state transformation + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) + hidden_states = self.post_attention_residual * residual + hidden_states + + return hidden_states + + +class DogePreTrainedModel(LlamaPreTrainedModel): + _supports_flash_attn = False + _can_compile_fullgraph = False + _can_record_outputs = { + "router_logits": OutputRecorder(DogeCDMoE, index=1), + "hidden_states": DogeDecoderLayer, + "attentions": DogeAttention, + } + + def _init_weights(self, module): + """Initialize the weights""" + PreTrainedModel._init_weights(self, module) + if isinstance(module, DogeAttention): + if hasattr(module, "A"): + module.A.data.zero_() + elif isinstance(module, DogeDecoderLayer): + if hasattr(module, "input_residual"): + module.input_residual.data.fill_(1.0) + if hasattr(module, "post_attention_residual"): + module.post_attention_residual.data.fill_(1.0) + + +class DogeModel(MixtralModel): + pass + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + num_keys: Optional[int] = None, + top_k: int = 2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [2, batch_size * sequence_length, num_keys]. + num_experts: + Number of experts + num_keys: + Number of keys + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + compute_dtype = gate_logits[0].dtype + compute_device = gate_logits[0].device + all_expert_indices = [] + all_routing_weights = [] + + for layer_gate_logits in gate_logits: + layer_gate_logits = layer_gate_logits.to(compute_device) + + (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1) + + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) + all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2) + all_scores = all_scores.view(*all_scores.shape[:-2], -1) + all_indices = all_indices.view(*all_indices.shape[:-2], -1) + + _, position_indices = all_scores.topk(top_k, dim=-1) + expert_indices = all_indices.gather(-1, position_indices) + + routing_weights = F.softmax(all_scores, dim=-1) + + all_expert_indices.append(expert_indices) + all_routing_weights.append(routing_weights) + all_expert_indices = torch.cat(all_expert_indices, dim=0) + all_routing_weights = torch.cat(all_routing_weights, dim=0) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + all_expert_indices = all_expert_indices.view(-1) + tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device) + pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device) + tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0] + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(all_routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = len(gate_logits) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k)) + .reshape(-1) + .to(compute_device) + ) + all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()] + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device) + pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device) + tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum( + expert_attention_mask + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert) + return overall_loss * num_experts + + +class DogeForCausalLM(MixtralForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = DogeModel(config) + self.num_experts = config.num_experts + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + output_router_logits: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, DogeForCausalLM + + >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M") + >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + math.floor(math.sqrt(self.num_experts)), + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class DogeForSequenceClassification(LlamaForSequenceClassification): + pass + + +__all__ = [ + "DogeConfig", + "DogeForCausalLM", + "DogeModel", + "DogePreTrainedModel", + "DogeForSequenceClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/donut/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/donut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..834c451f78fa0d4c5fe91f59719b6505c4c4e4e5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/donut/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_donut_swin import * + from .feature_extraction_donut import * + from .image_processing_donut import * + from .image_processing_donut_fast import * + from .modeling_donut_swin import * + from .processing_donut import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/donut/configuration_donut_swin.py b/venv/lib/python3.13/site-packages/transformers/models/donut/configuration_donut_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..9aac07dace7688273be0bdc57da0a12663c2fb5b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/donut/configuration_donut_swin.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Donut Swin Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DonutSwinConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DonutSwinModel`]. It is used to instantiate a + Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Donut + [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to add absolute position embeddings to the patch embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import DonutSwinConfig, DonutSwinModel + + >>> # Initializing a Donut naver-clova-ix/donut-base style configuration + >>> configuration = DonutSwinConfig() + + >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration + >>> model = DonutSwinModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "donut-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + + +__all__ = ["DonutSwinConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/donut/feature_extraction_donut.py b/venv/lib/python3.13/site-packages/transformers/models/donut/feature_extraction_donut.py new file mode 100644 index 0000000000000000000000000000000000000000..e37a58ddd3055e040c6c29cbd5f5cc3c34270cbe --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/donut/feature_extraction_donut.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Donut.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_donut import DonutImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class DonutFeatureExtractor(DonutImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class DonutFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use DonutImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["DonutFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/donut/image_processing_donut.py b/venv/lib/python3.13/site-packages/transformers/models/donut/image_processing_donut.py new file mode 100644 index 0000000000000000000000000000000000000000..7dec96422c5d82e51fe9c4291930d9827bd54e11 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/donut/image_processing_donut.py @@ -0,0 +1,471 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Donut.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + pad, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.import_utils import is_vision_available, requires + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +@requires(backends=("vision",)) +class DonutImageProcessor(BaseImageProcessor): + r""" + Constructs a Donut image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `False`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `random_padding` is set to `True` in `preprocess`, each image is padded with a + random amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Image standard deviation. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_thumbnail: bool = True, + do_align_long_axis: bool = False, + do_pad: bool = True, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + size = size if size is not None else {"height": 2560, "width": 1920} + if isinstance(size, (tuple, list)): + # The previous feature extractor size parameter was in (width, height) format + size = size[::-1] + size = get_size_dict(size) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_thumbnail = do_thumbnail + self.do_align_long_axis = do_align_long_axis + self.do_pad = do_pad + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def align_long_axis( + self, + image: np.ndarray, + size: dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`np.ndarray`): + The image to be aligned. + size (`dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + + Returns: + `np.ndarray`: The aligned image. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = size["height"], size["width"] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if input_data_format == ChannelDimension.LAST: + rot_axes = (0, 1) + elif input_data_format == ChannelDimension.FIRST: + rot_axes = (1, 2) + else: + raise ValueError(f"Unsupported data format: {input_data_format}") + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + image = np.rot90(image, 3, axes=rot_axes) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + def pad_image( + self, + image: np.ndarray, + size: dict[str, int], + random_padding: bool = False, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pad the image to the specified size. + + Args: + image (`np.ndarray`): + The image to be padded. + size (`dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + random_padding (`bool`, *optional*, defaults to `False`): + Whether to use random padding or not. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = size["height"], size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + delta_width = output_width - input_width + delta_height = output_height - input_height + + if random_padding: + pad_top = np.random.randint(low=0, high=delta_height + 1) + pad_left = np.random.randint(low=0, high=delta_width + 1) + else: + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = ((pad_top, pad_bottom), (pad_left, pad_right)) + return pad(image, padding, data_format=data_format, input_data_format=input_data_format) + + def thumbnail( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. + + Args: + image (`np.ndarray`): + The image to be resized. + size (`dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use. + data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = size["height"], size["width"] + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + return resize( + image, + size=(height, width), + resample=resample, + reducing_gap=2.0, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resizes `image` to `(height, width)` specified by `size` using the PIL library. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size) + shortest_edge = min(size["height"], size["width"]) + output_size = get_resize_output_image_size( + image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format + ) + resized_image = resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return resized_image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_thumbnail: Optional[bool] = None, + do_align_long_axis: Optional[bool] = None, + do_pad: Optional[bool] = None, + random_padding: bool = False, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to min(size["height"], + size["width"]) with the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random + amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + random_padding (`bool`, *optional*, defaults to `self.random_padding`): + Whether to use random padding when padding the image. If `True`, each image in the batch with be padded + with a random amount of padding on each side up to the size of the largest image in the batch. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image pixel values. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: defaults to the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + if isinstance(size, (tuple, list)): + # Previous feature extractor had size in (width, height) format + size = size[::-1] + size = get_size_dict(size) + resample = resample if resample is not None else self.resample + do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail + do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis + do_pad = do_pad if do_pad is not None else self.do_pad + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_align_long_axis: + images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images] + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_thumbnail: + images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images] + + if do_pad: + images = [ + self.pad_image( + image=image, size=size, random_padding=random_padding, input_data_format=input_data_format + ) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["DonutImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/donut/image_processing_donut_fast.py b/venv/lib/python3.13/site-packages/transformers/models/donut/image_processing_donut_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..29e06831b1b48d1bccc6945a49c675f9bcd01e9a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/donut/image_processing_donut_fast.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Donut.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + logging, +) + + +logger = logging.get_logger(__name__) + + +class DonutFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + Args: + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + """ + + do_thumbnail: Optional[bool] + do_align_long_axis: Optional[bool] + + +@auto_docstring +class DonutImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 2560, "width": 1920} + do_resize = True + do_rescale = True + do_normalize = True + do_thumbnail = True + do_align_long_axis = False + do_pad = True + valid_kwargs = DonutFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[DonutFastImageProcessorKwargs]): + size = kwargs.pop("size", None) + if isinstance(size, (tuple, list)): + size = size[::-1] + kwargs["size"] = size + super().__init__(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[DonutFastImageProcessorKwargs]) -> BatchFeature: + if "size" in kwargs: + size = kwargs.pop("size") + if isinstance(size, (tuple, list)): + size = size[::-1] + kwargs["size"] = size + return super().preprocess(images, **kwargs) + + def align_long_axis( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`torch.Tensor`): + The image to be aligned. + size (`dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + + Returns: + `torch.Tensor`: The aligned image. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + height_dim, width_dim = image.dim() - 2, image.dim() - 1 + image = torch.rot90(image, 3, dims=[height_dim, width_dim]) + + return image + + def pad_image( + self, + image: "torch.Tensor", + size: SizeDict, + random_padding: bool = False, + ) -> "torch.Tensor": + """ + Pad the image to the specified size. + + Args: + image (`torch.Tensor`): + The image to be padded. + size (`dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + random_padding (`bool`, *optional*, defaults to `False`): + Whether to use random padding or not. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = size.height, size.width + input_height, input_width = image.shape[-2:] + + delta_width = output_width - input_width + delta_height = output_height - input_height + + if random_padding: + pad_top = torch.random.randint(low=0, high=delta_height + 1) + pad_left = torch.random.randint(low=0, high=delta_width + 1) + else: + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = (pad_left, pad_top, pad_right, pad_bottom) + return F.pad(image, padding) + + def pad(self, *args, **kwargs): + logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.") + return self.pad_image(*args, **kwargs) + + def thumbnail( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. + + Args: + image (`torch.Tensor`): + The image to be resized. + size (`dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use. + data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + return self.resize( + image, + size=SizeDict(width=width, height=height), + interpolation=F.InterpolationMode.BICUBIC, + ) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + do_thumbnail: bool, + do_align_long_axis: bool, + do_pad: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_align_long_axis: + stacked_images = self.align_long_axis(image=stacked_images, size=size) + if do_resize: + shortest_edge = min(size.height, size.width) + stacked_images = self.resize( + image=stacked_images, size=SizeDict(shortest_edge=shortest_edge), interpolation=interpolation + ) + if do_thumbnail: + stacked_images = self.thumbnail(image=stacked_images, size=size) + if do_pad: + stacked_images = self.pad_image(image=stacked_images, size=size, random_padding=False) + + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["DonutImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/donut/modeling_donut_swin.py b/venv/lib/python3.13/site-packages/transformers/models/donut/modeling_donut_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..c5736b16183b7b9ca1363617ef1d4e149855df64 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/donut/modeling_donut_swin.py @@ -0,0 +1,1031 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Donut Swin Transformer model. + +This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden +states.""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging, torch_int +from .configuration_donut_swin import DonutSwinConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + DonutSwin encoder's outputs, with potential hidden states and attentions. + """ +) +# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin +class DonutSwinEncoderOutput(ModelOutput): + r""" + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + DonutSwin model's outputs that also contains a pooling of the last hidden states. + """ +) +# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin +class DonutSwinModelOutput(ModelOutput): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + DonutSwin outputs for image classification. + """ +) +# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin +class DonutSwinImageClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin +class DonutSwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = DonutSwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> tuple[torch.Tensor]: + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin +class DonutSwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor, tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging +class DonutSwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be divisible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath +class DonutSwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin +class DonutSwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + hidden_shape = (batch_size, dim, -1, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput +class DonutSwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin +class DonutSwinAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size) + self.output = DonutSwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate +class DonutSwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput +class DonutSwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin +class DonutSwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size) + self.drop_path = DonutSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = DonutSwinIntermediate(config, dim) + self.output = DonutSwinOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin +class DonutSwinStage(GradientCheckpointingLayer): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + DonutSwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[i], + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin +class DonutSwinEncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.layers = nn.ModuleList( + [ + DonutSwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + always_partition: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple, DonutSwinEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return DonutSwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +@auto_docstring +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut +class DonutSwinPreTrainedModel(PreTrainedModel): + config: DonutSwinConfig + base_model_prefix = "donut" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["DonutSwinStage"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, DonutSwinEmbeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, DonutSwinSelfAttention): + module.relative_position_bias_table.data.zero_() + + +@auto_docstring +class DonutSwinModel(DonutSwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + """ + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) + + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, DonutSwinModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return DonutSwinModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """ +) +# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut +class DonutSwinForImageClassification(DonutSwinPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.donut = DonutSwinModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, DonutSwinImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.donut( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DonutSwinImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/donut/processing_donut.py b/venv/lib/python3.13/site-packages/transformers/models/donut/processing_donut.py new file mode 100644 index 0000000000000000000000000000000000000000..c75e2fcaa54203bdb4f3d64b977cc160c11c11fc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/donut/processing_donut.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Donut. +""" + +import re +import warnings +from contextlib import contextmanager +from typing import Optional, Union + +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +class DonutProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + +logger = logging.get_logger(__name__) + + +class DonutProcessor(ProcessorMixin): + r""" + Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single + processor. + + [`DonutProcessor`] offers all the functionalities of [`DonutImageProcessor`] and + [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. See the [`~DonutProcessor.__call__`] and + [`~DonutProcessor.decode`] for more information. + + Args: + image_processor ([`DonutImageProcessor`], *optional*): + An instance of [`DonutImageProcessor`]. The image processor is a required input. + tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`], *optional*): + An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self._in_target_context_manager = False + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[str, list[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[DonutProcessorKwargs], + ): + """ + When used in normal mode, this method forwards all its arguments to AutoImageProcessor's + [`~AutoImageProcessor.__call__`] and returns its output. If used in the context + [`~DonutProcessor.as_target_processor`] this method forwards all its arguments to DonutTokenizer's + [`~DonutTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. + """ + if self._in_target_context_manager: + return self.current_processor(images, text, **kwargs) + + if images is None and text is None: + raise ValueError("You need to specify either an `images` or `text` input to process.") + + output_kwargs = self._merge_kwargs( + DonutProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + if text is not None: + if images is not None: + output_kwargs["text_kwargs"].setdefault("add_special_tokens", False) + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if text is None: + return inputs + elif images is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] # for BC + inputs["input_ids"] = encodings["input_ids"] + return inputs + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + + return list(image_processor_input_names + ["input_ids", "labels"]) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your images inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.image_processor + self._in_target_context_manager = False + + def token2json(self, tokens, is_inner_value=False, added_vocab=None): + """ + Convert a (generated) token sequence into an ordered JSON format. + """ + if added_vocab is None: + added_vocab = self.tokenizer.get_added_vocab() + + output = {} + + while tokens: + # We want r"" but without ReDOS risk, so do it manually in two parts + potential_start = re.search(r"" not in start_token: + break + start_token = start_token[: start_token.index(">") + 1] + key = start_token[len("")] + key_escaped = re.escape(key) + + end_token = re.search(rf"", tokens, re.IGNORECASE) + if end_token is None: + tokens = tokens.replace(start_token, "") + else: + end_token = end_token.group() + start_token_escaped = re.escape(start_token) + end_token_escaped = re.escape(end_token) + content = re.search( + f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL + ) + if content is not None: + content = content.group(1).strip() + if r""): + leaf = leaf.strip() + if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>": + leaf = leaf[1:-2] # for categorical special tokens + output[key].append(leaf) + if len(output[key]) == 1: + output[key] = output[key][0] + + tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() + if tokens[:6] == r"": # non-leaf nodes + return [output] + self.token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab) + + if output: + return [output] if is_inner_value else output + else: + return [] if is_inner_value else {"text_sequence": tokens} + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + @property + def feature_extractor(self): + warnings.warn( + "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", + FutureWarning, + ) + return self.image_processor + + +__all__ = ["DonutProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dots1/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dots1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60223e4df87f1d925c94f4737c215e904a47ac36 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dots1/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dots1 import * + from .modeling_dots1 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/dots1/configuration_dots1.py b/venv/lib/python3.13/site-packages/transformers/models/dots1/configuration_dots1.py new file mode 100644 index 0000000000000000000000000000000000000000..c8596ddc3828b70d28df3d115bb46558d49ab210 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dots1/configuration_dots1.py @@ -0,0 +1,211 @@ +# coding=utf-8 +# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Dots1Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dots1Model`]. It is used to instantiate a + `dots.llm1` model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + [rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `input_ids` passed when calling [`Dots1Model`]. + hidden_size (`int`, *optional*, defaults to 4608): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 10944): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 62): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + Number of key/value heads for Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, Multi + Head Attention (MHA) is used. If `num_key_value_heads=1`, Multi Query Attention (MQA) is used. Otherwise, + Grouped Query Attention (GQA) is used. If not specified, defaults to `num_attention_heads`. + n_shared_experts (`int`, *optional*, default=None): + Number of shared experts. None means dense model. + n_routed_experts (`int`, *optional*, default=None): + Number of routed experts. None means dense model. + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 1): + Number of selected groups for each token (selected experts only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, default=None): + Number of selected experts. None means dense model. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers at the beginning of the model before the first MoE layer. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the weights of the routed experts. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string). + max_position_embeddings (`int`, *optional*, defaults to 2048): + Maximum sequence length the model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + Epsilon used by the RMS normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. Only relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the input and output word embeddings. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + Dictionary for scaling RoPE embeddings. Supports `{"type": strategy name, "factor": scaling factor}`. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the self-attention projections. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout ratio for the attention probabilities. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for routed experts. + sliding_window (`int`, *optional*, defaults to 4096): + Size of the sliding window for attention. If not specified, defaults to `4096`. + max_window_layers (`int`, *optional*, defaults to 62): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + layer_types (`list`, *optional*): + Attention pattern for each layer. + + Examples: + ```python + >>> from transformers import Dots1Model, Dots1Config + + >>> # Initializing a Dots1 style configuration + >>> configuration = Dots1Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dots1" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "local_colwise", + "layers.*.mlp.experts.*.up_proj": "local_colwise", + "layers.*.mlp.experts.*.down_proj": "local_rowwise", + "layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list + "layers.*.mlp.shared_experts.gate_proj": "local_colwise", + "layers.*.mlp.shared_experts.up_proj": "local_colwise", + "layers.*.mlp.shared_experts.down_proj": "local_rowwise", + "layers.*.mlp.shared_experts": "local", + "layers.*.mlp.gate_proj": "local_colwise", + "layers.*.mlp.up_proj": "local_colwise", + "layers.*.mlp.down_proj": "local_rowwise", + "layers.*.mlp": "gather", # This is the only moment where results are gathered + } + + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=4608, + intermediate_size=10944, + moe_intermediate_size=1408, + num_hidden_layers=62, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + n_group=1, + topk_group=1, + num_experts_per_tok=None, + first_k_dense_replace=0, + norm_topk_prob=False, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + routed_scaling_factor=1.0, + sliding_window=4096, + max_window_layers=62, + layer_types=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.n_group = n_group + self.topk_group = topk_group + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.routed_scaling_factor = routed_scaling_factor + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Dots1Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dots1/modeling_dots1.py b/venv/lib/python3.13/site-packages/transformers/models/dots1/modeling_dots1.py new file mode 100644 index 0000000000000000000000000000000000000000..7a741543800e20245a8eca8d295a2c1d62ff2ca9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dots1/modeling_dots1.py @@ -0,0 +1,612 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dots1/modular_dots1.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dots1.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_dots1 import Dots1Config + + +@use_kernel_forward_from_hub("RMSNorm") +class Dots1RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Dots1RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Dots1RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Dots1Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Dots1Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Dots1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Dots1MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Dots1MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = nn.ModuleList( + [Dots1MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)] + ) + self.gate = Dots1TopkRouter(config) + self.shared_experts = Dots1MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class Dots1TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +class Dots1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Dots1Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Dots1Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = Dots1MoE(config) + else: + self.mlp = Dots1MLP(config) + + self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Dots1PreTrainedModel(PreTrainedModel): + config: Dots1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Dots1DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Dots1DecoderLayer, + "attentions": Dots1Attention, + } + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Dots1TopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class Dots1Model(Dots1PreTrainedModel): + def __init__(self, config: Dots1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Dots1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Dots1RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Dots1Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Dots1ForCausalLM + + >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst") + >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/dots1/modular_dots1.py b/venv/lib/python3.13/site-packages/transformers/models/dots1/modular_dots1.py new file mode 100644 index 0000000000000000000000000000000000000000..345265a14080c7e8978c37b34124e208536b1da2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/dots1/modular_dots1.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...modeling_outputs import CausalLMOutputWithPast +from ...processing_utils import Unpack +from ...utils import logging +from ..deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3DecoderLayer, + DeepseekV3MLP, + DeepseekV3MoE, + DeepseekV3PreTrainedModel, + DeepseekV3TopkRouter, +) +from ..qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3ForCausalLM, + Qwen3Model, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + TransformersKwargs, +) +from .configuration_dots1 import Dots1Config + + +logger = logging.get_logger(__name__) + + +class Dots1RMSNorm(Qwen3RMSNorm): + pass + + +class Dots1RotaryEmbedding(Qwen3RotaryEmbedding): + pass + + +class Dots1Attention(Qwen3Attention): + pass + + +class Dots1MLP(DeepseekV3MLP): + pass + + +class Dots1MoE(DeepseekV3MoE): + pass + + +class Dots1TopkRouter(DeepseekV3TopkRouter): + pass + + +class Dots1DecoderLayer(DeepseekV3DecoderLayer): + def __init__(self, config: Dots1Config, layer_idx: int): + super().__init__(config, layer_idx) + self.attention_type = config.layer_types[layer_idx] + + +class Dots1PreTrainedModel(DeepseekV3PreTrainedModel): + pass + + +class Dots1Model(Qwen3Model): + pass + + +class Dots1ForCausalLM(Qwen3ForCausalLM): + def forward( + self, + **super_kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Dots1ForCausalLM + + >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst") + >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward(**super_kwargs) + + +__all__ = [ + "Dots1PreTrainedModel", + "Dots1Model", + "Dots1ForCausalLM", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/efficientnet/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24d58e81167ec8729d04fa52ce96ebc1737a5982 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_efficientnet import * + from .image_processing_efficientnet import * + from .image_processing_efficientnet_fast import * + from .modeling_efficientnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/efficientnet/configuration_efficientnet.py b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/configuration_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..549311902903357604bcf09bd28ddb391d007793 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/configuration_efficientnet.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EfficientNet model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class EfficientNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EfficientNetModel`]. It is used to instantiate an + EfficientNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the EfficientNet + [google/efficientnet-b7](https://huggingface.co/google/efficientnet-b7) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 600): + The input image size. + width_coefficient (`float`, *optional*, defaults to 2.0): + Scaling coefficient for network width at each stage. + depth_coefficient (`float`, *optional*, defaults to 3.1): + Scaling coefficient for network depth at each stage. + depth_divisor `int`, *optional*, defaults to 8): + A unit of network width. + kernel_sizes (`list[int]`, *optional*, defaults to `[3, 3, 5, 3, 5, 5, 3]`): + List of kernel sizes to be used in each block. + in_channels (`list[int]`, *optional*, defaults to `[32, 16, 24, 40, 80, 112, 192]`): + List of input channel sizes to be used in each block for convolutional layers. + out_channels (`list[int]`, *optional*, defaults to `[16, 24, 40, 80, 112, 192, 320]`): + List of output channel sizes to be used in each block for convolutional layers. + depthwise_padding (`list[int]`, *optional*, defaults to `[]`): + List of block indices with square padding. + strides (`list[int]`, *optional*, defaults to `[1, 2, 2, 2, 1, 2, 1]`): + List of stride sizes to be used in each block for convolutional layers. + num_block_repeats (`list[int]`, *optional*, defaults to `[1, 2, 2, 3, 3, 4, 1]`): + List of the number of times each block is to repeated. + expand_ratios (`list[int]`, *optional*, defaults to `[1, 6, 6, 6, 6, 6, 6]`): + List of scaling coefficient of each block. + squeeze_expansion_ratio (`float`, *optional*, defaults to 0.25): + Squeeze expansion ratio. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`, + `"selu", `"gelu_new"`, `"silu"` and `"mish"` are supported. + hidden_dim (`int`, *optional*, defaults to 1280): + The hidden dimension of the layer before the classification head. + pooling_type (`str` or `function`, *optional*, defaults to `"mean"`): + Type of final pooling to be applied before the dense classification head. Available options are [`"mean"`, + `"max"`] + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + batch_norm_eps (`float`, *optional*, defaults to 1e-3): + The epsilon used by the batch normalization layers. + batch_norm_momentum (`float`, *optional*, defaults to 0.99): + The momentum used by the batch normalization layers. + dropout_rate (`float`, *optional*, defaults to 0.5): + The dropout rate to be applied before final classifier layer. + drop_connect_rate (`float`, *optional*, defaults to 0.2): + The drop rate for skip connections. + + Example: + ```python + >>> from transformers import EfficientNetConfig, EfficientNetModel + + >>> # Initializing a EfficientNet efficientnet-b7 style configuration + >>> configuration = EfficientNetConfig() + + >>> # Initializing a model (with random weights) from the efficientnet-b7 style configuration + >>> model = EfficientNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "efficientnet" + + def __init__( + self, + num_channels: int = 3, + image_size: int = 600, + width_coefficient: float = 2.0, + depth_coefficient: float = 3.1, + depth_divisor: int = 8, + kernel_sizes: list[int] = [3, 3, 5, 3, 5, 5, 3], + in_channels: list[int] = [32, 16, 24, 40, 80, 112, 192], + out_channels: list[int] = [16, 24, 40, 80, 112, 192, 320], + depthwise_padding: list[int] = [], + strides: list[int] = [1, 2, 2, 2, 1, 2, 1], + num_block_repeats: list[int] = [1, 2, 2, 3, 3, 4, 1], + expand_ratios: list[int] = [1, 6, 6, 6, 6, 6, 6], + squeeze_expansion_ratio: float = 0.25, + hidden_act: str = "swish", + hidden_dim: int = 2560, + pooling_type: str = "mean", + initializer_range: float = 0.02, + batch_norm_eps: float = 0.001, + batch_norm_momentum: float = 0.99, + dropout_rate: float = 0.5, + drop_connect_rate: float = 0.2, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_channels = num_channels + self.image_size = image_size + self.width_coefficient = width_coefficient + self.depth_coefficient = depth_coefficient + self.depth_divisor = depth_divisor + self.kernel_sizes = kernel_sizes + self.in_channels = in_channels + self.out_channels = out_channels + self.depthwise_padding = depthwise_padding + self.strides = strides + self.num_block_repeats = num_block_repeats + self.expand_ratios = expand_ratios + self.squeeze_expansion_ratio = squeeze_expansion_ratio + self.hidden_act = hidden_act + self.hidden_dim = hidden_dim + self.pooling_type = pooling_type + self.initializer_range = initializer_range + self.batch_norm_eps = batch_norm_eps + self.batch_norm_momentum = batch_norm_momentum + self.dropout_rate = dropout_rate + self.drop_connect_rate = drop_connect_rate + self.num_hidden_layers = sum(num_block_repeats) * 4 + + +class EfficientNetOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 + + +__all__ = ["EfficientNetConfig", "EfficientNetOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/efficientnet/image_processing_efficientnet.py b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/image_processing_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ea822d75ca27a8acff86a6049bbc1c79a1c3f8ca --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/image_processing_efficientnet.py @@ -0,0 +1,369 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for EfficientNet.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import rescale, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +class EfficientNetImageProcessor(BaseImageProcessor): + r""" + Constructs a EfficientNet image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in `preprocess`. + size (`dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`): + Size of the image after `resize`. Can be overridden by `size` in `preprocess`. + resample (`PILImageResampling` filter, *optional*, defaults to 0): + Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`. + do_center_crop (`bool`, *optional*, defaults to `False`): + Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image + is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`. + crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 289, "width": 289}`): + Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + rescale_offset (`bool`, *optional*, defaults to `False`): + Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range]. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + include_top (`bool`, *optional*, defaults to `True`): + Whether to rescale the image again. Should be set to True if the inputs are used for image classification. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PIL.Image.NEAREST, + do_center_crop: bool = False, + crop_size: Optional[dict[str, int]] = None, + rescale_factor: Union[int, float] = 1 / 255, + rescale_offset: bool = False, + do_rescale: bool = True, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + include_top: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 346, "width": 346} + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else {"height": 289, "width": 289} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.rescale_offset = rescale_offset + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.include_top = include_top + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.NEAREST + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.NEAREST, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.NEAREST`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.NEAREST`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def rescale( + self, + image: np.ndarray, + scale: Union[int, float], + offset: bool = True, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Rescale an image by a scale factor. + + If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is + 1/127.5, the image is rescaled between [-1, 1]. + image = image * scale - 1 + + If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1]. + image = image * scale + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`int` or `float`): + Scale to apply to the image. + offset (`bool`, *optional*): + Whether to scale the image in both negative and positive directions. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + rescaled_image = rescale( + image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + if offset: + rescaled_image = rescaled_image - 1 + + return rescaled_image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample=None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + rescale_offset: Optional[bool] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + include_top: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after `resize`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to + `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be + padded with zeros and then cropped + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`): + Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range]. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + include_top (`bool`, *optional*, defaults to `self.include_top`): + Rescales the image again for image classification if set to True. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - `None`: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + rescale_offset = rescale_offset if rescale_offset is not None else self.rescale_offset + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + include_top = include_top if include_top is not None else self.include_top + + size = size if size is not None else self.size + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale( + image=image, scale=rescale_factor, offset=rescale_offset, input_data_format=input_data_format + ) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + if include_top: + images = [ + self.normalize(image=image, mean=0, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["EfficientNetImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/efficientnet/image_processing_efficientnet_fast.py b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/image_processing_efficientnet_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..77e787614a10f204e2ed1716595c38b7a7ba0a6f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/image_processing_efficientnet_fast.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for EfficientNet.""" + +from functools import lru_cache +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) + + +class EfficientNetFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + Args: + rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`): + Whether to rescale the image between [-max_range/2, scale_range/2] instead of [0, scale_range]. + include_top (`bool`, *optional*, defaults to `self.include_top`): + Normalize the image again with the standard deviation only for image classification if set to True. + """ + + rescale_offset: bool + include_top: bool + + +@auto_docstring +class EfficientNetImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.NEAREST + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 346, "width": 346} + crop_size = {"height": 289, "width": 289} + do_resize = True + do_center_crop = False + do_rescale = True + rescale_factor = 1 / 255 + rescale_offset = False + do_normalize = True + include_top = True + valid_kwargs = EfficientNetFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def rescale( + self, + image: "torch.Tensor", + scale: float, + offset: Optional[bool] = True, + **kwargs, + ) -> "torch.Tensor": + """ + Rescale an image by a scale factor. + + If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is + 1/127.5, the image is rescaled between [-1, 1]. + image = image * scale - 1 + + If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1]. + image = image * scale + + Args: + image (`torch.Tensor`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + offset (`bool`, *optional*): + Whether to scale the image in both negative and positive directions. + + Returns: + `torch.Tensor`: The rescaled image. + """ + + rescaled_image = image * scale + + if offset: + rescaled_image -= 1 + + return rescaled_image + + @lru_cache(maxsize=10) + def _fuse_mean_std_and_rescale_factor( + self, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + device: Optional["torch.device"] = None, + rescale_offset: Optional[bool] = False, + ) -> tuple: + if do_rescale and do_normalize and not rescale_offset: + # Fused rescale and normalize + image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) + image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) + do_rescale = False + return image_mean, image_std, do_rescale + + def rescale_and_normalize( + self, + images: "torch.Tensor", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, list[float]], + image_std: Union[float, list[float]], + rescale_offset: bool = False, + ) -> "torch.Tensor": + """ + Rescale and normalize images. + """ + image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor( + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + device=images.device, + rescale_offset=rescale_offset, + ) + # if/elif as we use fused rescale and normalize if both are set to True + if do_rescale: + images = self.rescale(images, rescale_factor, rescale_offset) + if do_normalize: + images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std) + + return images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + rescale_offset: bool, + do_normalize: bool, + include_top: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std, rescale_offset + ) + if include_top: + stacked_images = self.normalize(stacked_images, 0, image_std) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + +__all__ = ["EfficientNetImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/efficientnet/modeling_efficientnet.py b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/modeling_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a263ff20760c323739220c8d155eceaff7419409 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/efficientnet/modeling_efficientnet.py @@ -0,0 +1,561 @@ +# coding=utf-8 +# Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch EfficientNet model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_efficientnet import EfficientNetConfig + + +logger = logging.get_logger(__name__) + + +def round_filters(config: EfficientNetConfig, num_channels: int): + r""" + Round number of filters based on depth multiplier. + """ + divisor = config.depth_divisor + num_channels *= config.width_coefficient + new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor) + + # Make sure that round down does not go down by more than 10%. + if new_dim < 0.9 * num_channels: + new_dim += divisor + + return int(new_dim) + + +def correct_pad(kernel_size: Union[int, tuple], adjust: bool = True): + r""" + Utility function to get the tuple padding value for the depthwise convolution. + + Args: + kernel_size (`int` or `tuple`): + Kernel size of the convolution layers. + adjust (`bool`, *optional*, defaults to `True`): + Adjusts padding value to apply to right and bottom sides of the input. + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + if adjust: + return (correct[1] - 1, correct[1], correct[0] - 1, correct[0]) + else: + return (correct[1], correct[1], correct[0], correct[0]) + + +class EfficientNetEmbeddings(nn.Module): + r""" + A module that corresponds to the stem module of the original work. + """ + + def __init__(self, config: EfficientNetConfig): + super().__init__() + + self.out_dim = round_filters(config, 32) + self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1)) + self.convolution = nn.Conv2d( + config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False + ) + self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + features = self.padding(pixel_values) + features = self.convolution(features) + features = self.batchnorm(features) + features = self.activation(features) + + return features + + +class EfficientNetDepthwiseConv2d(nn.Conv2d): + def __init__( + self, + in_channels, + depth_multiplier=1, + kernel_size=3, + stride=1, + padding=0, + dilation=1, + bias=True, + padding_mode="zeros", + ): + out_channels = in_channels * depth_multiplier + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + padding_mode=padding_mode, + ) + + +class EfficientNetExpansionLayer(nn.Module): + r""" + This corresponds to the expansion phase of each block in the original implementation. + """ + + def __init__(self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int): + super().__init__() + self.expand_conv = nn.Conv2d( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + padding="same", + bias=False, + ) + self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps) + self.expand_act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + # Expand phase + hidden_states = self.expand_conv(hidden_states) + hidden_states = self.expand_bn(hidden_states) + hidden_states = self.expand_act(hidden_states) + + return hidden_states + + +class EfficientNetDepthwiseLayer(nn.Module): + r""" + This corresponds to the depthwise convolution phase of each block in the original implementation. + """ + + def __init__( + self, + config: EfficientNetConfig, + in_dim: int, + stride: int, + kernel_size: int, + adjust_padding: bool, + ): + super().__init__() + self.stride = stride + conv_pad = "valid" if self.stride == 2 else "same" + padding = correct_pad(kernel_size, adjust=adjust_padding) + + self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding) + self.depthwise_conv = EfficientNetDepthwiseConv2d( + in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False + ) + self.depthwise_norm = nn.BatchNorm2d( + num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.depthwise_act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + # Depthwise convolution + if self.stride == 2: + hidden_states = self.depthwise_conv_pad(hidden_states) + + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.depthwise_norm(hidden_states) + hidden_states = self.depthwise_act(hidden_states) + + return hidden_states + + +class EfficientNetSqueezeExciteLayer(nn.Module): + r""" + This corresponds to the Squeeze and Excitement phase of each block in the original implementation. + """ + + def __init__(self, config: EfficientNetConfig, in_dim: int, expand_dim: int, expand: bool = False): + super().__init__() + self.dim = expand_dim if expand else in_dim + self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio)) + + self.squeeze = nn.AdaptiveAvgPool2d(output_size=1) + self.reduce = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim_se, + kernel_size=1, + padding="same", + ) + self.expand = nn.Conv2d( + in_channels=self.dim_se, + out_channels=self.dim, + kernel_size=1, + padding="same", + ) + self.act_reduce = ACT2FN[config.hidden_act] + self.act_expand = nn.Sigmoid() + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + inputs = hidden_states + hidden_states = self.squeeze(hidden_states) + hidden_states = self.reduce(hidden_states) + hidden_states = self.act_reduce(hidden_states) + + hidden_states = self.expand(hidden_states) + hidden_states = self.act_expand(hidden_states) + hidden_states = torch.mul(inputs, hidden_states) + + return hidden_states + + +class EfficientNetFinalBlockLayer(nn.Module): + r""" + This corresponds to the final phase of each block in the original implementation. + """ + + def __init__( + self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool + ): + super().__init__() + self.apply_dropout = stride == 1 and not id_skip + self.project_conv = nn.Conv2d( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + padding="same", + bias=False, + ) + self.project_bn = nn.BatchNorm2d( + num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.dropout = nn.Dropout(p=drop_rate) + + def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor: + hidden_states = self.project_conv(hidden_states) + hidden_states = self.project_bn(hidden_states) + + if self.apply_dropout: + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + embeddings + + return hidden_states + + +class EfficientNetBlock(nn.Module): + r""" + This corresponds to the expansion and depthwise convolution phase of each block in the original implementation. + + Args: + config ([`EfficientNetConfig`]): + Model configuration class. + in_dim (`int`): + Number of input channels. + out_dim (`int`): + Number of output channels. + stride (`int`): + Stride size to be used in convolution layers. + expand_ratio (`int`): + Expand ratio to set the output dimensions for the expansion and squeeze-excite layers. + kernel_size (`int`): + Kernel size for the depthwise convolution layer. + drop_rate (`float`): + Dropout rate to be used in the final phase of each block. + id_skip (`bool`): + Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase + of each block. Set to `True` for the first block of each stage. + adjust_padding (`bool`): + Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution + operation, set to `True` for inputs with odd input sizes. + """ + + def __init__( + self, + config: EfficientNetConfig, + in_dim: int, + out_dim: int, + stride: int, + expand_ratio: int, + kernel_size: int, + drop_rate: float, + id_skip: bool, + adjust_padding: bool, + ): + super().__init__() + self.expand_ratio = expand_ratio + self.expand = self.expand_ratio != 1 + expand_in_dim = in_dim * expand_ratio + + if self.expand: + self.expansion = EfficientNetExpansionLayer( + config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride + ) + + self.depthwise_conv = EfficientNetDepthwiseLayer( + config=config, + in_dim=expand_in_dim if self.expand else in_dim, + stride=stride, + kernel_size=kernel_size, + adjust_padding=adjust_padding, + ) + self.squeeze_excite = EfficientNetSqueezeExciteLayer( + config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand + ) + self.projection = EfficientNetFinalBlockLayer( + config=config, + in_dim=expand_in_dim if self.expand else in_dim, + out_dim=out_dim, + stride=stride, + drop_rate=drop_rate, + id_skip=id_skip, + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + embeddings = hidden_states + # Expansion and depthwise convolution phase + if self.expand_ratio != 1: + hidden_states = self.expansion(hidden_states) + hidden_states = self.depthwise_conv(hidden_states) + + # Squeeze and excite phase + hidden_states = self.squeeze_excite(hidden_states) + hidden_states = self.projection(embeddings, hidden_states) + return hidden_states + + +class EfficientNetEncoder(nn.Module): + r""" + Forward propagates the embeddings through each EfficientNet block. + + Args: + config ([`EfficientNetConfig`]): + Model configuration class. + """ + + def __init__(self, config: EfficientNetConfig): + super().__init__() + self.config = config + self.depth_coefficient = config.depth_coefficient + + def round_repeats(repeats): + # Round number of block repeats based on depth multiplier. + return int(math.ceil(self.depth_coefficient * repeats)) + + num_base_blocks = len(config.in_channels) + num_blocks = sum(round_repeats(n) for n in config.num_block_repeats) + + curr_block_num = 0 + blocks = [] + for i in range(num_base_blocks): + in_dim = round_filters(config, config.in_channels[i]) + out_dim = round_filters(config, config.out_channels[i]) + stride = config.strides[i] + kernel_size = config.kernel_sizes[i] + expand_ratio = config.expand_ratios[i] + + for j in range(round_repeats(config.num_block_repeats[i])): + id_skip = j == 0 + stride = 1 if j > 0 else stride + in_dim = out_dim if j > 0 else in_dim + adjust_padding = curr_block_num not in config.depthwise_padding + drop_rate = config.drop_connect_rate * curr_block_num / num_blocks + + block = EfficientNetBlock( + config=config, + in_dim=in_dim, + out_dim=out_dim, + stride=stride, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + drop_rate=drop_rate, + id_skip=id_skip, + adjust_padding=adjust_padding, + ) + blocks.append(block) + curr_block_num += 1 + + self.blocks = nn.ModuleList(blocks) + self.top_conv = nn.Conv2d( + in_channels=out_dim, + out_channels=round_filters(config, 1280), + kernel_size=1, + padding="same", + bias=False, + ) + self.top_bn = nn.BatchNorm2d( + num_features=config.hidden_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.top_activation = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.FloatTensor, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> BaseModelOutputWithNoAttention: + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for block in self.blocks: + hidden_states = block(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.top_conv(hidden_states) + hidden_states = self.top_bn(hidden_states) + hidden_states = self.top_activation(hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +@auto_docstring +class EfficientNetPreTrainedModel(PreTrainedModel): + config: EfficientNetConfig + base_model_prefix = "efficientnet" + main_input_name = "pixel_values" + _no_split_modules = [] + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +@auto_docstring +class EfficientNetModel(EfficientNetPreTrainedModel): + def __init__(self, config: EfficientNetConfig): + super().__init__(config) + self.config = config + self.embeddings = EfficientNetEmbeddings(config) + self.encoder = EfficientNetEncoder(config) + + # Final pooling layer + if config.pooling_type == "mean": + self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True) + elif config.pooling_type == "max": + self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True) + else: + raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}") + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # Apply pooling + last_hidden_state = encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + # Reshape (batch_size, 1280, 1 , 1) -> (batch_size, 1280) + pooled_output = pooled_output.reshape(pooled_output.shape[:2]) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. + for ImageNet. + """ +) +class EfficientNetForImageClassification(EfficientNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.efficientnet = EfficientNetModel(config) + # Classifier head + self.dropout = nn.Dropout(p=config.dropout_rate) + self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.efficientnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +__all__ = ["EfficientNetForImageClassification", "EfficientNetModel", "EfficientNetPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/emu3/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/emu3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8555f58d1866451c38abb5559ef5bef9545f0b0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/emu3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_emu3 import * + from .image_processing_emu3 import * + from .modeling_emu3 import * + from .processing_emu3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/emu3/configuration_emu3.py b/venv/lib/python3.13/site-packages/transformers/models/emu3/configuration_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..79407bfbf0d73cfac40297ccfa472f5e451f352d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/emu3/configuration_emu3.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Emu3VQVAEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3VQVAE`]. It is used to instantiate an VQ-VAE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a configuration to the VQ model presented in Emu3 paper. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + codebook_size (`int`, *optional*, defaults to 32768): + Codebook size of the VQ model. + embed_dim (`int`, *optional*, defaults to 4): + Dimension of the quantized vector in codebook. + latent_channels (`int`, *optional*, defaults to 4): + Dimension of the output channel of encoder and the input channel of decoder + double_latent (`bool`, *optional*, defaults to `False`): + Whether double the output dim of the encoder. + in_channels (`int`, *optional*, defaults to 3): + Input channel of encoder. + out_channels (`int`, *optional*, defaults to 3): + Output channel of decoder. + temporal_downsample_factor (`int`, *optional*, defaults to 4): + Temporal downsample factor. + base_channels (`int`, *optional*, defaults to 256): + Basic channel number of the intermediate blocks. + channel_multiplier (`list[int]`, *optional*, defaults to `[1, 2, 2, 4]`): + Channel scaling factor of the intermediate blocks. + num_res_blocks (`int`, *optional*, defaults to 2): + Residual block number in each stage. + attn_resolutions (`list[int]`, *optional*, defaults to `[3]`): + Stage indices to apply attention. + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations in the attention layer. + num_attention_heads (`int`, *optional*, defaults to 1): + Number of attention heads for each attention layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Emu3VQVAE, Emu3VQVAEConfig + + >>> # Initializing a video VQ model of Emu3 configuration + >>> configuration = Emu3VQVAEConfig() + + >>> # Initializing a model from the Emu3 VQ model style configuration + >>> model = Emu3VQVAE(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_vqgan" + base_config_key = "vq_config" + + def __init__( + self, + codebook_size: int = 32768, + embed_dim: int = 4, + latent_channels: int = 4, + double_latent: bool = False, + in_channels: int = 3, + out_channels: int = 3, + temporal_downsample_factor: int = 4, + base_channels: int = 256, + channel_multiplier: list[int] = [1, 2, 2, 4], + num_res_blocks: int = 2, + attn_resolutions: list[int] = [3], + hidden_size: int = 1024, + num_attention_heads: int = 1, + attention_dropout: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.codebook_size = codebook_size + self.embed_dim = embed_dim + self.latent_channels = latent_channels + self.double_latent = double_latent + self.in_channels = in_channels + self.out_channels = out_channels + self.temporal_downsample_factor = temporal_downsample_factor + self.base_channels = base_channels + self.channel_multiplier = channel_multiplier + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + + +class Emu3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3TextModel`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 184622): + Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Emu3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 9216): + The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens, + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 151643): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 151849): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 151850): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + ```python + >>> from transformers import Emu3Model, Emu3Config + + >>> # Initializing a Emu3-community/Emu3-Chat-hf style configuration + >>> configuration = Emu3Config() + + >>> # Initializing a model from the Emu3-community/Emu3-Chat-hf style configuration + >>> model = Emu3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_text_model" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 184622, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = 8, + hidden_act: str = "silu", + max_position_embeddings: int = 9216, + rms_norm_eps: float = 1e-5, + use_cache: bool = True, + pad_token_id: int = 151643, + bos_token_id: int = 151849, + eos_token_id: int = 151850, + tie_word_embeddings: bool = False, + rope_theta: float = 1000000.0, + rope_scaling: Optional[dict[str, Any]] = None, + mlp_bias=False, + attention_bias=False, + attention_dropout: float = 0.1, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.mlp_bias = mlp_bias + self.attention_bias = attention_bias + self.initializer_range = initializer_range + rope_config_validation(self) + + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Emu3Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vq_config (`Union[Dict, Emu3VQVAEConfig]`, *optional*): + Emu3VQVAEConfig instance containing the configuration for the VQ-VAE model. + text_config (`Union[Dict, Emu3TextConfig]``, *optional*): + Emu3TextConfig instance containing the configuration for the language model. + vocabulary_map (`dict`, *optional*): + A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. + """ + + model_type = "emu3" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"text_config": Emu3TextConfig, "vq_config": Emu3VQVAEConfig} + + def __init__( + self, + vq_config: Union[dict, Emu3VQVAEConfig] = None, + text_config: Union[dict, Emu3TextConfig] = None, + vocabulary_map: Optional[dict[int, int]] = None, + **kwargs, + ): + if vq_config is None: + vq_config = Emu3VQVAEConfig() + elif isinstance(vq_config, dict): + vq_config = Emu3VQVAEConfig(**vq_config) + + if text_config is None: + text_config = Emu3TextConfig() + elif isinstance(text_config, dict): + text_config = Emu3TextConfig(**text_config) + + self.vq_config = vq_config + self.text_config = text_config + self.vocabulary_map = vocabulary_map + self.image_token_id = vocabulary_map.get("") if vocabulary_map is not None else None + + super().__init__(**kwargs) + + +__all__ = ["Emu3Config", "Emu3TextConfig", "Emu3VQVAEConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/emu3/image_processing_emu3.py b/venv/lib/python3.13/site-packages/transformers/models/emu3/image_processing_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..50ce82e01de8ae5f17b009e40b1c28071e7537fb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/emu3/image_processing_emu3.py @@ -0,0 +1,529 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + make_nested_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +if is_vision_available(): + from PIL import Image + +logger = logging.get_logger(__name__) + + +def smart_resize( + height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class Emu3ImageProcessor(BaseImageProcessor): + r""" + Constructs a Emu3 image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + min_pixels (`int`, *optional*, defaults to `512 * 512`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `1024 * 1024`): + The max pixels of the image to resize the image. + spatial_factor (`int`, *optional*, defaults to 8): + The spatial downsample factor the image will be downsampled in feature extracting phase + """ + + model_input_names = ["pixel_values", "image_sizes"] + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + do_pad: bool = True, + min_pixels: int = 512 * 512, + max_pixels: int = 1024 * 1024, + spatial_factor: int = 8, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.spatial_factor = spatial_factor + self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`list[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_flat_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.spatial_factor, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + images = np.array(processed_images) + return images + + def _pad_for_batching( + self, + pixel_values: list[np.ndarray], + image_sizes: list[list[int]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + + Args: + pixel_values (`list[np.ndarray]`): + An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`) + image_sizes (`list[list[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + list[`np.ndarray`]: The padded images. + """ + + max_shape = ( + max(size[0] for size in image_sizes), + max(size[1] for size in image_sizes), + ) + pixel_values = [ + pad( + image, + padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])), + data_format=data_format, + input_data_format=input_data_format, + ) + for image, size in zip(pixel_values, image_sizes) + ] + return pixel_values + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + do_pad: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + + if images is not None: + images = self.fetch_images(images) + images = make_nested_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + pixel_values = [] + for image in images: + if image: + image = self._preprocess( + image, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(image) + + image_sizes = [image.shape[-2:] for image in pixel_values] + if do_pad: + pixel_values = self._pad_for_batching(pixel_values, image_sizes) + pixel_values = np.array(pixel_values) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": image_sizes}, tensor_type=return_tensors + ) + + def postprocess( + self, + images: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Union[str, TensorType] = "PIL.Image.Image", + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess. + The parameters should be same as in preprocess. + Args: + images (`ImageInput`): + Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_flat_list_of_images(images) + if isinstance(images[0], Image.Image): + return images if len(images) > 1 else images[0] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_values = [] + for image in images: + image = to_numpy_array(image) + if do_normalize: + image = self.unnormalize( + image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + image = image.clip(0, 255).astype(np.uint8) + + if do_normalize and do_rescale and return_tensors == "PIL.Image.Image": + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format) + pixel_values.append(Image.fromarray(image)) + else: + pixel_values.extend(image) + + data = {"pixel_values": pixel_values} + return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None + + return BatchFeature(data=data, tensor_type=return_tensors) + + def unnormalize( + self, + image: np.ndarray, + image_mean: Union[float, Iterable[float]], + image_std: Union[float, Iterable[float]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`. + image = (image * image_std) + image_mean + Args: + image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`): + Batch of pixel values to postprocess. + image_mean (`float` or `Iterable[float]`): + The mean to use for unnormalization. + image_std (`float` or `Iterable[float]`): + The standard deviation to use for unnormalization. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + num_channels = 3 + + if isinstance(image_mean, Iterable): + if len(image_mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}") + else: + image_mean = [image_mean] * num_channels + + if isinstance(image_std, Iterable): + if len(image_std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}") + else: + image_std = [image_std] * num_channels + + rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std)) + rev_image_std = tuple(1 / std for std in image_std) + image = self.normalize( + image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format + ) + return image + + +__all__ = ["Emu3ImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/emu3/modeling_emu3.py b/venv/lib/python3.13/site-packages/transformers/models/emu3/modeling_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..ad070efc1d3e18bf773394411a4042e8a2641e8a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/emu3/modeling_emu3.py @@ -0,0 +1,1638 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/emu3/modular_emu3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_emu3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Emu3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Emu3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Emu3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Emu3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Emu3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Emu3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Emu3MLP(config) + self.input_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dropout = nn.Dropout(config.attention_dropout) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, hidden_states): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: tuple[int], + stride: tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + # for compatibility with the attention interface + self.num_key_value_groups = 1 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv + Taigman](https://huggingface.co/papers/2203.13131). + """ +) +class Emu3VQVAE(PreTrainedModel): + config: Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: list[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +@auto_docstring +class Emu3PreTrainedModel(PreTrainedModel): + config: Emu3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_param_buffer_assignment = False + _supports_flex_attn = True + _supports_attention_backend = True + + +class Emu3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Emu3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Emu3TextModel(Emu3PreTrainedModel): + _can_record_outputs = { + "hidden_states": Emu3DecoderLayer, + "attentions": Emu3Attention, + } + + def __init__(self, config: Emu3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Emu3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config: Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Emu3Model(Emu3PreTrainedModel): + _checkpoint_conversion_mapping = {"text_model.model": "text_model"} + + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3TextModel._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.text_model = decoder + + def get_decoder(self): + return self.text_model + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module and embeds + them with text embeddings layer + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + split_sizes = [ + (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1) + for height, width in image_sizes + ] + image_features = self.get_input_embeddings()(image_tokens) + image_features = torch.split(image_features, split_sizes) + return image_features + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_sizes) + image_embeds = torch.cat(image_embeds, dim=0) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return outputs + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] + _checkpoint_conversion_mapping = { + "^text_model.model": "model.text_model", + "^vqmodel": "model.vqmodel", + "^text_model.lm_head": "lm_head", + } + + def __init__(self, config): + super().__init__(config) + self.model = Emu3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + # Make modules available through conditional class for BC + @property + def text_model(self): + return self.model.text_model + + @property + def vqmodel(self): + return self.model.vqmodel + + @property + def vocabulary_mapping(self): + return self.model.vocabulary_mapping + + def decode_image_tokens(self, **kwargs): + return self.model.decode_image_tokens(**kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + + +__all__ = [ + "Emu3ForConditionalGeneration", + "Emu3ForCausalLM", + "Emu3TextModel", + "Emu3PreTrainedModel", + "Emu3VQVAE", + "Emu3Model", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/emu3/modular_emu3.py b/venv/lib/python3.13/site-packages/transformers/models/emu3/modular_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..32599727b24c12dadc5dd237b63bcebcaecd1316 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/emu3/modular_emu3.py @@ -0,0 +1,1222 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ..chameleon.modeling_chameleon import ( + ChameleonPreTrainedModel, + ChameleonVQVAEEncoderConvDownsample, +) +from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, TransformersKwargs +from ..siglip.modeling_siglip import SiglipAttention +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +logger = logging.get_logger(__name__) + + +class Emu3Attention(LlamaAttention): + pass + + +# Has extra dropout which no other model in the library has +class Emu3DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.dropout = nn.Dropout(config.attention_dropout) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(ChameleonVQVAEEncoderConvDownsample): + pass + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: tuple[int], + stride: tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(SiglipAttention): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + # for compatibility with the attention interface + self.num_key_value_groups = 1 + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv + Taigman](https://huggingface.co/papers/2203.13131). + """ +) +class Emu3VQVAE(PreTrainedModel): + config: Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: list[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _supports_flex_attn = True + _supports_attention_backend = True + + +class Emu3TextModel(LlamaModel, Emu3PreTrainedModel): + _can_record_outputs = { + "hidden_states": Emu3DecoderLayer, + "attentions": Emu3Attention, + } + + def __init__(self, config: Emu3Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + +class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): + config: Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + super().forward() + + +class Emu3Model(Emu3PreTrainedModel): + _checkpoint_conversion_mapping = {"text_model.model": "text_model"} + + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3TextModel._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.text_model = decoder + + def get_decoder(self): + return self.text_model + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module and embeds + them with text embeddings layer + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + split_sizes = [ + (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1) + for height, width in image_sizes + ] + image_features = self.get_input_embeddings()(image_tokens) + image_features = torch.split(image_features, split_sizes) + return image_features + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_sizes) + image_embeds = torch.cat(image_embeds, dim=0) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return outputs + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] + _checkpoint_conversion_mapping = { + "^text_model.model": "model.text_model", + "^vqmodel": "model.vqmodel", + "^text_model.lm_head": "lm_head", + } + + def __init__(self, config): + super().__init__(config) + self.model = Emu3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + # Make modules available through conditional class for BC + @property + def text_model(self): + return self.model.text_model + + @property + def vqmodel(self): + return self.model.vqmodel + + @property + def vocabulary_mapping(self): + return self.model.vocabulary_mapping + + def decode_image_tokens(self, **kwargs): + return self.model.decode_image_tokens(**kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + + +__all__ = [ + "Emu3ForConditionalGeneration", + "Emu3ForCausalLM", + "Emu3TextModel", + "Emu3PreTrainedModel", + "Emu3VQVAE", + "Emu3Model", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/emu3/processing_emu3.py b/venv/lib/python3.13/site-packages/transformers/models/emu3/processing_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..67ccab795733563359dbba53e17591cab2878506 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/emu3/processing_emu3.py @@ -0,0 +1,248 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_vision_available + + +if is_vision_available(): + from .image_processing_emu3 import smart_resize + + +class Emu3TextKwargs(TextKwargs, total=False): + return_for_image_generation: bool + + +class Emu3ImagesKwargs(ImagesKwargs, total=False): + ratio: str + image_area: int + + +class Emu3ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: Emu3TextKwargs + images_kwargs: Emu3ImagesKwargs + _defaults = { + "text_kwargs": { + "return_for_image_generation": False, + "return_mm_token_type_ids": False, + }, + "images_kwargs": { + "ratio": "1:1", + "image_area": 518400, + }, + } + + +class Emu3Processor(ProcessorMixin): + r""" + Constructs a Emu3 processor which wraps a Emu3 image processor and a GPT2 tokenizer into a single + processor. + + [`Emu3Processor`] offers all the functionalities of [`Emu3ImageProcessor`] and [`GPT2TokenizerFast`]. + See the [`~Emu3Processor.__call__`] and [`~Emu3Processor.decode`] for more information. + + Args: + image_processor ([`Emu3ImageProcessor`]): + The image processor is a required input. + tokenizer ([`Emu3TokenizerFast`]): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") + image_processor_class = "Emu3ImageProcessor" + + def __init__( + self, + image_processor, + tokenizer, + chat_template=None, + **kwargs, + ): + self.image_token = tokenizer.image_token # image_token as placeholder to be replaced by vq-vae tokens + self.image_token_id = tokenizer.image_token_id + self.image_start_token = tokenizer.boi_token # "<|image start|>" fixed tokens for start and end of image + self.image_end_token = tokenizer.eoi_token # "<|image end|>" + self.fake_token_around_image = tokenizer.image_wrapper_token # "<|image token|>" every image starts with it + self.eof_token = tokenizer.eof_token # "<|extra_201|>" + self.bos_token = tokenizer.bos_token + self.downsample_ratio = 8 + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Emu3ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Emu3TokenizerFast's [`~Emu3TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + # check if images and text inputs are reversed for BC + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + output_kwargs = self._merge_kwargs( + Emu3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_image_generation = output_kwargs["text_kwargs"].pop("return_for_image_generation", False) + ratio = output_kwargs["images_kwargs"].pop("ratio", None) + image_area = output_kwargs["images_kwargs"].pop("image_area", None) + + if return_for_image_generation and images is not None: + raise ValueError("You should not provide `images` when `return_for_image_generation=True`") + + if not return_for_image_generation and text is None and images is None: + raise ValueError("You must provide either text or images when `return_for_image_generation=False`") + + image_features = {} + image_start_tokens = f"{self.image_start_token}" + image_end_tokens = f"{self.eof_token}{self.image_end_token}" + + # generate text from image + text input, so we add placeholders for image tokens + if not return_for_image_generation and images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + image_sizes = iter(image_features.image_sizes) + + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + height, width = image_size + height = height // self.downsample_ratio + width = width // self.downsample_ratio + image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code + + image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'' * image_seq_length}{image_end_tokens}" + sample = sample.replace(self.image_token, image_placeholder, 1) + sample = f"{self.bos_token}{sample}" # add BOS because GPT tokenizer doesn't add it + prompt_strings.append(sample) + text = [sample.replace("", self.image_token) for sample in prompt_strings] + + # generate image from text input, so we add begin-of-image tokens from where image generation starts + elif return_for_image_generation: + height, width = self.calculate_generate_size(ratio, image_area, self.downsample_ratio) + image_prompt = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}" + text = [f"{self.bos_token}{sample}{image_prompt}" for sample in text] + image_features["image_sizes"] = [[height, width]] * len(text) + + # else just generate from text-only input, and we do no special treatment for text + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_features}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + num_image_tokens = [] + for height, width in image_sizes: + height, width = smart_resize( + height, + width, + self.image_processor.spatial_factor, + self.image_processor.min_pixels, + self.image_processor.max_pixels, + ) + height = height // self.downsample_ratio + width = width // self.downsample_ratio + image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code + num_image_tokens.append(image_seq_length) + + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + def calculate_generate_size(self, ratio, image_area, spatial_factor): + width, height = map(int, ratio.split(":")) + current_area = width * height + target_ratio = (image_area / current_area) ** 0.5 + + token_height = int(round(height * target_ratio / spatial_factor)) + token_width = int(round(width * target_ratio / spatial_factor)) + return token_height, token_width + + def postprocess(self, images: ImageInput, **kwargs): + return self.image_processor.postprocess(images, **kwargs) + + +__all__ = ["Emu3Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c786feb9213fdd31640c0fdeaead5164026ad37a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_encoder_decoder import * + from .modeling_encoder_decoder import * + from .modeling_flax_encoder_decoder import * + from .modeling_tf_encoder_decoder import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..af57b2596cee99eefe0493cc4aea51c845036d2e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class EncoderDecoderConfig(PretrainedConfig): + r""" + [`EncoderDecoderConfig`] is the configuration class to store the configuration of a [`EncoderDecoderModel`]. It is + used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder + configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + + - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the encoder config. + - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the decoder config. + + Examples: + + ```python + >>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel + + >>> # Initializing a BERT google-bert/bert-base-uncased style configuration + >>> config_encoder = BertConfig() + >>> config_decoder = BertConfig() + + >>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) + + >>> # Initializing a Bert2Bert model (with random weights) from the google-bert/bert-base-uncased style configurations + >>> model = EncoderDecoderModel(config=config) + + >>> # Accessing the model configuration + >>> config_encoder = model.config.encoder + >>> config_decoder = model.config.decoder + >>> # set decoder config to causal lm + >>> config_decoder.is_decoder = True + >>> config_decoder.add_cross_attention = True + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("my-model") + + >>> # loading model and config from pretrained folder + >>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained("my-model") + >>> model = EncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config) + ```""" + + model_type = "encoder-decoder" + sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} + has_no_defaults_at_init = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "encoder" not in kwargs or "decoder" not in kwargs: + raise ValueError( + f"A configuration of type {self.model_type} cannot be instantiated because " + f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}" + ) + encoder_config = kwargs.pop("encoder") + encoder_model_type = encoder_config.pop("model_type") + decoder_config = kwargs.pop("decoder") + decoder_model_type = decoder_config.pop("model_type") + + self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config) + self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config) + self.is_encoder_decoder = True + + @classmethod + def from_encoder_decoder_configs( + cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs + ) -> PretrainedConfig: + r""" + Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and + decoder model configuration. + + Returns: + [`EncoderDecoderConfig`]: An instance of a configuration object + """ + logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) + + +__all__ = ["EncoderDecoderConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..30e2370b2240d2f2f00182abea548cd6a72b5626 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -0,0 +1,609 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Encoder-Decoder architectures""" + +import gc +import inspect +import os +import tempfile +import warnings +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM +from .configuration_encoder_decoder import EncoderDecoderConfig + + +logger = logging.get_logger(__name__) + + +DEPRECATION_WARNING = ( + "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the" + " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if" + " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the" + " labels, no need to pass them yourself anymore." +) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +@auto_docstring +class EncoderDecoderModel(PreTrainedModel, GenerationMixin): + r""" + [`EncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one + of the base model classes of the library as encoder and another one as decoder when created with the + :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and + :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config: EncoderDecoderConfig + base_model_prefix = "encoder_decoder" + main_input_name = "input_ids" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + _supports_flash_attn = True + _supports_sdpa = True + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + ): + r""" + encoder (`PreTrainedModel`, *optional*): + The encoder model to use. + decoder (`PreTrainedModel`, *optional*): + The decoder model to use. + """ + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") + if config is None: + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # initialize with config + super().__init__(config) + + if encoder is None: + from ..auto.modeling_auto import AutoModel + + encoder = AutoModel.from_config(config.encoder) + + if decoder is None: + from ..auto.modeling_auto import AutoModelForCausalLM + + decoder = AutoModelForCausalLM.from_config(config.decoder) + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" + f" {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + # update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel + self.config.encoder._attn_implementation = self.encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) + + decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) + if "encoder_hidden_states" not in decoder_signature: + raise ValueError( + "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " + "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" + ) + + # tie encoder, decoder weights if config set accordingly + self.tie_weights() + + def tie_weights(self): + self.encoder.tie_weights() + self.decoder.tie_weights() + # tie encoder & decoder if needed + if self.config.tie_encoder_decoder: + # tie encoder and decoder base model + decoder_base_model_prefix = self.decoder.base_model_prefix + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + "encoder", + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights + + def _init_weights(self, module): + if module in self.encoder.modules(): + self.encoder._init_weights(module) + elif module in self.decoder.modules(): + self.decoder._init_weights(module) + + def get_encoder(self): + return self.encoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Example: + + ```python + >>> from transformers import EncoderDecoderModel + + >>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + ```""" + + from_tf = kwargs.pop("from_tf", False) + if from_tf: + from transformers import TFEncoderDecoderModel + + # a workaround to load from tensorflow checkpoint + # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get + # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is + # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The + # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`, + # which should not occur when we want to save the components alone. + # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see + # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245 + # (the change in `src/transformers/modeling_tf_utils.py`) + _tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = _tf_model.config + + # Using `tf_model` instead + encoder = _tf_model.encoder.__class__(_tf_model.config.encoder) + decoder = _tf_model.decoder.__class__(_tf_model.config.decoder) + # Make sure models are built + encoder(encoder.dummy_inputs) + decoder(decoder.dummy_inputs) + + # Get the variable correspondence between `_tf_model` and `encoder` and `decoder` + encoder_variables = {} + for v in encoder.trainable_variables + encoder.non_trainable_variables: + encoder_variables["/".join(v.name.split("/")[1:])] = v + decoder_variables = {} + for v in decoder.trainable_variables + decoder.non_trainable_variables: + decoder_variables["/".join(v.name.split("/")[1:])] = v + + _encoder_variables = {} + for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables: + _encoder_variables["/".join(v.name.split("/")[2:])] = v + _decoder_variables = {} + for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables: + _decoder_variables["/".join(v.name.split("/")[2:])] = v + + # assign weight values to `encoder` and `decoder` from `_tf_model` + for name, v in encoder_variables.items(): + v.assign(_encoder_variables[name]) + for name, v in decoder_variables.items(): + v.assign(_decoder_variables[name]) + + tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) + + # Deal with `enc_to_dec_proj` + if hasattr(_tf_model, "enc_to_dec_proj"): + tf_model(tf_model.dummy_inputs) + tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel) + tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias) + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder_dir = os.path.join(tmpdirname, "encoder") + decoder_dir = os.path.join(tmpdirname, "decoder") + tf_model.encoder.save_pretrained(encoder_dir) + tf_model.decoder.save_pretrained(decoder_dir) + + if hasattr(tf_model, "enc_to_dec_proj"): + enc_to_dec_proj_weight = torch.transpose( + torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0 + ) + enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy()) + + del _tf_model + del tf_model + gc.collect() + + model = EncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True + ) + # This is only for copying some specific attributes of this particular model. + model.config = config + + if hasattr(model, "enc_to_dec_proj"): + model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous() + model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous() + + return model + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[str] = None, + decoder_pretrained_model_name_or_path: Optional[str] = None, + *model_args, + **kwargs, + ) -> PreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you need to first set it back in training mode with `model.train()`. + + Params: + encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import EncoderDecoderModel + + >>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased") + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert2bert") + >>> # load fine-tuned model + >>> model = EncoderDecoderModel.from_pretrained("./bert2bert") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder: + del kwargs["encoder_" + key] + for key in kwargs_decoder: + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and causal mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + return cls(encoder=encoder, decoder=decoder, config=config) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the + right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, + ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Examples: + + ```python + >>> from transformers import EncoderDecoderModel, BertTokenizer + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained( + ... "google-bert/bert-base-uncased", "google-bert/bert-base-uncased" + ... ) # initialize Bert2Bert from pre-trained checkpoints + + >>> # training + >>> model.config.decoder_start_token_id = tokenizer.cls_token_id + >>> model.config.pad_token_id = tokenizer.pad_token_id + >>> model.config.vocab_size = model.config.decoder.vocab_size + + >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids + >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss, logits = outputs.loss, outputs.logits + + >>> # save and load from pretrained + >>> model.save_pretrained("bert2bert") + >>> model = EncoderDecoderModel.from_pretrained("bert2bert") + + >>> # generation + >>> generated = model.generate(input_ids) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + if "num_items_in_batch" in kwargs_encoder: + kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + **kwargs_decoder, + ) + + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + warnings.warn(DEPRECATION_WARNING, FutureWarning) + logits = decoder_outputs.logits if return_dict else decoder_outputs[0] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) + + if not return_dict: + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" + " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" + " model.decoder.resize_token_embeddings(...))" + ) + + +__all__ = ["EncoderDecoderModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4a27c23c3c69ae928c73273c9397d5f5aad2b1c0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -0,0 +1,901 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Flax Encoder-Decoder architectures""" + +import os +from typing import Optional, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput +from ...modeling_flax_utils import FlaxPreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM +from .configuration_encoder_decoder import EncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "EncoderDecoderConfig" + +ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the + encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via + [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`] + function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream + generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models + (see the examples for more information). + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Parameters: + config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.encoder.max_position_embeddings - 1]`. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. +""" + +ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.encoder.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. +""" + +ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a + plain tuple. +""" + + +class FlaxEncoderDecoderModule(nn.Module): + config: EncoderDecoderConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + encoder_config = self.config.encoder + decoder_config = self.config.decoder + + # Copied from `modeling_hybrid_clip.py` with modifications. + from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING + + encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class + decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class + + self.encoder = encoder_module(encoder_config, dtype=self.dtype) + self.decoder = decoder_module(decoder_config, dtype=self.dtype) + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Dense( + self.decoder.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), + dtype=self.dtype, + ) + else: + self.enc_to_dec_proj = None + + def _get_encoder_module(self): + return self.encoder + + def _get_projection_module(self): + return self.enc_to_dec_proj + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if self.enc_to_dec_proj is not None: + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=decoder_outputs.logits, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) +class FlaxEncoderDecoderModel(FlaxPreTrainedModel): + r""" + [`FlaxEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with + the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one as + decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the + encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config_class = EncoderDecoderConfig + base_model_prefix = "encoder_decoder" + module_class = FlaxEncoderDecoderModule + + def __init__( + self, + config: EncoderDecoderConfig, + input_shape: Optional[tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if input_shape is None: + input_shape = ((1, 1), (1, 1)) + + if not _do_init: + raise ValueError( + "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + encoder_input_shape, decoder_input_shape = input_shape + + # init input tensors + input_ids = jnp.zeros(encoder_input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape + if not decoder_batch_size == batch_size: + raise ValueError( + f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder" + f" and {decoder_batch_size} for decoder." + ) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer + + >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + + >>> text = "My friends are cool but they eat too many carbs." + >>> input_ids = tokenizer.encode(text, return_tensors="np") + >>> encoder_outputs = model.encode(input_ids) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + outputs = self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + if return_dict: + outputs = FlaxBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return outputs + + @add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer + >>> import jax.numpy as jnp + + >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + + >>> text = "My friends are cool but they eat too many carbs." + >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(input_ids) + + >>> decoder_start_token_id = model.config.decoder.bos_token_id + >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBartAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward( + module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs + ): + projection_module = module._get_projection_module() + decoder_module = module._get_decoder_module() + + # optionally project encoder_hidden_states + if projection_module is not None: + encoder_hidden_states = projection_module(encoder_hidden_states) + + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer + + >>> # load a fine-tuned bert2gpt2 model + >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") + >>> # load input & output tokenizer + >>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + >>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + + >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members + >>> singing a racist chant. SAE's national chapter suspended the students, + >>> but University of Oklahoma President David Boren took it a step further, + >>> saying the university's affiliation with the fraternity is permanently done.''' + + >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids + + >>> # use GPT2's eos_token as the pad as well as eos token + >>> model.config.eos_token_id = model.config.decoder.eos_token_id + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences + + >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0] + >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members" + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + raise ValueError( + "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must" + " be specified as an input argument." + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": decoder_position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + *model_args, + **kwargs, + ) -> FlaxPreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + Params: + encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import FlaxEncoderDecoderModel + + >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert2gpt2") + >>> # load fine-tuned model + >>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder: + del kwargs["encoder_" + key] + for key in kwargs_decoder: + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and causal mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = FlaxAutoModel.from_pretrained( + encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder + ) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + dtype = kwargs.pop("dtype", jnp.float32) + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + + # init model + model = cls(config, dtype=dtype) + model.params["encoder"] = encoder.params + model.params["decoder"] = decoder.params + + return model + + +__all__ = ["FlaxEncoderDecoderModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5343d200499e1f3b8ba26f8d70924c2999a2fc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -0,0 +1,661 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support TF Encoder-Decoder architectures""" + +from __future__ import annotations + +import inspect +import re +import warnings + +import numpy as np +import tensorflow as tf + +from ...configuration_utils import PretrainedConfig +from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + get_initializer, + keras, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM +from .configuration_encoder_decoder import EncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "EncoderDecoderConfig" + +DEPRECATION_WARNING = ( + "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the" + " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if" + " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the" + " labels, no need to pass them yourself anymore." +) + +ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the + encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via + [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`] + function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream + generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models + (see the examples for more information). + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + Provide for sequence to sequence training to the decoder. Indices can be obtained using + [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*): + This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output + of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `({0})`. + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, + ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: + + - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. + - With a *decoder_* prefix which will be input as `**decoder_kwargs`` for the decoder forward function. +""" + + +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) +class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): + r""" + [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one + of the base model classes of the library as encoder and another one as decoder when created with the + [`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class + method for the decoder. + """ + + config_class = EncoderDecoderConfig + base_model_prefix = "encoder_decoder" + load_weight_prefix = "tf_encoder_decoder_model" + + def __init__( + self, + config: PretrainedConfig | None = None, + encoder: TFPreTrainedModel | None = None, + decoder: TFPreTrainedModel | None = None, + ): + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") + if config is None: + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # initialize with config + super().__init__(config) + + if encoder is None: + encoder = TFAutoModel.from_config(config.encoder, name="encoder") + + if decoder is None: + decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder") + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" + f" {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = keras.layers.Dense( + units=self.decoder.config.hidden_size, + kernel_initializer=get_initializer(config.encoder.initializer_range), + name="enc_to_dec_proj", + ) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) + + decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys()) + if "encoder_hidden_states" not in decoder_signature: + raise ValueError( + "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " + "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" + ) + + def get_encoder(self): + return self.encoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def tf_to_pt_weight_rename(self, tf_weight): + # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models + # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. + # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption + # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's + # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! + + # This override is only needed in the case where we're crossloading weights from PT. However, since weights are + # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. + # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it + # or not. + encoder_model_type = self.config.encoder.model_type + if "encoder" in tf_weight and "decoder" not in tf_weight: + return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),) + else: + return (tf_weight,) + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: str | None = None, + decoder_pretrained_model_name_or_path: str | None = None, + *model_args, + **kwargs, + ) -> TFPreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + Params: + encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case, + `encoder_from_pt` should be set to `True`. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case, + `decoder_from_pt` should be set to `True`. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import TFEncoderDecoderModel + + >>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "openai-community/gpt2") + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert2gpt2") + >>> # load fine-tuned model + >>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder: + del kwargs["encoder_" + key] + for key in kwargs_decoder: + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and causal mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + kwargs_encoder["name"] = "encoder" + kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix + encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + kwargs_decoder["name"] = "decoder" + kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix + decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # Make sure these 2 `keras.Model` have fixed names so `from_pretrained` could load model weights correctly. + if encoder.name != "encoder": + raise ValueError("encoder model must be created with the name `encoder`.") + if decoder.name != "decoder": + raise ValueError("decoder model must be created with the name `decoder`.") + + # instantiate config with corresponding kwargs + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + return cls(encoder=encoder, decoder=decoder, config=config) + + @unpack_inputs + @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import TFEncoderDecoderModel, BertTokenizer + + >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + + >>> # forward + >>> input_ids = tokenizer.encode( + ... "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf" + ... ) # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) + + >>> # training + >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids) + >>> loss, logits = outputs.loss, outputs.logits + + >>> # save and load from pretrained + >>> model.save_pretrained("bert2gpt2") + >>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2") + + >>> # generation + >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # Let the user be responsible for the expected format. + if encoder_outputs is not None: + if return_dict and not isinstance(encoder_outputs, ModelOutput): + raise ValueError( + "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of " + f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`." + ) + + if encoder_outputs is None: + encoder_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "training": training, + } + + # Add arguments to encoder from `kwargs_encoder` + encoder_inputs.update(kwargs_encoder) + + # Handle the case where the inputs are passed as a single dict which contains `labels`. + # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this + # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`). + if "labels" in encoder_inputs: + labels = encoder_inputs.pop("labels") + + # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. + if "decoder_input_ids" in encoder_inputs: + decoder_input_ids = encoder_inputs.pop("decoder_input_ids") + # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. + if "decoder_attention_mask" in encoder_inputs: + decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask") + + encoder_outputs = self.encoder(**encoder_inputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + decoder_inputs = { + "input_ids": decoder_input_ids, + "attention_mask": decoder_attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": attention_mask, + "inputs_embeds": decoder_inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "use_cache": use_cache, + "past_key_values": past_key_values, + "return_dict": return_dict, + "training": training, + } + + # Add arguments to decoder from `kwargs_decoder` + decoder_inputs.update(kwargs_decoder) + + decoder_outputs = self.decoder(**decoder_inputs) + + logits = decoder_outputs[0] + + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + warnings.warn(DEPRECATION_WARNING, FutureWarning) + loss = self.hf_compute_loss(labels, logits) + + if not return_dict: + past_key_values = None + if use_cache: + past_key_values = decoder_outputs[1] + # The starting index of the remaining elements in `decoder_outputs` + start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) + + if not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs + output = tuple(x for x in output if x is not None) + return output + + return TFSeq2SeqLMOutput( + loss=loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) + decoder_attention_mask = decoder_inputs.get("attention_mask", None) + past_key_values = decoder_inputs.get("past_key_values") + if past_key_values is None: + past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2 + input_dict = { + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete + "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), + "past_key_values": past_key_values, + "use_cache": use_cache, + } + return input_dict + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported.Please use the" + " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" + " model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past, beam_idx) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "enc_to_dec_proj", None) is not None: + with tf.name_scope(self.enc_to_dec_proj.name): + self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size]) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +__all__ = ["TFEncoderDecoderModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/eomt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/eomt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f4fe6327b312ff5f60ffb08c4b76566bf63f3f9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/eomt/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_eomt import * + from .image_processing_eomt import * + from .image_processing_eomt_fast import * + from .modeling_eomt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/eomt/configuration_eomt.py b/venv/lib/python3.13/site-packages/transformers/models/eomt/configuration_eomt.py new file mode 100644 index 0000000000000000000000000000000000000000..670250721150e60df3d5da9280197cdad461beef --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/eomt/configuration_eomt.py @@ -0,0 +1,168 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/eomt/modular_eomt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_eomt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig + + +class EomtConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EomtForUniversalSegmentation`]. It is used to instantiate an EoMT model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the EoMT + [tue-mps/coco_panoptic_eomt_large_640](https://huggingface.co/tue-mps/coco_panoptic_eomt_large_640) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the hidden representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads in each attention layer. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the MLP hidden dimensionality to the hidden size. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 640): + The size (resolution) of each input image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value for the LayerScale parameter. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The stochastic depth rate (drop path) used during training. + num_upscale_blocks (`int`, *optional*, defaults to 2): + Number of upsampling blocks used in the decoder or segmentation head. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability applied after attention projection. + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_blocks (`int`, *optional*, defaults to 4): + Number of feature blocks or stages in the architecture. + no_object_weight (`float`, *optional*, defaults to 0.1): + Loss weight for the 'no object' class in panoptic/instance segmentation. + class_weight (`float`, *optional*, defaults to 2.0): + Loss weight for classification targets. + mask_weight (`float`, *optional*, defaults to 5.0): + Loss weight for mask prediction. + dice_weight (`float`, *optional*, defaults to 5.0): + Loss weight for the dice loss component. + train_num_points (`int`, *optional*, defaults to 12544): + Number of points to sample for mask loss computation during training. + oversample_ratio (`float`, *optional*, defaults to 3.0): + Oversampling ratio used in point sampling for mask training. + importance_sample_ratio (`float`, *optional*, defaults to 0.75): + Ratio of points to sample based on importance during training. + num_queries (`int`, *optional*, defaults to 200): + Number of object queries in the Transformer. + num_register_tokens (`int`, *optional*, defaults to 4): + Number of learnable register tokens added to the transformer input. + + Example: + + ```python + >>> from transformers import EomtConfig, EomtForUniversalSegmentation + + >>> # Initialize configuration + >>> config = EomtConfig() + + >>> # Initialize model + >>> model = EomtForUniversalSegmentation(config) + + >>> # Access config + >>> config = model.config + ```""" + + model_type = "eomt" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=640, + patch_size=16, + num_channels=3, + layerscale_value=1.0, + drop_path_rate=0.0, + num_upscale_blocks=2, + attention_dropout=0.0, + use_swiglu_ffn=False, + num_blocks=4, + no_object_weight: float = 0.1, + class_weight: float = 2.0, + mask_weight: float = 5.0, + dice_weight: float = 5.0, + train_num_points: int = 12544, + oversample_ratio: float = 3.0, + importance_sample_ratio: float = 0.75, + num_queries=200, + num_register_tokens=4, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + + self.mlp_ratio = mlp_ratio + self.attention_dropout = attention_dropout + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.num_upscale_blocks = num_upscale_blocks + self.use_swiglu_ffn = use_swiglu_ffn + self.num_blocks = num_blocks + self.no_object_weight = no_object_weight + self.class_weight = class_weight + self.mask_weight = mask_weight + self.dice_weight = dice_weight + self.train_num_points = train_num_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.num_queries = num_queries + self.num_register_tokens = num_register_tokens + + +__all__ = ["EomtConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/eomt/image_processing_eomt.py b/venv/lib/python3.13/site-packages/transformers/models/eomt/image_processing_eomt.py new file mode 100644 index 0000000000000000000000000000000000000000..2b786ce39e71a8710fc46f053a4570e90444830f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/eomt/image_processing_eomt.py @@ -0,0 +1,973 @@ +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for EoMT.""" + +import math +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + PaddingMode, + pad, + resize, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + TensorType, + filter_out_non_signature_kwargs, + is_torch_available, + logging, +) + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + import torch.nn.functional as F + + +# Adapted from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks +def convert_segmentation_map_to_binary_masks( + segmentation_map: np.ndarray, + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + ignore_index: Optional[int] = None, +): + if ignore_index is not None: + segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1) + + # Get unique ids (class or instance ids based on input) + all_labels = np.unique(segmentation_map) + + # Drop background label if applicable + if ignore_index is not None: + all_labels = all_labels[all_labels != ignore_index] + + # Generate a binary mask for each object instance + binary_masks = [(segmentation_map == i) for i in all_labels] + + # Stack the binary masks + if binary_masks: + binary_masks = np.stack(binary_masks, axis=0) + else: + binary_masks = np.zeros((0, *segmentation_map.shape)) + + # Convert instance ids to class ids + if instance_id_to_semantic_id is not None: + labels = np.zeros(all_labels.shape[0]) + + for label in all_labels: + class_id = instance_id_to_semantic_id[label + 1 if ignore_index is not None else label] + labels[all_labels == label] = class_id - 1 if ignore_index is not None else class_id + else: + labels = all_labels + + return binary_masks.astype(np.float32), labels.astype(np.int64) + + +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = round(raw_size * height / width) + else: + oh = round(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = round(raw_size * width / height) + else: + ow = round(size * width / height) + + return (oh, ow) + + +# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_mask = mask_probs[k] >= mask_threshold + original_area = original_mask.sum() + + final_mask = mask_k & original_mask + final_mask_area = final_mask.sum() + + mask_exists = mask_k_area > 0 and original_area > 0 and final_mask_area > 0 + + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, final_mask + + +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + stuff_classes, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_size: Optional[tuple[int, int]] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.long, device=mask_probs.device) - 1 + segments: list[dict] = [] + + # Compute per-pixel assignment based on weighted mask scores + mask_probs = mask_probs.sigmoid() + mask_labels = (pred_scores[:, None, None] * mask_probs).argmax(0) + + # Keep track of instances of each class + current_segment_id = 0 + stuff_memory_list: dict[str, int] = {} + + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + + # Check if mask exists and large enough to be a segment + mask_exists, final_mask = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if not mask_exists: + continue + + if stuff_classes and pred_class in stuff_classes: + if pred_class in stuff_memory_list: + segmentation[final_mask] = stuff_memory_list[pred_class] + continue + else: + stuff_memory_list[pred_class] = current_segment_id + + segmentation[final_mask] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "score": segment_score, + } + ) + current_segment_id += 1 + return segmentation, segments + + +def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]: + """Returns the height and width from a size dict.""" + target_height = size_dict["shortest_edge"] + target_width = size_dict.get("longest_edge") or target_height + + return target_height, target_width + + +class EomtImageProcessor(BaseImageProcessor): + r""" + Constructs a EoMT image processor. The image processor can be used to prepare image(s) and optional targets + for the model. + + This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 640): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a + sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of + the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size * + height / width, size)`. + resample (`int`, *optional*, defaults to `Resampling.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`, + `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`, + `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set + to `True`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input to a certain `scale`. + rescale_factor (`float`, *optional*, defaults to `1/ 255`): + Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + do_split_image (`bool`, *optional*, defaults to `False`): + Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the + input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches. + Otherwise, the input images will be padded to the target size. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean. + image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the + ImageNet std. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + num_labels (`int`, *optional*): + The number of labels in the segmentation map. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + do_split_image: bool = False, + do_pad: bool = False, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + ignore_index: Optional[int] = None, + num_labels: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + size = size if size is not None else {"shortest_edge": 640, "longest_edge": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_split_image = do_split_image + self.do_pad = do_pad + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.ignore_index = ignore_index + self.num_labels = num_labels + + def resize( + self, + image: np.ndarray, + size: dict, + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format=None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + image_size = get_image_size(image) + output_size = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) + + image = resize( + image=image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + return_numpy=True, + **kwargs, + ) + + return image + + def _split_image(self, image: ImageInput, size: dict, image_index: int) -> tuple[list, list]: + """Slices an image into overlapping patches for semantic segmentation.""" + + patches, patch_offsets = [], [] + + image_size = get_image_size(image) + patch_size = size["shortest_edge"] + + longer_side = max(image_size) + num_patches = math.ceil(longer_side / patch_size) + total_overlap = num_patches * patch_size - longer_side + overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0 + + for i in range(num_patches): + start = int(i * (patch_size - overlap_per_patch)) + end = start + patch_size + + if image_size[0] > image_size[1]: + patch = image[:, start:end, :] + else: + patch = image[:, :, start:end] + + patches.append(patch) + patch_offsets.append([image_index, start, end]) + + return patches, patch_offsets + + def _pad(self, image: ImageInput, size: dict) -> np.ndarray: + """Pads the image to the target size using zero padding.""" + height, width = get_image_size(image) + + target_height, target_width = get_target_size(size) + pad_h = max(0, target_height - height) + pad_w = max(0, target_width - width) + + padding = ((0, pad_h), (0, pad_w)) + + # Channel axis is last; default padding format is compatible + padded_image = pad(image=image, padding=padding, mode=PaddingMode.CONSTANT, constant_values=0.0) + return padded_image + + def _preprocess_images( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_split_image: Optional[bool] = None, + do_pad: Optional[bool] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a batch of images.""" + images = [to_numpy_array(image) for image in images] + + if do_resize: + images = [ + self.resize( + image, + size=size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + ) + for image in images + ] + + processed_images, patch_offsets = [], [] + + if do_split_image: + for idx, img in enumerate(images): + patches, offsets = self._split_image(img, size, idx) + processed_images.extend(patches) + patch_offsets.extend(offsets) + + images = processed_images + + if do_pad: + images = [self._pad(img, size) for img in images] + + if do_rescale: + images = [self.rescale(img, scale=rescale_factor, input_data_format=input_data_format) for img in images] + + if do_normalize: + images = [ + self.normalize( + image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + for image in images + ] + + return images, patch_offsets + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: Optional[bool] = False, + do_pad: Optional[bool] = False, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + data_format: Union[str, ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map) + + if do_resize: + segmentation_map = self.resize( + segmentation_map, + size=size, + resample=resample, + data_format=data_format, + ) + + if do_pad: + segmentation_map = self._pad(segmentation_map, size) + + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + return torch.from_numpy(segmentation_map) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[Union[list[dict[int, int]], dict[int, int]]] = None, + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + do_split_image: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + do_pad: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + ignore_index: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Preprocesses images or a batch of images. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. + do_split_image (`bool`, *optional*, defaults to `self.do_split_image`): + Whether to split the input images into overlapping patches for semantic segmentation. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the input images. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use when resizing. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the input images by `rescale_factor`. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Factor to scale image pixel values. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the input images. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean for normalization. Single value or list for each channel. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation for normalization. Single value or list for each channel. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be `"pt"`, `"tf"`, `"np"`, or `"jax"`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + Channel format of the output image. Either `"channels_first"` or `"channels_last"`. + input_data_format (`ChannelDimension` or `str`, *optional*): + Channel format of the input image. + """ + + do_split_image = do_split_image if do_split_image is not None else self.do_split_image + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_pad = do_pad if do_pad is not None else self.do_pad + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + pixel_values_list, patch_offsets = self._preprocess_images( + images=images, + do_resize=do_resize, + size=size, + resample=resample, + do_split_image=do_split_image, + do_pad=do_pad, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + + if segmentation_maps is not None: + segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) + segmentation_maps = [to_numpy_array(mask) for mask in segmentation_maps] + + segmentation_maps = [ + self._preprocess_mask( + segmentation_map, + do_resize=do_resize, + do_pad=do_pad, + size=size, + resample=PILImageResampling.NEAREST, + data_format=data_format, + input_data_format=input_data_format, + ) + for segmentation_map in segmentation_maps + ] + + encoded_inputs = self.encode_inputs( + pixel_values_list, + segmentation_maps, + instance_id_to_semantic_id, + ignore_index, + return_tensors, + input_data_format=data_format, + ) + + if do_split_image and patch_offsets: + encoded_inputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets] + + return encoded_inputs + + def encode_inputs( + self, + pixel_values_list: list[ImageInput], + segmentation_maps: Optional[ImageInput] = None, + instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None, + ignore_index: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. + + EoMT addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps + will be converted to lists of binary masks and their respective labels. Let's see an example, assuming + `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels = + [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for + each mask. + + Args: + pixel_values_list (`list[ImageInput]`): + list of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, + width)`. + + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + + (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an + instance segmentation map where each pixel represents an instance id. Can be provided as a single + dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map + instance ids in each image separately. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` + objects. + + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model + (when `annotations` are provided). + - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when + `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of + `mask_labels[i][j]` if `class_labels[i][j]`. + """ + ignore_index = self.ignore_index if ignore_index is None else ignore_index + + pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(pixel_values_list[0]) + + encoded_inputs = BatchFeature({"pixel_values": pixel_values_list}, tensor_type=return_tensors) + + if segmentation_maps is not None: + mask_labels = [] + class_labels = [] + # Convert to list of binary masks and labels + for idx, segmentation_map in enumerate(segmentation_maps): + segmentation_map = to_numpy_array(segmentation_map) + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = convert_segmentation_map_to_binary_masks( + segmentation_map, + instance_id, + ignore_index=ignore_index, + ) + + mask_labels.append(torch.from_numpy(masks)) + class_labels.append(torch.from_numpy(classes)) + + # we cannot batch them since they don't share a common class size + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels + + return encoded_inputs + + def merge_image_patches( + self, + segmentation_logits: torch.Tensor, + patch_offsets: list[tuple[int, int, int]], + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[torch.Tensor]: + """ + Reconstructs full-size semantic segmentation logits from patch predictions. + + Args: + segmentation_logits (`torch.Tensor`): + A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits + for each image patch. + patch_offsets (`list[tuple[int, int, int]]`): + A list of tuples where each tuple contains: + - `image_index` (int): Index of the original image this patch belongs to. + - `start` (int): Start pixel index of the patch along the long dimension (height or width). + - `end` (int): End pixel index of the patch along the long dimension. + target_sizes (`list[tuple[int, int]]`): + list of original (height, width) dimensions for each image before preprocessing. + size (`dict[str, int]`): + A size dict which was used to resize. + """ + num_classes = segmentation_logits.shape[1] + aggregated_logits = [] + patch_counts = [] + + for image_size in target_sizes: + height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) + aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) + patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) + + # Stitch patches back into full-sized logit maps + for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): + if target_sizes[image_idx][0] > target_sizes[image_idx][1]: + aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, patch_start:patch_end, :] += 1 + else: + aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, :, patch_start:patch_end] += 1 + + # Normalize and resize logits to original image size + reconstructed_logits = [] + for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)): + averaged_logits = logit_sum / count.clamp(min=1) + resized_logits = F.interpolate( + averaged_logits[None, ...], + size=target_sizes[idx], + mode="bilinear", + align_corners=False, + )[0] + + reconstructed_logits.append(resized_logits) + + return reconstructed_logits + + def unpad_image( + self, + segmentation_logits: torch.Tensor, + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[torch.Tensor]: + """Restores panoptic segmentation logits to their original image resolutions.""" + + resized_logits = [] + + for idx, original_size in enumerate(target_sizes): + target_height, target_width = get_size_with_aspect_ratio( + original_size, size["shortest_edge"], size["longest_edge"] + ) + cropped_logits = segmentation_logits[idx][:, :target_height, :target_width] + upsampled_logits = F.interpolate( + cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False + )[0] + resized_logits.append(upsampled_logits) + return resized_logits + + def post_process_semantic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + size: Optional[dict[str, int]] = None, + ) -> np.ndarray: + """Post-processes model outputs into final semantic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + patch_offsets = outputs.patch_offsets + + output_size = get_target_size(size) + masks_queries_logits = F.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size) + + preds = [logit.argmax(dim=0) for logit in output_logits] + return preds + + def post_process_panoptic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.8, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + stuff_classes: Optional[list[int]] = None, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into final panoptic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + output_size = get_target_size(size) + masks_queries_logits = F.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1) + + results: list = [] + + for i in range(batch_size): + mask_probs, pred_scores, pred_labels = remove_low_and_no_objects( + mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels + ) + + # No mask found + if mask_probs.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + segmentation, segments = compute_segments( + mask_probs=mask_probs, + pred_scores=pred_scores, + pred_labels=pred_labels, + stuff_classes=stuff_classes, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + target_size=target_sizes[i] if target_sizes is not None else None, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + @filter_out_non_signature_kwargs() + def post_process_instance_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.5, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into Instance Segmentation Predictions.""" + + size = size if size is not None else self.size + + class_queries_logits = outputs.class_queries_logits + masks_queries_logits = outputs.masks_queries_logits + + output_size = get_target_size(size) + masks_queries_logits = F.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + + device = masks_queries_logits.device + batch_size = class_queries_logits.shape[0] + num_queries = class_queries_logits.shape[-2] + + results = [] + + for i in range(batch_size): + mask_pred = mask_probs_batch[i] + mask_class = class_queries_logits[i] + + # Remove the null class `[..., :-1]` + scores, pred_classes = mask_class.softmax(dim=-1)[..., :-1].max(-1) + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores * mask_scores + + segmentation = torch.zeros(target_sizes[i], device=device) - 1 + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not torch.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["EomtImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/eomt/image_processing_eomt_fast.py b/venv/lib/python3.13/site-packages/transformers/models/eomt/image_processing_eomt_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..ca80231d3a7646411e1134e8e97986100da8b025 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/eomt/image_processing_eomt_fast.py @@ -0,0 +1,536 @@ +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for EoMT.""" + +import math +from typing import Optional, Union + +import numpy as np +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + filter_out_non_signature_kwargs, +) +from .image_processing_eomt import ( + compute_segments, + convert_segmentation_map_to_binary_masks, + get_size_with_aspect_ratio, + remove_low_and_no_objects, +) + + +class EomtImageProcessorFastKwargs(DefaultFastImageProcessorKwargs): + """ + do_split_image (`bool`, *optional*, defaults to `False`): + Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the + input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches. + Otherwise, the input images will be padded to the target size. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + """ + + do_split_image: bool + do_pad: bool + ignore_index: Optional[int] = None + + +def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]: + """Returns the height and width from a size dict.""" + target_height = size_dict["shortest_edge"] + target_width = size_dict["longest_edge"] or target_height + + return target_height, target_width + + +def reorder_patches_and_offsets( + patches: list[torch.Tensor], offsets: list[list[int]] +) -> tuple[list[torch.Tensor], list[list[int]]]: + """Sorts patches and offsets according to the original image index.""" + + combined = list(zip(offsets, patches)) + combined.sort(key=lambda x: x[0][0]) + sorted_offsets, sorted_patches = zip(*combined) + + return list(sorted_patches), list(sorted_offsets) + + +@auto_docstring +class EomtImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 640, "longest_edge": 640} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + do_split_image = False + do_pad = False + ignore_index = None + valid_kwargs = EomtImageProcessorFastKwargs + + def __init__(self, **kwargs: Unpack[EomtImageProcessorFastKwargs]): + super().__init__(**kwargs) + + def _split_image(self, images: torch.Tensor, size: dict, image_indices: int) -> tuple[list, list]: + """Slices an image into overlapping patches for semantic segmentation.""" + + patches, patch_offsets = [], [] + + _, _, height, width = images.shape + patch_size = size["shortest_edge"] + + longer_side = max(height, width) + num_patches = math.ceil(longer_side / patch_size) + total_overlap = num_patches * patch_size - longer_side + overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0 + + for i in range(num_patches): + start = int(i * (patch_size - overlap_per_patch)) + end = start + patch_size + + if height > width: + batch_patch = images[:, :, start:end, :] + else: + batch_patch = images[:, :, :, start:end] + + for batch_idx, single in enumerate(torch.unbind(batch_patch, dim=0)): + patches.append(single) + patch_offsets.append([image_indices[batch_idx], start, end]) + + return patches, patch_offsets + + def _pad(self, images: torch.Tensor, size: dict) -> torch.Tensor: + """Pads the image to the target size using zero padding.""" + _, _, height, width = images.shape + + target_height, target_width = get_target_size(size) + pad_h = max(0, target_height - height) + pad_w = max(0, target_width - width) + padding = (0, pad_w, 0, pad_h) + + padded_images = torch.nn.functional.pad(images, padding, mode="constant", value=0.0) + return padded_images + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[list[torch.Tensor]] = None, + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + **kwargs: Unpack[EomtImageProcessorFastKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess for corresponding images. + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. + """ + return super().preprocess(images, segmentation_maps, instance_id_to_semantic_id, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + instance_id_to_semantic_id: Optional[dict[int, int]], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[EomtImageProcessorFastKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + ignore_index = kwargs.pop("ignore_index", None) + images_kwargs = kwargs.copy() + processed_images, patch_offsets = self._preprocess(images, **images_kwargs) + outputs = BatchFeature({"pixel_values": processed_images}) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + # Nearest interpolation is used for segmentation maps instead of BILINEAR. + "interpolation": F.InterpolationMode.NEAREST_EXACT, + } + ) + + processed_segmentation_maps, _ = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + processed_segmentation_maps = processed_segmentation_maps.squeeze(1).to(torch.int64) + # Convert to list of binary masks and labels + mask_labels, class_labels = [], [] + for idx, segmentation_map in enumerate(processed_segmentation_maps): + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = convert_segmentation_map_to_binary_masks( + segmentation_map, + instance_id, + ignore_index=ignore_index, + ) + + mask_labels.append(torch.from_numpy(masks)) + class_labels.append(torch.from_numpy(classes)) + + # we cannot batch them since they don't share a common class size + outputs["mask_labels"] = mask_labels + outputs["class_labels"] = class_labels + + if patch_offsets: + outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets] + + return outputs + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_split_image: bool, + do_pad: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ): + """Preprocesses the input images and masks if provided.""" + processed_images, patch_offsets = [], [] + + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for batched resizing, Needed in case do_resize is False. + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + original_indices = [ + original_idx for original_idx, (img_shape, _) in grouped_images_index.items() if img_shape == shape + ] + + if do_split_image: + patches, offsets = self._split_image(stacked_images, size, original_indices) + processed_images.extend(patches) + patch_offsets.extend(offsets) + + if do_pad: + stacked_images = self._pad(stacked_images, size) + processed_images_grouped[shape] = stacked_images + + if do_split_image: + images, patch_offsets = reorder_patches_and_offsets(processed_images, patch_offsets) + + if do_pad: + images = reorder_images(processed_images_grouped, grouped_images_index) + + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + images = reorder_images(processed_images_grouped, grouped_images_index) + + processed_images = torch.stack(images, dim=0) if return_tensors else images + + return processed_images, patch_offsets + + def merge_image_patches( + self, + segmentation_logits: torch.Tensor, + patch_offsets: list[tuple[int, int, int]], + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[torch.Tensor]: + """ + Reconstructs full-size semantic segmentation logits from patch predictions. + + Args: + segmentation_logits (`torch.Tensor`): + A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits + for each image patch. + patch_offsets (`list[tuple[int, int, int]]`): + A list of tuples where each tuple contains: + - `image_index` (int): Index of the original image this patch belongs to. + - `start` (int): Start pixel index of the patch along the long dimension (height or width). + - `end` (int): End pixel index of the patch along the long dimension. + target_sizes (`list[tuple[int, int]]`): + list of original (height, width) dimensions for each image before preprocessing. + size (`dict[str, int]`): + A size dict which was used to resize. + """ + num_classes = segmentation_logits.shape[1] + aggregated_logits = [] + patch_counts = [] + + for image_size in target_sizes: + height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) + aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) + patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) + + # Stitch patches back into full-sized logit maps + for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): + if target_sizes[image_idx][0] > target_sizes[image_idx][1]: + aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, patch_start:patch_end, :] += 1 + else: + aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, :, patch_start:patch_end] += 1 + + # Normalize and resize logits to original image size + reconstructed_logits = [] + for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)): + averaged_logits = logit_sum / count.clamp(min=1) + resized_logits = torch.nn.functional.interpolate( + averaged_logits[None, ...], + size=target_sizes[idx], + mode="bilinear", + align_corners=False, + )[0] + + reconstructed_logits.append(resized_logits) + + return reconstructed_logits + + def unpad_image( + self, + segmentation_logits: torch.Tensor, + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[torch.Tensor]: + """Restores panoptic segmentation logits to their original image resolutions.""" + + resized_logits = [] + + for idx, original_size in enumerate(target_sizes): + target_height, target_width = get_size_with_aspect_ratio( + original_size, size["shortest_edge"], size["longest_edge"] + ) + cropped_logits = segmentation_logits[idx][:, :target_height, :target_width] + upsampled_logits = torch.nn.functional.interpolate( + cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False + )[0] + resized_logits.append(upsampled_logits) + return resized_logits + + def post_process_semantic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + size: Optional[dict[str, int]] = None, + ) -> np.ndarray: + """Post-processes model outputs into final semantic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + patch_offsets = outputs.patch_offsets + + output_size = get_target_size(size) + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size) + + preds = [logit.argmax(dim=0) for logit in output_logits] + return preds + + def post_process_panoptic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.8, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + stuff_classes: Optional[list[int]] = None, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into final panoptic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + output_size = get_target_size(size) + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1) + + results: list = [] + + for i in range(batch_size): + mask_probs, pred_scores, pred_labels = remove_low_and_no_objects( + mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels + ) + + # No mask found + if mask_probs.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + segmentation, segments = compute_segments( + mask_probs=mask_probs, + pred_scores=pred_scores, + pred_labels=pred_labels, + stuff_classes=stuff_classes, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + target_size=target_sizes[i] if target_sizes is not None else None, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + @filter_out_non_signature_kwargs() + def post_process_instance_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.8, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into Instance Segmentation Predictions.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits + class_queries_logits = outputs.class_queries_logits + + output_size = get_target_size(size) + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + + device = masks_queries_logits.device + batch_size = class_queries_logits.shape[0] + num_queries = class_queries_logits.shape[-2] + + results = [] + + for i in range(batch_size): + mask_pred = mask_probs_batch[i] + mask_class = class_queries_logits[i] + + # Remove the null class `[..., :-1]` + scores, pred_classes = mask_class.softmax(dim=-1)[..., :-1].max(-1) + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores * mask_scores + + segmentation = torch.zeros(target_sizes[i], device=device) - 1 + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not torch.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["EomtImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/eomt/modeling_eomt.py b/venv/lib/python3.13/site-packages/transformers/models/eomt/modeling_eomt.py new file mode 100644 index 0000000000000000000000000000000000000000..2143716a0b2f5ac45b840f3b8a9f71eb241293af --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/eomt/modeling_eomt.py @@ -0,0 +1,1226 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/eomt/modular_eomt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_eomt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +import math +from dataclasses import dataclass +from typing import Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...file_utils import ModelOutput, is_scipy_available, requires_backends +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, is_accelerate_available +from ...utils.generic import check_model_inputs +from .configuration_eomt import EomtConfig + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`EomtForUniversalSegmentationOutput`]. + + This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or + [`~EomtImageProcessor.post_process_instance_segmentation`] or + [`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see + [`~EomtImageProcessor] for details regarding usage. + """ +) +class EomtForUniversalSegmentationOutput(ModelOutput): + r""" + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last layer. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segmentation. + """ + + loss: Optional[torch.FloatTensor] = None + class_queries_logits: Optional[torch.FloatTensor] = None + masks_queries_logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + patch_offsets: Optional[list[torch.Tensor]] = None + + +# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py +def sample_point( + input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs +) -> torch.Tensor: + """ + A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. + + Args: + input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): + A tensor that contains features map on a height * width grid + point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: + 2)): + A tensor that contains [0, 1] * [0, 1] normalized point coordinates + add_dim (`bool`): + boolean value to keep track of added dimension + + Returns: + point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, + height_grid, width_grid): + A tensor that contains features for points in `point_coordinates`. + """ + if point_coordinates.dim() == 3: + add_dim = True + point_coordinates = point_coordinates.unsqueeze(2) + + # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation + point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs) + if add_dim: + point_features = point_features.squeeze(3) + + return point_features + + +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * torch.matmul(inputs, labels.T) + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + r""" + A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss between each pairs. + """ + + height_and_width = inputs.shape[1] + + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) + cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + + loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T) + loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T) + loss = loss_pos + loss_neg + return loss + + +# Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/matcher.py +class EomtHungarianMatcher(nn.Module): + """This class computes an assignment between the labels and the predictions of the network. + + For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more + predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + """ + + def __init__( + self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 + ): + """Creates the matcher + + Params: + cost_class (`float`, *optional*, defaults to 1.0): + Relative weight of the classification error in the matching cost. + cost_mask (`float`, *optional*, defaults to 1.0): + This is the relative weight of the focal loss of the binary mask in the matching cost. + cost_dice (`float`, *optional*, defaults to 1.0): + This is the relative weight of the dice loss of the binary mask in the matching cost. + num_points (`int`, *optional*, defaults to 12544): + No. of points to sample on which the mask loss will be calculated. The same set of K points are + uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite + matching. + """ + super().__init__() + if cost_class == 0 and cost_mask == 0 and cost_dice == 0: + raise ValueError("All costs can't be 0") + + self.num_points = num_points + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + @torch.no_grad() + def forward( + self, + masks_queries_logits: torch.Tensor, + class_queries_logits: torch.Tensor, + mask_labels: torch.Tensor, + class_labels: torch.Tensor, + ) -> list[tuple[Tensor]]: + """ + Params: + masks_queries_logits (`torch.Tensor`): + A tensor of dim `batch_size, num_queries, num_labels` with the classification logits. + class_queries_logits (`torch.Tensor`): + A tensor of dim `batch_size, num_queries, height, width` with the predicted masks. + class_labels (`torch.Tensor`): + A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the + target) containing the class labels. + mask_labels (`torch.Tensor`): + A tensor of dim `num_target_boxes, height, width` containing the target masks. + + Returns: + matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j) + where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected labels (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes). + """ + indices: list[tuple[np.array]] = [] + + # iterate through batch size + batch_size = masks_queries_logits.shape[0] + for i in range(batch_size): + pred_probs = class_queries_logits[i].softmax(-1) + pred_mask = masks_queries_logits[i] + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted. + cost_class = -pred_probs[:, class_labels[i]] + target_mask = mask_labels[i].to(pred_mask) + target_mask = target_mask[:, None] + pred_mask = pred_mask[:, None] + + # Sample ground truth and predicted masks + point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) + + target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1) + target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) + + pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1) + pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) + + # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels) + cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) + # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels) + cost_dice = pair_wise_dice_loss(pred_mask, target_mask) + # final cost matrix + cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible`` + cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10)) + cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10)) + cost_matrix = torch.nan_to_num(cost_matrix, 0) + # do the assignment using the hungarian algorithm in scipy + assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) + indices.append(assigned_indices) + + # It could be stacked in one tensor + matched_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices + ] + return matched_indices + + +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor: + r""" + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss. + """ + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss = criterion(inputs, labels) + + loss = cross_entropy_loss.mean(1).sum() / num_masks + return loss + + +# Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/criterion.py +class EomtLoss(nn.Module): + def __init__(self, config: EomtConfig, weight_dict: dict[str, float]): + """ + The Eomt Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we + compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair + of matched ground-truth / prediction (supervise class and mask) + + Args: + config (`EomtConfig`): + The configuration for Eomt model also containing loss calculation specific parameters. + weight_dict (`dict[str, float]`): + A dictionary of weights to be applied to the different losses. + """ + super().__init__() + requires_backends(self, ["scipy"]) + self.num_labels = config.num_labels + self.weight_dict = weight_dict + + # Weight to apply to the null class + self.eos_coef = config.no_object_weight + empty_weight = torch.ones(self.num_labels + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + # pointwise mask loss parameters + self.num_points = config.train_num_points + self.oversample_ratio = config.oversample_ratio + self.importance_sample_ratio = config.importance_sample_ratio + + self.matcher = EomtHungarianMatcher( + cost_class=config.class_weight, + cost_dice=config.dice_weight, + cost_mask=config.mask_weight, + num_points=self.num_points, + ) + + def _max_by_axis(self, sizes: list[list[int]]) -> list[int]: + maxes = sizes[0] + for sublist in sizes[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + # Adapted from nested_tensor_from_tensor_list() in original implementation + def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + # get the maximum size in the batch + max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) + # compute final size + batch_shape = [len(tensors)] + max_size + batch_size, _, height, width = batch_shape + dtype = tensors[0].dtype + device = tensors[0].device + padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) + padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) + # pad the tensors to the size of the biggest one + for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): + padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) + padding_mask[: tensor.shape[1], : tensor.shape[2]] = False + + return padded_tensors, padding_masks + + def loss_labels( + self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array] + ) -> dict[str, Tensor]: + """Compute the losses related to the labels using cross entropy. + + Args: + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + class_labels (`list[torch.Tensor]`): + List of class labels of shape `(labels)`. + indices (`tuple[np.array])`: + The indices computed by the Hungarian matcher. + + Returns: + `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + """ + pred_logits = class_queries_logits + batch_size, num_queries, _ = pred_logits.shape + criterion = nn.CrossEntropyLoss(weight=self.empty_weight) + idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries) + target_classes_o = torch.cat( + [target[j] for target, (_, j) in zip(class_labels, indices)] + ) # shape of (batch_size, num_queries) + target_classes = torch.full( + (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device + ) + target_classes[idx] = target_classes_o + # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) + pred_logits_transposed = pred_logits.transpose(1, 2) + loss_ce = criterion(pred_logits_transposed, target_classes) + losses = {"loss_cross_entropy": loss_ce} + return losses + + def loss_masks( + self, + masks_queries_logits: torch.Tensor, + mask_labels: list[torch.Tensor], + indices: tuple[np.array], + num_masks: int, + ) -> dict[str, torch.Tensor]: + """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, height, width)`. + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + indices (`tuple[np.array])`: + The indices computed by the Hungarian matcher. + num_masks (`int)`: + The number of masks, used for normalization. + + Returns: + losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys: + - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth. + masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth, + masks. + """ + src_idx = self._get_predictions_permutation_indices(indices) + tgt_idx = self._get_targets_permutation_indices(indices) + # shape (batch_size * num_queries, height, width) + pred_masks = masks_queries_logits[src_idx] + # shape (batch_size, num_queries, height, width) + # pad all and stack the targets to the num_labels dimension + target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) + target_masks = target_masks[tgt_idx] + + # No need to upsample predictions as we are using normalized coordinates + pred_masks = pred_masks[:, None] + target_masks = target_masks[:, None] + + # Sample point coordinates + with torch.no_grad(): + point_coordinates = self.sample_points_using_uncertainty( + pred_masks, + lambda logits: self.calculate_uncertainty(logits), + self.num_points, + self.oversample_ratio, + self.importance_sample_ratio, + ) + + point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1) + + point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1) + + losses = { + "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), + "loss_dice": dice_loss(point_logits, point_labels, num_masks), + } + + del pred_masks + del target_masks + return losses + + def _get_predictions_permutation_indices(self, indices): + # Permute predictions following indices + batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + predictions_indices = torch.cat([src for (src, _) in indices]) + return batch_indices, predictions_indices + + def _get_targets_permutation_indices(self, indices): + # Permute labels following indices + batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + target_indices = torch.cat([tgt for (_, tgt) in indices]) + return batch_indices, target_indices + + def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor: + """ + In Eomt paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' + for the foreground class in `classes`. + + Args: + logits (`torch.Tensor`): + A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: + the number of foreground classes. The values are logits. + + Returns: + scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most + uncertain locations having the highest uncertainty score. + """ + uncertainty_scores = -(torch.abs(logits)) + return uncertainty_scores + + def sample_points_using_uncertainty( + self, + logits: torch.Tensor, + uncertainty_function, + num_points: int, + oversample_ratio: int, + importance_sample_ratio: float, + ) -> torch.Tensor: + """ + This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The + uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit + prediction as input. + + Args: + logits (`float`): + Logit predictions for P points. + uncertainty_function: + A function that takes logit predictions for P points and returns their uncertainties. + num_points (`int`): + The number of points P to sample. + oversample_ratio (`int`): + Oversampling parameter. + importance_sample_ratio (`float`): + Ratio of points that are sampled via importance sampling. + + Returns: + point_coordinates (`torch.Tensor`): + Coordinates for P sampled points. + """ + + num_boxes = logits.shape[0] + num_points_sampled = int(num_points * oversample_ratio) + + # Get random point coordinates + point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) + # Get sampled prediction value for the point coordinates + point_logits = sample_point(logits, point_coordinates, align_corners=False) + # Calculate the uncertainties based on the sampled prediction values of the points + point_uncertainties = uncertainty_function(point_logits) + + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) + idx += shift[:, None] + point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) + + if num_random_points > 0: + point_coordinates = torch.cat( + [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], + dim=1, + ) + return point_coordinates + + def forward( + self, + masks_queries_logits: torch.Tensor, + class_queries_logits: torch.Tensor, + mask_labels: list[torch.Tensor], + class_labels: list[torch.Tensor], + auxiliary_predictions: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + """ + This performs the loss computation. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, height, width)`. + class_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, num_labels)`. + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + class_labels (`list[torch.Tensor]`): + List of class labels of shape `(labels)`. + auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*): + if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], then it contains the logits from + the inner layers of the EomtMaskedAttentionDecoder. + + Returns: + losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth + masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], the dictionary contains additional + losses for each auxiliary predictions. + """ + + # retrieve the matching between the outputs of the last layer and the labels + indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + # compute the average number of target masks for normalization purposes + num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) + # get all the losses + losses: dict[str, Tensor] = { + **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), + **self.loss_labels(class_queries_logits, class_labels, indices), + } + # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if auxiliary_predictions is not None: + for idx, aux_outputs in enumerate(auxiliary_predictions): + masks_queries_logits = aux_outputs["masks_queries_logits"] + class_queries_logits = aux_outputs["class_queries_logits"] + loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} + losses.update(loss_dict) + + return losses + + def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + Computes the average number of target masks across the batch, for normalization purposes. + """ + num_masks = sum(len(classes) for classes in class_labels) + num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device) + world_size = 1 + if is_accelerate_available(): + if PartialState._shared_state != {}: + num_masks = reduce(num_masks) + world_size = PartialState().num_processes + + num_masks = torch.clamp(num_masks / world_size, min=1) + return num_masks + + +class EomtPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class EomtEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: EomtConfig) -> None: + super().__init__() + + self.config = config + self.patch_size = config.patch_size + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) + + self.patch_embeddings = EomtPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS] + self.position_embeddings = nn.Embedding(num_patches, config.hidden_size) + self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, _, _, _ = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + + embeddings = embeddings + self.position_embeddings(self.position_ids) + embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1) + + embeddings = self.dropout(embeddings) + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EomtAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class EomtLayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class EomtDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class EomtMLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class EomtSwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class EomtLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: EomtConfig) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = EomtAttention(config) + self.layer_scale1 = EomtLayerScale(config) + self.drop_path = EomtDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = EomtSwiGLUFFN(config) + else: + self.mlp = EomtMLP(config) + self.layer_scale2 = EomtLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states_norm = self.norm1(hidden_states) + self_attention_output, _ = self.attention(hidden_states_norm, head_mask) + self_attention_output = self.layer_scale1(self_attention_output) + + # first residual connection + hidden_states = self.drop_path(self_attention_output) + hidden_states + + # in Eomt, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + return layer_output + + +class EomtLayerNorm2d(nn.LayerNorm): + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = hidden_state.permute(0, 2, 3, 1) + hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps) + hidden_state = hidden_state.permute(0, 3, 1, 2) + return hidden_state + + +class EomtScaleLayer(nn.Module): + def __init__(self, config: EomtConfig): + super().__init__() + hidden_size = config.hidden_size + self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2) + self.activation = ACT2FN[config.hidden_act] + self.conv2 = nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=3, + padding=1, + groups=hidden_size, + bias=False, + ) + + self.layernorm2d = EomtLayerNorm2d(hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.layernorm2d(hidden_states) + return hidden_states + + +class EomtScaleBlock(nn.Module): + def __init__(self, config: EomtConfig): + super().__init__() + self.num_blocks = config.num_upscale_blocks + self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for block in self.block: + hidden_states = block(hidden_states) + return hidden_states + + +class EomtMaskHead(nn.Module): + def __init__(self, config: EomtConfig): + super().__init__() + + hidden_size = config.hidden_size + self.fc1 = nn.Linear(hidden_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.fc3 = nn.Linear(hidden_size, hidden_size) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation(self.fc1(hidden_states)) + hidden_states = self.activation(self.fc2(hidden_states)) + hidden_states = self.fc3(hidden_states) + return hidden_states + + +@auto_docstring +class EomtPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config: EomtConfig + base_model_prefix = "eomt" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _no_split_modules = ["EomtLayer"] + _supports_sdpa = True + _can_record_outputs = { + "hidden_states": EomtLayer, + "attentions": EomtAttention, + } + + def _init_weights(self, module: nn.Module) -> None: + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=1) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, EomtLayerScale): + if hasattr(module, "lambda1"): + module.lambda1.data.fill_(self.config.layerscale_value) + elif isinstance(module, EomtEmbeddings): + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), mean=0.0, std=std + ).to(module.cls_token.dtype) + module.register_tokens.data.zero_() + + +@auto_docstring( + custom_intro=""" + The EoMT Model with head on top for instance/semantic/panoptic segmentation. + """ +) +class EomtForUniversalSegmentation(EomtPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: EomtConfig): + super().__init__(config) + self.config = config + self.num_hidden_layers = config.num_hidden_layers + self.embeddings = EomtEmbeddings(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.query = nn.Embedding(config.num_queries, config.hidden_size) + self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)]) + + self.upscale_block = EomtScaleBlock(config) + self.mask_head = EomtMaskHead(config) + + self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1) + + self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.weight_dict: dict[str, float] = { + "loss_cross_entropy": config.class_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + } + + self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict) + + self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks)) + + self.post_init() + + def get_loss_dict( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + mask_labels: Tensor, + class_labels: Tensor, + auxiliary_predictions: dict[str, Tensor], + ) -> dict[str, Tensor]: + loss_dict: dict[str, Tensor] = self.criterion( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + auxiliary_predictions=auxiliary_predictions, + ) + + # weight each loss by `self.weight_dict[]` including auxiliary losses + for key, weight in self.weight_dict.items(): + for loss_key, loss in loss_dict.items(): + if key in loss_key: + loss *= weight + + return loss_dict + + def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor: + return sum(loss_dict.values()) + + @check_model_inputs() + @auto_docstring + def forward( + self, + pixel_values: Tensor, + mask_labels: Optional[list[Tensor]] = None, + class_labels: Optional[list[Tensor]] = None, + patch_offsets: Optional[list[Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> EomtForUniversalSegmentationOutput: + r""" + mask_labels (`list[torch.Tensor]`, *optional*): + list of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`list[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segmentation. + """ + + masks_queries_logits_per_layer, class_queries_logits_per_layer = (), () + attention_mask = None + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + for idx, layer_module in enumerate(self.layers): + if idx == self.num_hidden_layers - self.config.num_blocks: + query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device) + hidden_states = torch.cat((query, hidden_states), dim=1) + + if idx >= self.num_hidden_layers - self.config.num_blocks and ( + self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0 + ): + norm_hidden_states = self.layernorm(hidden_states) + masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states) + + masks_queries_logits_per_layer += (masks_queries_logits,) + class_queries_logits_per_layer += (class_queries_logits,) + + attention_mask = torch.ones( + hidden_states.shape[0], + hidden_states.shape[1], + hidden_states.shape[1], + device=hidden_states.device, + dtype=torch.bool, + ) + + interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear") + interpolated_logits = interpolated_logits.view( + interpolated_logits.size(0), interpolated_logits.size(1), -1 + ) + + num_query_tokens = self.config.num_queries + encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens + + # Set attention mask for queries to focus on encoder tokens based on interpolated logits + attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0 + + # Disable attention mask for random query tokens. + attention_mask = self._disable_attention_mask( + attention_mask, + prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks], + num_query_tokens=num_query_tokens, + encoder_start_tokens=encoder_start_tokens, + device=attention_mask.device, + ) + + # Expand attention mask to 4d mask. + attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1) + attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9) + + hidden_states = layer_module(hidden_states, attention_mask) + + sequence_output = self.layernorm(hidden_states) + + masks_queries_logits, class_queries_logits = self.predict(sequence_output) + masks_queries_logits_per_layer += (masks_queries_logits,) + class_queries_logits_per_layer += (class_queries_logits,) + + loss = None + if mask_labels is not None and class_labels is not None: + loss = 0.0 + for masks_queries_logits, class_queries_logits in zip( + masks_queries_logits_per_layer, class_queries_logits_per_layer + ): + loss_dict = self.get_loss_dict( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + auxiliary_predictions=None, + ) + loss += self.get_loss(loss_dict) + + return EomtForUniversalSegmentationOutput( + loss=loss, + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + last_hidden_state=sequence_output, + patch_offsets=patch_offsets, + ) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def predict(self, logits: torch.Tensor): + query_tokens = logits[:, : self.config.num_queries, :] + class_logits = self.class_predictor(query_tokens) + + prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :] + prefix_tokens = prefix_tokens.transpose(1, 2) + + prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size) + + query_tokens = self.mask_head(query_tokens) + prefix_tokens = self.upscale_block(prefix_tokens) + + mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens) + + return mask_logits, class_logits + + @staticmethod + def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device): + if prob < 1: + # Generate random queries to disable based on the probs + random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob + + # Disable attention to the query tokens, considering the prefix tokens + attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1 + + return attn_mask + + +__all__ = ["EomtPreTrainedModel", "EomtForUniversalSegmentation"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/eomt/modular_eomt.py b/venv/lib/python3.13/site-packages/transformers/models/eomt/modular_eomt.py new file mode 100644 index 0000000000000000000000000000000000000000..c8942c4d23a4abbd4e99770dfa4de243a11b8a93 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/eomt/modular_eomt.py @@ -0,0 +1,601 @@ +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch EoMT model.""" + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + logging, +) +from ...utils.generic import check_model_inputs +from ..dinov2.modeling_dinov2 import ( + Dinov2Embeddings, + Dinov2Layer, + Dinov2LayerScale, + Dinov2PatchEmbeddings, +) +from ..mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentation, Mask2FormerLoss +from ..siglip.modeling_siglip import SiglipAttention +from ..vit.configuration_vit import ViTConfig + + +logger = logging.get_logger(__name__) + + +class EomtConfig(ViTConfig): + r""" + This is the configuration class to store the configuration of a [`EomtForUniversalSegmentation`]. It is used to instantiate an EoMT model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the EoMT + [tue-mps/coco_panoptic_eomt_large_640](https://huggingface.co/tue-mps/coco_panoptic_eomt_large_640) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the hidden representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads in each attention layer. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the MLP hidden dimensionality to the hidden size. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 640): + The size (resolution) of each input image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value for the LayerScale parameter. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The stochastic depth rate (drop path) used during training. + num_upscale_blocks (`int`, *optional*, defaults to 2): + Number of upsampling blocks used in the decoder or segmentation head. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability applied after attention projection. + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_blocks (`int`, *optional*, defaults to 4): + Number of feature blocks or stages in the architecture. + no_object_weight (`float`, *optional*, defaults to 0.1): + Loss weight for the 'no object' class in panoptic/instance segmentation. + class_weight (`float`, *optional*, defaults to 2.0): + Loss weight for classification targets. + mask_weight (`float`, *optional*, defaults to 5.0): + Loss weight for mask prediction. + dice_weight (`float`, *optional*, defaults to 5.0): + Loss weight for the dice loss component. + train_num_points (`int`, *optional*, defaults to 12544): + Number of points to sample for mask loss computation during training. + oversample_ratio (`float`, *optional*, defaults to 3.0): + Oversampling ratio used in point sampling for mask training. + importance_sample_ratio (`float`, *optional*, defaults to 0.75): + Ratio of points to sample based on importance during training. + num_queries (`int`, *optional*, defaults to 200): + Number of object queries in the Transformer. + num_register_tokens (`int`, *optional*, defaults to 4): + Number of learnable register tokens added to the transformer input. + + Example: + + ```python + >>> from transformers import EomtConfig, EomtForUniversalSegmentation + + >>> # Initialize configuration + >>> config = EomtConfig() + + >>> # Initialize model + >>> model = EomtForUniversalSegmentation(config) + + >>> # Access config + >>> config = model.config + ```""" + + model_type = "eomt" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=640, + patch_size=16, + num_channels=3, + layerscale_value=1.0, + drop_path_rate=0.0, + num_upscale_blocks=2, + attention_dropout=0.0, + use_swiglu_ffn=False, + num_blocks=4, + no_object_weight: float = 0.1, + class_weight: float = 2.0, + mask_weight: float = 5.0, + dice_weight: float = 5.0, + train_num_points: int = 12544, + oversample_ratio: float = 3.0, + importance_sample_ratio: float = 0.75, + num_queries=200, + num_register_tokens=4, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + initializer_range=initializer_range, + layer_norm_eps=layer_norm_eps, + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + **kwargs, + ) + + del self.intermediate_size + del self.qkv_bias + del self.pooler_act + del self.pooler_output_size + del self.encoder_stride + del self.attention_probs_dropout_prob + + self.mlp_ratio = mlp_ratio + self.attention_dropout = attention_dropout + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.num_upscale_blocks = num_upscale_blocks + self.use_swiglu_ffn = use_swiglu_ffn + self.num_blocks = num_blocks + self.no_object_weight = no_object_weight + self.class_weight = class_weight + self.mask_weight = mask_weight + self.dice_weight = dice_weight + self.train_num_points = train_num_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.num_queries = num_queries + self.num_register_tokens = num_register_tokens + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`EomtForUniversalSegmentationOutput`]. + + This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or + [`~EomtImageProcessor.post_process_instance_segmentation`] or + [`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see + [`~EomtImageProcessor] for details regarding usage. + """ +) +class EomtForUniversalSegmentationOutput(ModelOutput): + r""" + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last layer. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segmentation. + """ + + loss: Optional[torch.FloatTensor] = None + class_queries_logits: Optional[torch.FloatTensor] = None + masks_queries_logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + patch_offsets: Optional[list[torch.Tensor]] = None + + +class EomtLoss(Mask2FormerLoss): + pass + + +class EomtPatchEmbeddings(Dinov2PatchEmbeddings): + pass + + +class EomtEmbeddings(Dinov2Embeddings): + def __init__(self, config: EomtConfig) -> None: + nn.Module.__init__(self) + + self.config = config + self.patch_size = config.patch_size + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) + + self.patch_embeddings = EomtPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS] + self.position_embeddings = nn.Embedding(num_patches, config.hidden_size) + self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self): + raise AttributeError("Not needed for Eomt Model") + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, _, _, _ = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + + embeddings = embeddings + self.position_embeddings(self.position_ids) + embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class EomtAttention(SiglipAttention): + pass + + +class EomtLayerScale(Dinov2LayerScale): + pass + + +class EomtLayer(Dinov2Layer): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states_norm = self.norm1(hidden_states) + self_attention_output, _ = self.attention(hidden_states_norm, head_mask) + self_attention_output = self.layer_scale1(self_attention_output) + + # first residual connection + hidden_states = self.drop_path(self_attention_output) + hidden_states + + # in Eomt, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + return layer_output + + +class EomtLayerNorm2d(nn.LayerNorm): + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = hidden_state.permute(0, 2, 3, 1) + hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps) + hidden_state = hidden_state.permute(0, 3, 1, 2) + return hidden_state + + +class EomtScaleLayer(nn.Module): + def __init__(self, config: EomtConfig): + super().__init__() + hidden_size = config.hidden_size + self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2) + self.activation = ACT2FN[config.hidden_act] + self.conv2 = nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=3, + padding=1, + groups=hidden_size, + bias=False, + ) + + self.layernorm2d = EomtLayerNorm2d(hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.layernorm2d(hidden_states) + return hidden_states + + +class EomtScaleBlock(nn.Module): + def __init__(self, config: EomtConfig): + super().__init__() + self.num_blocks = config.num_upscale_blocks + self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for block in self.block: + hidden_states = block(hidden_states) + return hidden_states + + +class EomtMaskHead(nn.Module): + def __init__(self, config: EomtConfig): + super().__init__() + + hidden_size = config.hidden_size + self.fc1 = nn.Linear(hidden_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.fc3 = nn.Linear(hidden_size, hidden_size) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation(self.fc1(hidden_states)) + hidden_states = self.activation(self.fc2(hidden_states)) + hidden_states = self.fc3(hidden_states) + return hidden_states + + +@auto_docstring +class EomtPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config: EomtConfig + base_model_prefix = "eomt" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _no_split_modules = ["EomtLayer"] + _supports_sdpa = True + _can_record_outputs = { + "hidden_states": EomtLayer, + "attentions": EomtAttention, + } + + def _init_weights(self, module: nn.Module) -> None: + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=1) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, EomtLayerScale): + if hasattr(module, "lambda1"): + module.lambda1.data.fill_(self.config.layerscale_value) + elif isinstance(module, EomtEmbeddings): + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), mean=0.0, std=std + ).to(module.cls_token.dtype) + module.register_tokens.data.zero_() + + +@auto_docstring( + custom_intro=""" + The EoMT Model with head on top for instance/semantic/panoptic segmentation. + """ +) +class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation): + def __init__(self, config: EomtConfig): + PreTrainedModel.__init__(self, config) + self.config = config + self.num_hidden_layers = config.num_hidden_layers + self.embeddings = EomtEmbeddings(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.query = nn.Embedding(config.num_queries, config.hidden_size) + self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)]) + + self.upscale_block = EomtScaleBlock(config) + self.mask_head = EomtMaskHead(config) + + self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1) + + self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.weight_dict: dict[str, float] = { + "loss_cross_entropy": config.class_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + } + + self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict) + + self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks)) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def get_auxiliary_logits(self): + raise AttributeError("Note needed for Eomt Model.") + + def predict(self, logits: torch.Tensor): + query_tokens = logits[:, : self.config.num_queries, :] + class_logits = self.class_predictor(query_tokens) + + prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :] + prefix_tokens = prefix_tokens.transpose(1, 2) + + prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size) + + query_tokens = self.mask_head(query_tokens) + prefix_tokens = self.upscale_block(prefix_tokens) + + mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens) + + return mask_logits, class_logits + + @staticmethod + def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device): + if prob < 1: + # Generate random queries to disable based on the probs + random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob + + # Disable attention to the query tokens, considering the prefix tokens + attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1 + + return attn_mask + + @check_model_inputs() + @auto_docstring + def forward( + self, + pixel_values: Tensor, + mask_labels: Optional[list[Tensor]] = None, + class_labels: Optional[list[Tensor]] = None, + patch_offsets: Optional[list[Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> EomtForUniversalSegmentationOutput: + r""" + mask_labels (`list[torch.Tensor]`, *optional*): + list of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`list[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segmentation. + """ + + masks_queries_logits_per_layer, class_queries_logits_per_layer = (), () + attention_mask = None + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + for idx, layer_module in enumerate(self.layers): + if idx == self.num_hidden_layers - self.config.num_blocks: + query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device) + hidden_states = torch.cat((query, hidden_states), dim=1) + + if idx >= self.num_hidden_layers - self.config.num_blocks and ( + self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0 + ): + norm_hidden_states = self.layernorm(hidden_states) + masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states) + + masks_queries_logits_per_layer += (masks_queries_logits,) + class_queries_logits_per_layer += (class_queries_logits,) + + attention_mask = torch.ones( + hidden_states.shape[0], + hidden_states.shape[1], + hidden_states.shape[1], + device=hidden_states.device, + dtype=torch.bool, + ) + + interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear") + interpolated_logits = interpolated_logits.view( + interpolated_logits.size(0), interpolated_logits.size(1), -1 + ) + + num_query_tokens = self.config.num_queries + encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens + + # Set attention mask for queries to focus on encoder tokens based on interpolated logits + attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0 + + # Disable attention mask for random query tokens. + attention_mask = self._disable_attention_mask( + attention_mask, + prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks], + num_query_tokens=num_query_tokens, + encoder_start_tokens=encoder_start_tokens, + device=attention_mask.device, + ) + + # Expand attention mask to 4d mask. + attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1) + attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9) + + hidden_states = layer_module(hidden_states, attention_mask) + + sequence_output = self.layernorm(hidden_states) + + masks_queries_logits, class_queries_logits = self.predict(sequence_output) + masks_queries_logits_per_layer += (masks_queries_logits,) + class_queries_logits_per_layer += (class_queries_logits,) + + loss = None + if mask_labels is not None and class_labels is not None: + loss = 0.0 + for masks_queries_logits, class_queries_logits in zip( + masks_queries_logits_per_layer, class_queries_logits_per_layer + ): + loss_dict = self.get_loss_dict( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + auxiliary_predictions=None, + ) + loss += self.get_loss(loss_dict) + + return EomtForUniversalSegmentationOutput( + loss=loss, + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + last_hidden_state=sequence_output, + patch_offsets=patch_offsets, + ) + + +__all__ = ["EomtConfig", "EomtPreTrainedModel", "EomtForUniversalSegmentation"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/ernie/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/ernie/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb8983063ddb0117e8b0d7cd6603aa6ac3056b6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ernie/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ernie import * + from .modeling_ernie import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/ernie/configuration_ernie.py b/venv/lib/python3.13/site-packages/transformers/models/ernie/configuration_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..abf300f0ce51696dbdab756a6095689c8344ad08 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ernie/configuration_ernie.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ERNIE model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ErnieConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ErnieModel`] or a [`TFErnieModel`]. It is used to + instantiate a ERNIE model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ERNIE + [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the ERNIE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ErnieModel`] or [`TFErnieModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`ErnieModel`] or [`TFErnieModel`]. + task_type_vocab_size (`int`, *optional*, defaults to 3): + The vocabulary size of the `task_type_ids` for ERNIE2.0/ERNIE3.0 model + use_task_id (`bool`, *optional*, defaults to `False`): + Whether or not the model support `task_type_ids` + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import ErnieConfig, ErnieModel + + >>> # Initializing a ERNIE nghuyong/ernie-3.0-base-zh style configuration + >>> configuration = ErnieConfig() + + >>> # Initializing a model (with random weights) from the nghuyong/ernie-3.0-base-zh style configuration + >>> model = ErnieModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + task_type_vocab_size=3, + use_task_id=False, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.task_type_vocab_size = task_type_vocab_size + self.use_task_id = use_task_id + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class ErnieOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ("task_type_ids", dynamic_axis), + ] + ) + + +__all__ = ["ErnieConfig", "ErnieOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/ernie/modeling_ernie.py b/venv/lib/python3.13/site-packages/transformers/models/ernie/modeling_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..7cbce6b2d20b452acc9480c4cd28a4f5039e4d62 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ernie/modeling_ernie.py @@ -0,0 +1,1689 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ERNIE model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_ernie import ErnieConfig + + +logger = logging.get_logger(__name__) + + +class ErnieEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.use_task_id = config.use_task_id + if config.use_task_id: + self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + task_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + # add `task_type_id` for ERNIE model + if self.use_task_id: + if task_type_ids is None: + task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + task_type_embeddings = self.task_type_embeddings(task_type_ids) + embeddings += task_type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Ernie +class ErnieSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ErnieModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Ernie +class ErnieSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ERNIE_SELF_ATTENTION_CLASSES = { + "eager": ErnieSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE +class ErnieAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, + ) + self.output = ErnieSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Ernie +class ErnieIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Ernie +class ErnieOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie +class ErnieLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ErnieAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = ErnieAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) + self.intermediate = ErnieIntermediate(config) + self.output = ErnieOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Ernie +class ErnieEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ErnieLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Ernie +class ErniePooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Ernie +class ErniePredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Ernie +class ErnieLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = ErniePredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Ernie +class ErnieOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = ErnieLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->Ernie +class ErnieOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->Ernie +class ErniePreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = ErnieLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@auto_docstring +class ErniePreTrainedModel(PreTrainedModel): + config: ErnieConfig + base_model_prefix = "ernie" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`ErnieForPreTraining`]. + """ +) +# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie +class ErnieForPreTrainingOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + seq_relationship_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring( + custom_intro=""" + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + """ +) +class ErnieModel(ErniePreTrainedModel): + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Ernie + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = ErnieEmbeddings(config) + self.encoder = ErnieEncoder(config) + + self.pooler = ErniePooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertModel.get_input_embeddings + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + # Copied from transformers.models.bert.modeling_bert.BertModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Ernie Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """ +) +class ErnieForPreTraining(ErniePreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + self.ernie = ErnieModel(config) + self.cls = ErniePreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], ErnieForPreTrainingOutput]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Example: + + ```python + >>> from transformers import AutoTokenizer, ErnieForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh") + >>> model = ErnieForPreTraining.from_pretrained("nghuyong/ernie-1.0-base-zh") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return ErnieForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Ernie Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `ErnieForCausalLM` as a standalone, add `is_decoder=True.`") + + self.ernie = ErnieModel(config, add_pooling_layer=False) + self.cls = ErnieOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring +class ErnieForMaskedLM(ErniePreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `ErnieForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.ernie = ErnieModel(config, add_pooling_layer=False) + self.cls = ErnieOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def can_generate(cls) -> bool: + """ + Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a + `prepare_inputs_for_generation` method. + """ + return False + + +@auto_docstring( + custom_intro=""" + Ernie Model with a `next sentence prediction (classification)` head on top. + """ +) +class ErnieForNextSentencePrediction(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForNextSentencePrediction.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + self.ernie = ErnieModel(config) + self.cls = ErnieOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Example: + + ```python + >>> from transformers import AutoTokenizer, ErnieForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh") + >>> model = ErnieForNextSentencePrediction.from_pretrained("nghuyong/ernie-1.0-base-zh") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Ernie Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """ +) +class ErnieForSequenceClassification(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.ernie = ErnieModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class ErnieForMultipleChoice(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + self.ernie = ErnieModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + task_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class ErnieForTokenClassification(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.ernie = ErnieModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class ErnieForQuestionAnswering(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.ernie = ErnieModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ErnieForCausalLM", + "ErnieForMaskedLM", + "ErnieForMultipleChoice", + "ErnieForNextSentencePrediction", + "ErnieForPreTraining", + "ErnieForQuestionAnswering", + "ErnieForSequenceClassification", + "ErnieForTokenClassification", + "ErnieModel", + "ErniePreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb30318fa25ec560d549b08421ea08f02b0dce6f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ernie4_5_moe import * + from .modeling_ernie4_5_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..294ccfc638cf9683f15f94c3d2faee2a5f19cd44 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py @@ -0,0 +1,254 @@ +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ernie 4.5 MoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Ernie4_5_MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Ernie4_5_MoeModel`]. It is used to instantiate a + Ernie 4.5 MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 103424): + Vocabulary size of the Ernie 4.5 MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Ernie4_5_MoeModel`] + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 500000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in any of the projections including mlp and attention for example. + moe_intermediate_size (`int`, *optional*, defaults to 1536): + Intermediate size of the routed expert. + moe_k (`int`, *optional*, defaults to 6): + Number of selected experts. + moe_num_experts (`int`, *optional*, defaults to 64): + Number of routed experts. + moe_num_shared_experts (`int`, *optional*, defaults to 2): + The number of experts that are shared for all MoE forwards. + moe_layer_start_index (`int`, *optional*, defaults to 1): + The first index at which MoE layers start to appear. + moe_layer_end_index (`int`, *optional*, defaults to -1): + The last possible index for a MoE layer. + moe_layer_interval (`int`, *optional*, defaults to 1): + The intervals between MoE layers to appear. + moe_norm_min (`float`, *optional*, defaults to 1e-12): + Minimum division value during routing normalization. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import Ernie4_5_MoeModel, Ernie4_5_MoEConfig + + >>> # Initializing a Ernie4_5_MoE style configuration + >>> configuration = Ernie4_5_MoEConfig() + + >>> # Initializing a model from the ERNIE-4.5-21B-A3B style configuration + >>> model = Ernie4_5_MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie4_5_moe" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_experts": "moe_num_experts", "num_experts_per_tok": "moe_k"} + + # Default tensor parallel plan for base model `Ernie4_5_MoE` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + # sequence parallel is pretty slow + # "norm.weight": "sequence_parallel", + # "layers.*.input_layernorm.weight": "sequence_parallel", + # "layers.*.post_attention_layernorm.weight": "sequence_parallel", + "layers.*.mlp.shared_experts.gate_proj": "local_colwise", + "layers.*.mlp.shared_experts.up_proj": "local_colwise", + "layers.*.mlp.shared_experts.down_proj": "local_rowwise", + "layers.*.mlp.experts.*.gate_proj": "local_colwise", + "layers.*.mlp.experts.*.up_proj": "local_colwise", + "layers.*.mlp.experts.*.down_proj": "local_rowwise", + "layers.*.mlp.experts": "local", + "layers.*.mlp.gate_proj": "local_colwise", + "layers.*.mlp.up_proj": "local_colwise", + "layers.*.mlp.down_proj": "local_rowwise", + "layers.*.mlp": "gather", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=103424, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + hidden_size=2560, + intermediate_size=12288, + num_hidden_layers=28, + num_attention_heads=20, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=True, + rope_theta=500000.0, + rope_scaling=None, + use_bias=False, + moe_intermediate_size=1536, + moe_k=6, + moe_num_experts=64, + moe_num_shared_experts=2, + moe_layer_start_index=1, + moe_layer_end_index=-1, + moe_layer_interval=1, + moe_norm_min=1e-12, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_bias = use_bias + + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.moe_intermediate_size = moe_intermediate_size + self.moe_k = moe_k + self.moe_num_experts = moe_num_experts + self.moe_num_shared_experts = moe_num_shared_experts + self.moe_layer_start_index = moe_layer_start_index + self.moe_layer_end_index = self.num_hidden_layers - 1 if moe_layer_end_index == -1 else moe_layer_end_index + self.moe_layer_interval = moe_layer_interval + self.moe_norm_min = moe_norm_min + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Ernie4_5_MoeConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..3911a39901f4f04bdce4bb390e77c30b3778e1e1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -0,0 +1,749 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ernie4_5_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig + + +@use_kernel_forward_from_hub("RMSNorm") +class Ernie4_5_MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Ernie4_5_MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Ernie4_5_MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Ernie4_5_MoeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Ernie4_5_MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + # keeping it in full precision + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # glm rope style (with full dim) and full precision + original_dtype = q.dtype + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + q_embed = (q.float() * cos) + (rotate_half(q).float() * sin) + k_embed = (k.float() * cos) + (rotate_half(k).float() * sin) + + return q_embed.to(original_dtype), k_embed.to(original_dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Ernie4_5_MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.attention_dropout = 0.0 + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Ernie4_5_MoeStatics(nn.Module): + """ + Stores MoE (Mixture of Experts) statistics + - Bias for the gating + - Additionally, usage per expert in the original codebase + """ + + def __init__(self, config): + super().__init__() + + num_experts_groups = 1 + num_experts = config.moe_num_experts + + self.e_score_correction_bias = nn.Parameter( + torch.zeros(num_experts_groups, num_experts, dtype=torch.float32), + requires_grad=False, + ) + + def forward(self, hidden_states): + # NOTE: This is a workaround to enable TP with a module that only has parameters + # + # Otherwise, it stays as `DTensor` when called in the "super" forward + # 1. All other tensors are local (`torch.Tensor`) + # 2. Isolate does not work on `nn.Module` which only has parameters + return hidden_states + self.e_score_correction_bias.squeeze() + + +class Ernie4_5_MoeSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + + Ernie 4.5 MoE's original formula is based on case (2) with + (optional) shared experts and a corrections bias during gating. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.moe_num_experts + self.top_k = config.moe_k + + # correction bias (yes it seems to be a typo with statics <> statistics) + self.moe_statics = Ernie4_5_MoeStatics(config) + + # gating + self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) + self.experts = nn.ModuleList( + [Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)] + ) + self.norm_min = config.moe_norm_min + + # (optional) shared experts for all forwards + self.shared_experts = None + if config.moe_num_shared_experts > 0: + self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # (Optional) shared experts + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + device_type = ( + hidden_states.device.type + if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states.float()) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) + routing_weights = routing_weights / torch.clamp( + routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min + ) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + # Add (optional) shared experts to the result + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Ernie4_5_MoeAttention(config, layer_idx) + + if ( + ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= config.moe_layer_start_index + and layer_idx <= config.moe_layer_end_index + ): + self.mlp = Ernie4_5_MoeSparseMoeBlock(config) + else: + self.mlp = Ernie4_5_MoeMLP(config) + + self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring +class Ernie4_5_MoePreTrainedModel(PreTrainedModel): + config: Ernie4_5_MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Ernie4_5_MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1), + "hidden_states": Ernie4_5_MoeDecoderLayer, + "attentions": Ernie4_5_MoeAttention, + } + _keep_in_fp32_modules_strict = ["gate", "moe_statics"] + # Not supporting multi-token prediction (MTP) atm + _keys_to_ignore_on_load_unexpected = ["mtp"] + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Ernie4_5_MoeStatics): + module.e_score_correction_bias.data.zero_() + + +@auto_docstring +class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel): + def __init__(self, config: Ernie4_5_MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class Ernie4_5_MoeForCausalLM(Ernie4_5_MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Ernie4_5_MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = config.moe_k + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeModel", "Ernie4_5_MoePreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..7c949bd455fb82343f6c20c6c90cd5aa3218c5c3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -0,0 +1,346 @@ +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Ernie 4.5 MoE model.""" + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_outputs import MoeModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import OutputRecorder, check_model_inputs +from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401 +from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm +from ..mixtral.modeling_mixtral import ( + MixtralForCausalLM, + MixtralPreTrainedModel, +) +from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeMLP +from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig + + +logger = logging.get_logger(__name__) + + +class Ernie4_5_MoeRMSNorm(LlamaRMSNorm): + pass + + +class Ernie4_5_MoeMLP(Qwen3MoeMLP): + def __init__(self, config, intermediate_size=None): + super().__init__(config, intermediate_size) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + + +class Ernie4_5_MoeRotaryEmbedding(Ernie4_5RotaryEmbedding): + def __init__(self, config: Ernie4_5_MoeConfig, device=None): + super().__init__(config, device) + + +class Ernie4_5_MoeAttention(LlamaAttention): + def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self.attention_dropout = 0.0 + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + +class Ernie4_5_MoeStatics(nn.Module): + """ + Stores MoE (Mixture of Experts) statistics + - Bias for the gating + - Additionally, usage per expert in the original codebase + """ + + def __init__(self, config): + super().__init__() + + num_experts_groups = 1 + num_experts = config.moe_num_experts + + self.e_score_correction_bias = nn.Parameter( + torch.zeros(num_experts_groups, num_experts, dtype=torch.float32), + requires_grad=False, + ) + + def forward(self, hidden_states): + # NOTE: This is a workaround to enable TP with a module that only has parameters + # + # Otherwise, it stays as `DTensor` when called in the "super" forward + # 1. All other tensors are local (`torch.Tensor`) + # 2. Isolate does not work on `nn.Module` which only has parameters + return hidden_states + self.e_score_correction_bias.squeeze() + + +class Ernie4_5_MoeSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + + Ernie 4.5 MoE's original formula is based on case (2) with + (optional) shared experts and a corrections bias during gating. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.moe_num_experts + self.top_k = config.moe_k + + # correction bias (yes it seems to be a typo with statics <> statistics) + self.moe_statics = Ernie4_5_MoeStatics(config) + + # gating + self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) + self.experts = nn.ModuleList( + [Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)] + ) + self.norm_min = config.moe_norm_min + + # (optional) shared experts for all forwards + self.shared_experts = None + if config.moe_num_shared_experts > 0: + self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # (Optional) shared experts + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + device_type = ( + hidden_states.device.type + if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states.float()) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) + routing_weights = routing_weights / torch.clamp( + routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min + ) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + # Add (optional) shared experts to the result + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer): + def __init__(self, config, layer_idx): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = Ernie4_5_MoeAttention(config, layer_idx) + + if ( + ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= config.moe_layer_start_index + and layer_idx <= config.moe_layer_end_index + ): + self.mlp = Ernie4_5_MoeSparseMoeBlock(config) + else: + self.mlp = Ernie4_5_MoeMLP(config) + + self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps) + + +@auto_docstring +class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel): + config: Ernie4_5_MoeConfig + _no_split_modules = ["Ernie4_5_MoeDecoderLayer"] + _keep_in_fp32_modules_strict = ["gate", "moe_statics"] + # Not supporting multi-token prediction (MTP) atm + _keys_to_ignore_on_load_unexpected = ["mtp"] + _can_record_outputs = { + "router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1), + "hidden_states": Ernie4_5_MoeDecoderLayer, + "attentions": Ernie4_5_MoeAttention, + } + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, Ernie4_5_MoeStatics): + module.e_score_correction_bias.data.zero_() + + +@auto_docstring +class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel): + def __init__(self, config: Ernie4_5_MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Ernie4_5_MoeForCausalLM(MixtralForCausalLM): + def __init__(self, config): + PreTrainedModel.__init__(self, config) + self.model = Ernie4_5_MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = config.moe_k + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + super().forward(**super_kwargs) + + +__all__ = [ + "Ernie4_5_MoeForCausalLM", + "Ernie4_5_MoeModel", + "Ernie4_5_MoePreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/esm/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/esm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8eac54d6ddcbdae2b8ca3771ae5540522f6f29da --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/esm/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_esm import * + from .modeling_esm import * + from .modeling_esmfold import * + from .modeling_tf_esm import * + from .tokenization_esm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/esm/configuration_esm.py b/venv/lib/python3.13/site-packages/transformers/models/esm/configuration_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..fabfb4ebd6d34a7f212af5e74a90c18d4a038156 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/esm/configuration_esm.py @@ -0,0 +1,365 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ESM model configuration""" + +from dataclasses import asdict, dataclass +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +# TODO Update this + + +class EsmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ESM + [facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*): + Vocabulary size of the ESM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ESMModel`]. + mask_token_id (`int`, *optional*): + The index of the mask token in the vocabulary. This must be included in the config because of the + "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens. + pad_token_id (`int`, *optional*): + The index of the padding token in the vocabulary. This must be included in the config because certain parts + of the ESM code use this instead of the attention mask. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 1026): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`. + For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + emb_layer_norm_before (`bool`, *optional*): + Whether to apply layer normalization after embeddings but before the main stem of the network. + token_dropout (`bool`, defaults to `False`): + When this is enabled, masked tokens are treated as if they had been dropped out by input dropout. + + Examples: + + ```python + >>> from transformers import EsmModel, EsmConfig + + >>> # Initializing a ESM facebook/esm-1b style configuration + >>> configuration = EsmConfig(vocab_size=33) + + >>> # Initializing a model from the configuration + >>> model = EsmModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "esm" + + def __init__( + self, + vocab_size=None, + mask_token_id=None, + pad_token_id=None, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1026, + initializer_range=0.02, + layer_norm_eps=1e-12, + position_embedding_type="absolute", + use_cache=True, + emb_layer_norm_before=None, + token_dropout=False, + is_folding_model=False, + esmfold_config=None, + vocab_list=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.emb_layer_norm_before = emb_layer_norm_before + self.token_dropout = token_dropout + self.is_folding_model = is_folding_model + if is_folding_model: + if esmfold_config is None: + logger.info("No esmfold_config supplied for folding model, using default values.") + esmfold_config = EsmFoldConfig() + elif isinstance(esmfold_config, dict): + esmfold_config = EsmFoldConfig(**esmfold_config) + self.esmfold_config = esmfold_config + if vocab_list is None: + logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!") + self.vocab_list = get_default_vocab_list() + else: + self.vocab_list = vocab_list + else: + self.esmfold_config = None + self.vocab_list = None + if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False): + raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = super().to_dict() + if isinstance(self.esmfold_config, EsmFoldConfig): + output["esmfold_config"] = self.esmfold_config.to_dict() + return output + + +@dataclass +class EsmFoldConfig: + esm_type: Optional[str] = None + fp16_esm: bool = True + use_esm_attn_map: bool = False + esm_ablate_pairwise: bool = False + esm_ablate_sequence: bool = False + esm_input_dropout: float = 0 + + embed_aa: bool = True + bypass_lm: bool = False + + lddt_head_hid_dim: int = 128 + trunk: "TrunkConfig" = None + + def __post_init__(self): + if self.trunk is None: + self.trunk = TrunkConfig() + elif isinstance(self.trunk, dict): + self.trunk = TrunkConfig(**self.trunk) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = asdict(self) + output["trunk"] = self.trunk.to_dict() + return output + + +@dataclass +class TrunkConfig: + num_blocks: int = 48 + sequence_state_dim: int = 1024 + pairwise_state_dim: int = 128 + sequence_head_width: int = 32 + pairwise_head_width: int = 32 + position_bins: int = 32 + dropout: float = 0 + layer_drop: float = 0 + cpu_grad_checkpoint: bool = False + max_recycles: int = 4 + chunk_size: Optional[int] = 128 + structure_module: "StructureModuleConfig" = None + + def __post_init__(self): + if self.structure_module is None: + self.structure_module = StructureModuleConfig() + elif isinstance(self.structure_module, dict): + self.structure_module = StructureModuleConfig(**self.structure_module) + + if self.max_recycles <= 0: + raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.") + if self.sequence_state_dim % self.sequence_state_dim != 0: + raise ValueError( + "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got" + f" {self.sequence_state_dim} and {self.sequence_state_dim}." + ) + if self.pairwise_state_dim % self.pairwise_state_dim != 0: + raise ValueError( + "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got" + f" {self.pairwise_state_dim} and {self.pairwise_state_dim}." + ) + + sequence_num_heads = self.sequence_state_dim // self.sequence_head_width + pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width + + if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width: + raise ValueError( + "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got" + f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}." + ) + if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width: + raise ValueError( + "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got" + f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}." + ) + if self.pairwise_state_dim % 2 != 0: + raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.") + + if self.dropout >= 0.4: + raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = asdict(self) + output["structure_module"] = self.structure_module.to_dict() + return output + + +@dataclass +class StructureModuleConfig: + """ + Args: + sequence_dim: + Single representation channel dimension + pairwise_dim: + Pair representation channel dimension + ipa_dim: + IPA hidden channel dimension + resnet_dim: + Angle resnet (Alg. 23 lines 11-14) hidden channel dimension + num_heads_ipa: + Number of IPA heads + num_qk_points: + Number of query/key points to generate during IPA + num_v_points: + Number of value points to generate during IPA + dropout_rate: + Dropout rate used throughout the layer + num_blocks: + Number of structure module blocks + num_transition_layers: + Number of layers in the single representation transition (Alg. 23 lines 8-9) + num_resnet_blocks: + Number of blocks in the angle resnet + num_angles: + Number of angles to generate in the angle resnet + trans_scale_factor: + Scale of single representation transition hidden dimension + epsilon: + Small number used in angle resnet normalization + inf: + Large number used for attention masking + """ + + sequence_dim: int = 384 + pairwise_dim: int = 128 + ipa_dim: int = 16 + resnet_dim: int = 128 + num_heads_ipa: int = 12 + num_qk_points: int = 4 + num_v_points: int = 8 + dropout_rate: float = 0.1 + num_blocks: int = 8 + num_transition_layers: int = 1 + num_resnet_blocks: int = 2 + num_angles: int = 7 + trans_scale_factor: int = 10 + epsilon: float = 1e-8 + inf: float = 1e5 + + def to_dict(self): + return asdict(self) + + +def get_default_vocab_list(): + return ( + "", + "", + "", + "", + "L", + "A", + "G", + "V", + "S", + "E", + "R", + "T", + "I", + "D", + "P", + "K", + "Q", + "N", + "F", + "Y", + "M", + "H", + "W", + "C", + "X", + "B", + "U", + "Z", + "O", + ".", + "-", + "", + "", + ) + + +__all__ = ["EsmConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_esm.py b/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8f549d6943b07d06add5bb846ecfb07960859c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_esm.py @@ -0,0 +1,1058 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ESM model.""" + +import math +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_esm import EsmConfig + + +logger = logging.get_logger(__name__) + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +def gelu(x): + """ + This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results. + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose(-1, -2) + + +def average_product_correct(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg.div_(a12) # in-place to reduce memory + normalized = x - avg + return normalized + + +class RotaryEmbedding(torch.nn.Module): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=2): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :] + self._sin_cached = emb.sin()[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype), + ) + + +class EsmContactPredictionHead(nn.Module): + """Performs symmetrization, apc, and computes a logistic regression on the output features""" + + def __init__( + self, + in_features: int, + bias=True, + eos_idx: int = 2, + ): + super().__init__() + self.in_features = in_features + self.eos_idx = eos_idx + self.regression = nn.Linear(in_features, 1, bias) + self.activation = nn.Sigmoid() + + def forward(self, tokens, attentions): + # remove eos token attentions + eos_mask = tokens.ne(self.eos_idx).to(attentions) + eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) + attentions = attentions * eos_mask[:, None, None, :, :] + attentions = attentions[..., :-1, :-1] + # remove cls token attentions + attentions = attentions[..., 1:, 1:] + batch_size, layers, heads, seqlen, _ = attentions.size() + attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) + + # features: batch x channels x tokens x tokens (symmetric) + attentions = attentions.to( + self.regression.weight.device + ) # attentions always float32, may need to convert to float16 + attentions = average_product_correct(symmetrize(attentions)) + attentions = attentions.permute(0, 2, 3, 1) + return self.activation(self.regression(attentions).squeeze(3)) + + +class EsmEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + if config.emb_layer_norm_before: + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + else: + self.layer_norm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + self.padding_idx = config.pad_token_id + if self.position_embedding_type == "absolute": + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + inputs_embeds=None, + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths + embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to( + embeddings.dtype + ) + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + # Matt: I think this line was copied incorrectly from BERT, disabling it for now. + # embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], +): + # ESM applies relative position embeddings and we don't copy from Llama + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if hasattr(module, "position_embedding_type") and module.position_embedding_type in [ + "relative_key", + "relative_key_query", + ]: + seq_length = query.shape[2] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility + + if module.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + elif module.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) + relative_position_scores = relative_position_scores_query + relative_position_scores_key + + attn_weights = attn_weights + relative_position_scores + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EsmSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False): + super().__init__() + self.config = config + + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = config.attention_probs_dropout_prob + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) + + self.scaling = 1.0 # For BC we apply scaling before RoPE + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + self.is_causal = self.is_decoder and not is_cross_attention + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + batch_size, seq_length = hidden_states.shape[:-1] + hidden_shape = (batch_size, seq_length, -1, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # ESM code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.position_embedding_type in ["relative_key", "relative_key_query"]: + raise ValueError( + f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. " + "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`" + ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + return attn_output, attn_weights + + +class EsmSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EsmAttention(nn.Module): + def __init__(self, config, layer_idx=None, is_cross_attention=False): + super().__init__() + self.self = EsmSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention) + self.output = EsmSelfOutput(config) + self.pruned_heads = set() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + hidden_states_ln = self.LayerNorm(hidden_states) + attn_output, _ = self.self( + hidden_states_ln, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + attn_output = self.output(attn_output, hidden_states) + return attn_output + + +class EsmIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = gelu(hidden_states) + return hidden_states + + +class EsmOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EsmLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = EsmAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = EsmAttention(config, is_cross_attention=True) + self.intermediate = EsmIntermediate(config) + self.output = EsmOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + attention_output = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + **kwargs, + ) + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated" + " with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + attention_output = self.crossattention( + attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + layer_output = self.feed_forward_chunk(attention_output) + return layer_output + + def feed_forward_chunk(self, attention_output): + attention_output_ln = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(attention_output_ln) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class EsmEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) + self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module( + hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class EsmPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class EsmPreTrainedModel(PreTrainedModel): + config: EsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = True + accepts_loss_kwargs = False + _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": EsmLayer, + "attentions": [OutputRecorder(EsmSelfAttention, index=1, layer_name="attention")], + "cross_attentions": [ + OutputRecorder(EsmSelfAttention, index=1, layer_name="crossattention"), + ], + } + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, EsmLMHead): + module.bias.data.zero_() + + def get_output_embeddings(self): + # NOTE: get_output_embeddings() must return None to prevent accidental weight tying. + # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400 + return None + + +@auto_docstring +class EsmModel(EsmPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = EsmEmbeddings(config) + self.encoder = EsmEncoder(config) + + self.pooler = EsmPooler(config) if add_pooling_layer else None + + self.contact_head = EsmContactPredictionHead( + in_features=config.num_hidden_layers * config.num_attention_heads, bias=True + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + ) + + if self.config._attn_implementation != "flash_attention_2": + batch_size, seq_length = inputs_embeds.shape[:-1] + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device) + + attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape=(batch_size, seq_length) + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + inputs_embeds, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + def predict_contacts(self, tokens, attention_mask): + attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions + attns = torch.stack(attns, dim=1) # Matches the original model layout + # In the original model, attentions for padding tokens are completely zeroed out. + # This makes no difference most of the time because the other tokens won't attend to them, + # but it does for the contact prediction task, which takes attentions as input, + # so we have to mimic that here. + attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) + attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) + return self.contact_head(tokens, attns) + + +@auto_docstring +class EsmForMaskedLM(EsmPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = EsmModel(config, add_pooling_layer=False) + self.lm_head = EsmLMHead(config) + + self.init_weights() + + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(prediction_scores.device) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def predict_contacts(self, tokens, attention_mask): + return self.esm.predict_contacts(tokens, attention_mask=attention_mask) + + +class EsmLMHead(nn.Module): + """ESM Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + self.bias + return x + + +@auto_docstring( + custom_intro=""" + ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """ +) +class EsmForSequenceClassification(EsmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.esm = EsmModel(config, add_pooling_layer=False) + self.classifier = EsmClassificationHead(config) + + self.init_weights() + + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class EsmForTokenClassification(EsmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = EsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class EsmClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "EsmForMaskedLM", + "EsmForSequenceClassification", + "EsmForTokenClassification", + "EsmModel", + "EsmPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_esmfold.py b/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_esmfold.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc1f0dbdc701991a4109ddbc617eb1b3769c6a1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_esmfold.py @@ -0,0 +1,2309 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import sys +from collections.abc import Sequence +from dataclasses import dataclass +from functools import partial +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from ...integrations.deepspeed import is_deepspeed_available +from ...modeling_outputs import ModelOutput +from ...utils import ( + ContextManagers, + auto_docstring, + is_scipy_available, + logging, +) +from .modeling_esm import EsmModel, EsmPreTrainedModel +from .openfold_utils import ( + OFProtein, + Rigid, + Rotation, + atom14_to_atom37, + chunk_layer, + compute_predicted_aligned_error, + compute_tm, + frames_and_literature_positions_to_atom14_pos, + make_atom14_masks, + residue_constants, + to_pdb, + torsion_angles_to_frames, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`EsmForProteinFoldingOutput`]. + """ +) +class EsmForProteinFoldingOutput(ModelOutput): + r""" + frames (`torch.FloatTensor`): + Output frames. + sidechain_frames (`torch.FloatTensor`): + Output sidechain frames. + unnormalized_angles (`torch.FloatTensor`): + Predicted unnormalized backbone and side chain torsion angles. + angles (`torch.FloatTensor`): + Predicted backbone and side chain torsion angles. + positions (`torch.FloatTensor`): + Predicted positions of the backbone and side chain atoms. + states (`torch.FloatTensor`): + Hidden states from the protein folding trunk. + s_s (`torch.FloatTensor`): + Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem. + s_z (`torch.FloatTensor`): + Pairwise residue embeddings. + distogram_logits (`torch.FloatTensor`): + Input logits to the distogram used to compute residue distances. + lm_logits (`torch.FloatTensor`): + Logits output by the ESM-2 protein language model stem. + aatype (`torch.FloatTensor`): + Input amino acids (AlphaFold2 indices). + atom14_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom14 representation. + residx_atom14_to_atom37 (`torch.FloatTensor`): + Mapping between atoms in the atom14 and atom37 representations. + residx_atom37_to_atom14 (`torch.FloatTensor`): + Mapping between atoms in the atom37 and atom14 representations. + atom37_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom37 representation. + residue_index (`torch.FloatTensor`): + The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be + a sequence of integers from 0 to `sequence_length`. + lddt_head (`torch.FloatTensor`): + Raw outputs from the lddt head used to compute plddt. + plddt (`torch.FloatTensor`): + Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is + uncertain, or where the protein structure is disordered. + ptm_logits (`torch.FloatTensor`): + Raw logits used for computing ptm. + ptm (`torch.FloatTensor`): + TM-score output representing the model's high-level confidence in the overall structure. + aligned_confidence_probs (`torch.FloatTensor`): + Per-residue confidence scores for the aligned structure. + predicted_aligned_error (`torch.FloatTensor`): + Predicted error between the model's prediction and the ground truth. + max_predicted_aligned_error (`torch.FloatTensor`): + Per-sample maximum predicted error. + """ + + frames: Optional[torch.FloatTensor] = None + sidechain_frames: Optional[torch.FloatTensor] = None + unnormalized_angles: Optional[torch.FloatTensor] = None + angles: Optional[torch.FloatTensor] = None + positions: Optional[torch.FloatTensor] = None + states: Optional[torch.FloatTensor] = None + s_s: Optional[torch.FloatTensor] = None + s_z: Optional[torch.FloatTensor] = None + distogram_logits: Optional[torch.FloatTensor] = None + lm_logits: Optional[torch.FloatTensor] = None + aatype: Optional[torch.FloatTensor] = None + atom14_atom_exists: Optional[torch.FloatTensor] = None + residx_atom14_to_atom37: Optional[torch.FloatTensor] = None + residx_atom37_to_atom14: Optional[torch.FloatTensor] = None + atom37_atom_exists: Optional[torch.FloatTensor] = None + residue_index: Optional[torch.FloatTensor] = None + lddt_head: Optional[torch.FloatTensor] = None + plddt: Optional[torch.FloatTensor] = None + ptm_logits: Optional[torch.FloatTensor] = None + ptm: Optional[torch.FloatTensor] = None + aligned_confidence_probs: Optional[torch.FloatTensor] = None + predicted_aligned_error: Optional[torch.FloatTensor] = None + max_predicted_aligned_error: Optional[torch.FloatTensor] = None + + +def is_fp16_enabled(device_type): + # Autocast world + autocast_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + fp16_enabled = autocast_dtype == torch.float16 + fp16_enabled = fp16_enabled and torch.is_autocast_enabled() + + return fp16_enabled + + +def is_deepspeed_initialized(): + if is_deepspeed_available(): + return False + else: + try: + import deepspeed + + # This is not available in all DeepSpeed versions. + return deepspeed.utils.is_initialized() + except Exception: + return False + + +def collate_dense_tensors(samples: list[torch.Tensor], pad_v: float = 0) -> torch.Tensor: + """ + Takes a list of tensors with the following dimensions: + [(d_11, ..., d_1K), + (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)] + and stack + pads them into a single tensor of: + (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) + """ + if len(samples) == 0: + return torch.Tensor() + if len({x.dim() for x in samples}) != 1: + raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}") + (device,) = tuple({x.device for x in samples}) # assumes all on same device + max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] + result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device) + result.fill_(pad_v) + for i in range(len(samples)): + result_i = result[i] + t = samples[i] + result_i[tuple(slice(0, k) for k in t.shape)] = t + return result + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def permute_final_dims(tensor: torch.Tensor, inds: list[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if isinstance(v, dict): + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + scale = scale / max(1, shape[1]) + + if not is_scipy_available(): + logger.warning( + "This init requires scipy, but scipy was not found, default to an approximation that might not be" + " equivalent." + ) + std = math.sqrt(scale) + torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std) + + else: + from scipy.stats import truncnorm + + std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1) + samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel()) + samples = np.reshape(samples, shape) + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class EsmFoldLinear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal + distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal": + Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. Overrides init if not None. + """ + super().__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + self.init = init + self.init_fn = init_fn + + if init not in ["default", "relu", "glorot", "gating", "normal", "final"]: + raise ValueError("Invalid init string.") + + +class EsmFoldLayerNorm(nn.Module): + def __init__(self, c_in, eps=1e-5): + super().__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + d = x.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.autocast(device_type="cuda", enabled=False): + out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps) + else: + out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps) + + return out + + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of type bfloat16 + """ + d = t.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.autocast(device_type="cuda", enabled=False): + s = torch.nn.functional.softmax(t, dim=dim) + else: + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +class EsmFoldAttention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super().__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + # [*, H, Q/K, C_hidden] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if self.linear_g is not None: + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[list[torch.Tensor]] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + lma_q_chunk_size: int = 1024, + lma_kv_chunk_size: int = 4096, + use_flash: bool = False, + flash_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_memory_efficient_kernel: + Whether to use a custom memory-efficient attention kernel. This should be the default choice for most. + If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead + use_lma: + Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a + stock PyTorch implementation is used instead + lma_q_chunk_size: + Query chunk size (for LMA) + lma_kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): + raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided") + + if use_flash and biases is not None: + raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead") + + attn_options = [use_memory_efficient_kernel, use_lma, use_flash] + if sum(attn_options) > 1: + raise ValueError("Choose at most one alternative attention algorithm") + + if biases is None: + biases = [] + + # [*, H, Q/K, C_hidden] + query, key, value = self._prep_qkv(q_x, kv_x) + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + output = torch.matmul(query, key) + for b in biases: + output += b + output = softmax_no_cast(output, -1) + + # [*, H, Q, C_hidden] + output = torch.matmul(output, value) + output = output.transpose(-2, -3) + output = self._wrap_up(output, q_x) + + return output + + +class EsmFoldTriangleAttention(nn.Module): + def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super().__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + biases: list[torch.Tensor], + chunk_size: int, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + "triangle! triangle!" + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + + return chunk_layer( + partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + _out=x if inplace_safe else None, + ) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk( + x, + biases, + chunk_size, + use_memory_efficient_kernel=use_memory_efficient_kernel, + use_lma=use_lma, + inplace_safe=inplace_safe, + ) + else: + x = self.mha( + q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma + ) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class EsmFoldTriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + + def __init__(self, config, _outgoing=True): + super().__init__() + c_hidden = config.pairwise_state_dim + self._outgoing = _outgoing + + self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final") + + self.layer_norm_in = LayerNorm(c_hidden) + self.layer_norm_out = LayerNorm(c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections( + self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None + ) -> torch.Tensor: + if self._outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = permute_final_dims(b, (2, 1, 0)) + else: + a = permute_final_dims(a, (2, 1, 0)) + b = permute_final_dims(b, (2, 0, 1)) + + if _inplace_chunk_size is not None: + # To be replaced by torch vmap + for i in range(0, a.shape[-3], _inplace_chunk_size): + a_chunk = a[..., i : i + _inplace_chunk_size, :, :] + b_chunk = b[..., i : i + _inplace_chunk_size, :, :] + a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul( + a_chunk, + b_chunk, + ) + + p = a + else: + p = torch.matmul(a, b) + + return permute_final_dims(p, (1, 2, 0)) + + def _inference_forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_chunk_size: Optional[int] = None, + with_add: bool = True, + ): + """ + Args: + z: + A [*, N, N, C_z] pair representation + mask: + A [*, N, N] pair mask + inplace_chunk_size: + Size of chunks used in the main computation. Increase to trade memory for speed. + with_add: + If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update). + Returns: + A reference to the overwritten z + + More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the + addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten + values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size. + Useful for inference on extremely long sequences. + + It works as follows. We will make reference to variables used in the default forward implementation below. + Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the + "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask, + and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for + N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate + tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the + tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over + pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains + inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring + total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks + directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at + the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column + ahead of previously overwritten columns and can be recovered directly from z. After the first iteration, + however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache, + a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For + 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith + iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead. + Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the + z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache. + After the final iteration, z has been completely overwritten and contains the triangular multiplicative update. + If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case, + peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small + variables. + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + def compute_projection_helper(pair, mask, a=True): + if a: + linear_g = self.linear_a_g + linear_p = self.linear_a_p + else: + linear_g = self.linear_b_g + linear_p = self.linear_b_p + + pair = self.layer_norm_in(pair) + p = linear_g(pair) + p.sigmoid_() + p *= linear_p(pair) + p *= mask + p = permute_final_dims(p, (2, 0, 1)) + return p + + def compute_projection(pair, mask, a=True, chunked=True): + need_transpose = self._outgoing ^ a + if not chunked: + p = compute_projection_helper(pair, mask, a) + if need_transpose: + p = p.transpose(-1, -2) + else: + # This computation is chunked so as not to exceed our 2.5x + # budget with a large intermediate tensor + linear_g = self.linear_a_g if a else self.linear_b_g + c = linear_g.bias.shape[-1] + out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1] + p = pair.new_zeros(out_shape) + for i in range(0, pair.shape[-3], inplace_chunk_size): + pair_chunk = pair[..., i : i + inplace_chunk_size, :, :] + pair_chunk = compute_projection_helper( + pair[..., i : i + inplace_chunk_size, :, :], + mask[..., i : i + inplace_chunk_size, :, :], + a, + ) + if need_transpose: + pair_chunk = pair_chunk.transpose(-1, -2) + p[..., i : i + inplace_chunk_size] = pair_chunk + else: + p[..., i : i + inplace_chunk_size, :] = pair_chunk + + del pair_chunk + + return p + + # We start by fully manifesting a. In addition to the input, this + # brings total memory consumption to 2x z (disregarding size of chunks) + # [*, N, N, c] + a = compute_projection(z, mask, True, chunked=True) + + if inplace_chunk_size is not None: + n = a.shape[-1] + half_n = n // 2 + n % 2 + row_dim = -3 + col_dim = -2 + b_chunk_dim = row_dim if self._outgoing else col_dim + + def empty_slicer(t): + return [slice(None) for _ in t.shape] + + def slice_tensor(t, start, end, dim): + # Slices start:end from the dim dimension of t + s = empty_slicer(t) + s[dim] = slice(start, end) + return t[s] + + def flip_z_cache_(z_cache, z): + # "Reorient" the z_cache (see below), filling it with quadrants + # 3---recovered from the z_cache---and 4---recovered from z--- + # of the input tensor z. + quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim) + z_cache = z_cache.transpose(row_dim, col_dim) + + # If n is odd, we need to shrink the z_cache by one row + z_cache = z_cache[..., : (n // 2), :, :] + + # Move the 3rd quadrant of z into the + first_half_slicer = empty_slicer(z_cache) + first_half_slicer[col_dim] = slice(0, half_n) + z_cache[first_half_slicer] = quadrant_3 + + # Get the fourth quadrant of z + quadrant_4 = slice_tensor(z, half_n, None, row_dim) + quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim) + + # Insert said quadrant into the rotated z-cache + quadrant_3_slicer = empty_slicer(z_cache) + quadrant_3_slicer[col_dim] = slice(half_n, None) + + z_cache[quadrant_3_slicer] = quadrant_4 + + return z_cache + + # Initialize the z cache to the left half of z. + z_cache_shape = list(z.shape) + z_cache_shape[col_dim] = half_n + z_cache = z.new_zeros(z_cache_shape) + z_cache_slicer = empty_slicer(z_cache) + z_cache_slicer[col_dim] = slice(0, half_n) + z_cache.copy_(z[z_cache_slicer]) + z_cache_rotated = False + + # We need to reorient the z-cache at the halfway point, and we + # don't want a single chunk to straddle that point. We contract one + # of the chunks in the middle to address that problem. + i_range = list(range(0, half_n, inplace_chunk_size)) + initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])] + after_half = list(range(half_n, n, inplace_chunk_size)) + after_half_offsets = [inplace_chunk_size for _ in after_half] + combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets) + for i, offset in combined_range_with_offsets: + if not z_cache_rotated and i >= half_n: + z_cache = flip_z_cache_(z_cache, z) + z_cache_rotated = True + + z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim) + mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim) + + z_chunk_b = z_chunk_b.clone() + if b_chunk_dim == col_dim: + z_chunk_b = slice_tensor(z, i, i + offset, col_dim) + else: # b_chunk_dim == row_dim + # In this case, the b-dimension (b_chunk_dim) is partially + # overwritten at the end of each iteration. We need to + # restore the missing component from the z-cache. + if not z_cache_rotated: + z_chunk_slicer = empty_slicer(z_chunk_b) + z_chunk_slicer[col_dim] = slice(0, half_n) + z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim) + else: + z_cache_offset = i - half_n + z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim) + + b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False) + del z_chunk_b + + x_chunk = torch.matmul(a, b_chunk) + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + + # The g dimension (col_dim) is parallel to and ahead of the + # overwrites in z. We can extract the g chunk normally. + z_chunk_g = slice_tensor(z, i, i + offset, col_dim) + g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g)) + g_chunk.sigmoid_() + del z_chunk_g + + x_chunk *= g_chunk + + # Write the columns into z in-place + z_slicer = empty_slicer(z) + z_slicer[col_dim] = slice(i, i + offset) + if with_add: + z[z_slicer] += x_chunk + else: + z[z_slicer] = x_chunk + else: + b = compute_projection(z, mask, False, False) + x = torch.matmul(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.linear_g(z) + g.sigmoid_() + x *= g + if with_add: + z += x + else: + z = x + + return z + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_safe: bool = False, + _add_with_inplace: bool = False, + _inplace_chunk_size: Optional[int] = 256, + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if inplace_safe: + x = self._inference_forward( + z, + mask, + inplace_chunk_size=_inplace_chunk_size, + with_add=_add_with_inplace, + ) + return x + + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = mask + a = a * self.sigmoid(self.linear_a_g(z)) + a = a * self.linear_a_p(z) + b = mask + b = b * self.sigmoid(self.linear_b_g(z)) + b = b * self.linear_b_p(z) + + device_type = a.device.type if a.device.type != "mps" else "cpu" + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): + x = self._combine_projections(a.float(), b.float()) + else: + x = self._combine_projections(a, b) + + del a, b + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + x = x * g + + return x + + +class EsmFoldPreTrainedModel(EsmPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + # Subclass `EsMPreTrainedModel` to deal with special init + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, EsmFoldLinear): + with torch.no_grad(): + if module.init_fn is not None: + module.init_fn(module.weight, module.bias) + elif module.init == "default": + trunc_normal_init_(module.weight, scale=1.0) + elif module.init == "relu": + trunc_normal_init_(module.weight, scale=2.0) + elif module.init == "glorot": + nn.init.xavier_uniform_(module.weight, gain=1) + elif module.init == "gating": + module.weight.fill_(0.0) + if module.bias: + module.bias.fill_(1.0) + elif module.init == "normal": + torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear") + elif module.init == "final": + module.weight.fill_(0.0) + elif isinstance(module, EsmFoldInvariantPointAttention): + ipa_point_weights_init_(module.head_weights) + elif isinstance(module, EsmFoldTriangularSelfAttentionBlock): + torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias) + + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight) + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias) + torch.nn.init.zeros_(module.pair_to_sequence.linear.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.bias) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias) + else: + super()._init_weights(module) + + +class EsmFoldSelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, head_width, gated=False): + super().__init__() + assert embed_dim == num_heads * head_width + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_width = head_width + + self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) + self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.gated = gated + if gated: + self.g_proj = nn.Linear(embed_dim, embed_dim) + torch.nn.init.zeros_(self.g_proj.weight) + torch.nn.init.ones_(self.g_proj.bias) + + self.rescale_factor = self.head_width**-0.5 + + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, x, mask=None, bias=None, indices=None): + """ + Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths, + use mask. + + Inputs: + x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (.. + x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads) + + Outputs: + sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) + """ + + t = self.proj(x).view(*x.shape[:2], self.num_heads, -1) + t = t.permute(0, 2, 1, 3) + q, k, v = t.chunk(3, dim=-1) + + q = self.rescale_factor * q + a = torch.einsum("...qc,...kc->...qk", q, k) + + # Add external attention bias. + if bias is not None: + a = a + bias.permute(0, 3, 1, 2) + + # Do not attend to padding tokens. + if mask is not None: + mask = mask[:, None, None] + a = a.masked_fill(mask == False, -np.inf) # noqa: E712 + + a = nn.functional.softmax(a, dim=-1) + + y = torch.einsum("...hqk,...hkc->...qhc", a, v) + y = y.reshape(*y.shape[:2], -1) + + if self.gated: + y = self.g_proj(x).sigmoid() * y + y = self.o_proj(y) + + return y, a.permute(0, 3, 1, 2) + + +class EsmFoldDropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask along a particular dimension. + """ + + def __init__(self, r: float, batch_dim: Union[int, list[int]]): + super().__init__() + + self.r = r + if isinstance(batch_dim, int): + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + return x * self.dropout(x.new_ones(shape)) + + +class EsmFoldSequenceToPair(nn.Module): + def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): + super().__init__() + + self.layernorm = nn.LayerNorm(sequence_state_dim) + self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) + self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) + + torch.nn.init.zeros_(self.proj.bias) + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, sequence_state): + """ + Inputs: + sequence_state: B x L x sequence_state_dim + + Output: + pairwise_state: B x L x L x pairwise_state_dim + + Intermediate state: + B x L x L x 2*inner_dim + """ + + assert len(sequence_state.shape) == 3 + + s = self.layernorm(sequence_state) + s = self.proj(s) + q, k = s.chunk(2, dim=-1) + + prod = q[:, None, :, :] * k[:, :, None, :] + diff = q[:, None, :, :] - k[:, :, None, :] + + x = torch.cat([prod, diff], dim=-1) + x = self.o_proj(x) + + return x + + +class EsmFoldPairToSequence(nn.Module): + def __init__(self, pairwise_state_dim, num_heads): + super().__init__() + + self.layernorm = nn.LayerNorm(pairwise_state_dim) + self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False) + + def forward(self, pairwise_state): + """ + Inputs: + pairwise_state: B x L x L x pairwise_state_dim + + Output: + pairwise_bias: B x L x L x num_heads + """ + assert len(pairwise_state.shape) == 4 + z = self.layernorm(pairwise_state) + pairwise_bias = self.linear(z) + return pairwise_bias + + +class EsmFoldResidueMLP(nn.Module): + def __init__(self, embed_dim, inner_dim, dropout=0): + super().__init__() + + self.mlp = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, inner_dim), + nn.ReLU(), + nn.Linear(inner_dim, embed_dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return x + self.mlp(x) + + +class EsmFoldTriangularSelfAttentionBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + sequence_state_dim = config.sequence_state_dim + pairwise_state_dim = config.pairwise_state_dim + sequence_num_heads = sequence_state_dim // config.sequence_head_width + pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width + + self.layernorm_1 = nn.LayerNorm(sequence_state_dim) + + self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim) + self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads) + + self.seq_attention = EsmFoldSelfAttention( + sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True + ) + self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True) + self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False) + + self.tri_att_start = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True + ) + self.tri_att_end = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False + ) + + self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout) + self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout) + + self.drop = nn.Dropout(config.dropout) + self.row_drop = EsmFoldDropout(config.dropout * 2, 2) + self.col_drop = EsmFoldDropout(config.dropout * 2, 1) + + def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): + """ + Inputs: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean + tensor of valid positions + + Output: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim + """ + if len(sequence_state.shape) != 3: + raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.") + if len(pairwise_state.shape) != 4: + raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.") + if mask is not None and len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + + batch_dim, seq_dim, sequence_state_dim = sequence_state.shape + pairwise_state_dim = pairwise_state.shape[3] + + if sequence_state_dim != self.config.sequence_state_dim: + raise ValueError( + "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got " + f"{sequence_state_dim} != {self.config.sequence_state_dim}." + ) + if pairwise_state_dim != self.config.pairwise_state_dim: + raise ValueError( + "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got " + f"{pairwise_state_dim} != {self.config.pairwise_state_dim}." + ) + if batch_dim != pairwise_state.shape[0]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != " + f"{pairwise_state.shape[0]}." + ) + if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != " + f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}." + ) + + # Update sequence state + bias = self.pair_to_sequence(pairwise_state) + + # Self attention with bias + mlp. + y = self.layernorm_1(sequence_state) + y, _ = self.seq_attention(y, mask=mask, bias=bias) + sequence_state = sequence_state + self.drop(y) + sequence_state = self.mlp_seq(sequence_state) + + # Update pairwise state + pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) + + # Axial attention with triangular bias. + tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None + pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.row_drop( + self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + pairwise_state = pairwise_state + self.col_drop( + self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + + # MLP over pairs. + pairwise_state = self.mlp_pair(pairwise_state) + + return sequence_state, pairwise_state + + +class EsmCategoricalMixture: + def __init__(self, param, bins=50, start=0, end=1): + # All tensors are of shape ..., bins. + self.logits = param + bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype) + self.v_bins = (bins[:-1] + bins[1:]) / 2 + + def log_prob(self, true): + # Shapes are: + # self.probs: ... x bins + # true : ... + true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1) + nll = self.logits.log_softmax(-1) + return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1) + + def mean(self): + return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1) + + +def categorical_lddt(logits, bins=50): + # Logits are ..., 37, bins. + return EsmCategoricalMixture(logits, bins=bins).mean() + + +def get_axial_mask(mask): + """ + Helper to convert B x L mask of valid positions to axial mask used in row column attentions. + + Input: + mask: B x L tensor of booleans + + Output: + mask: B x L x L tensor of booleans + """ + + if mask is None: + return None + + if len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + batch_dim, seq_dim = mask.shape + m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) + m = m.reshape(batch_dim * seq_dim, seq_dim) + return m + + +class EsmFoldRelativePosition(nn.Module): + def __init__(self, config): + super().__init__() + self.bins = config.position_bins + + # Note an additional offset is used so that the 0th position + # is reserved for masked pairs. + self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim) + + def forward(self, residue_index, mask=None): + """ + Input: + residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans + + Output: + pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings + """ + if residue_index.dtype != torch.long: + raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.") + if mask is not None and residue_index.shape != mask.shape: + raise ValueError( + f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}." + ) + + diff = residue_index[:, None, :] - residue_index[:, :, None] + diff = diff.clamp(-self.bins, self.bins) + diff = diff + self.bins + 1 # Add 1 to adjust for padding index. + + if mask is not None: + mask = mask[:, None, :] * mask[:, :, None] + diff[mask == False] = 0 # noqa: E712 + + output = self.embedding(diff) + return output + + +class EsmFoldAngleResnetBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class EsmFoldAngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + + self.layers = nn.ModuleList() + for _ in range(config.num_resnet_blocks): + layer = EsmFoldAngleResnetBlock(config) + self.layers.append(layer) + + self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2) + + self.relu = nn.ReLU() + + def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s**2, dim=-1, keepdim=True), + min=self.config.epsilon, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class EsmFoldInvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_dim + c_z = config.pairwise_dim + self.hidden_dim = config.ipa_dim + self.num_heads = config.num_heads_ipa + self.num_qk_points = config.num_qk_points + self.num_v_points = config.num_v_points + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = config.ipa_dim * config.num_heads_ipa + self.linear_q = EsmFoldLinear(c_s, hc) + self.linear_kv = EsmFoldLinear(c_s, 2 * hc) + + hpq = config.num_heads_ipa * config.num_qk_points * 3 + self.linear_q_points = EsmFoldLinear(c_s, hpq) + + hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3 + self.linear_kv_points = EsmFoldLinear(c_s, hpkv) + + self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa) + + self.head_weights = nn.Parameter(torch.zeros(config.num_heads_ipa)) + + concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4) + self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Rigid, + mask: torch.Tensor, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.hidden_dim, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3)) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if _offload_inference: + assert sys.getrefcount(z[0]) == 2 + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + device_type = q.device.type if q.device.type != "mps" else "cpu" + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): + a = torch.matmul( + permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + else: + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + + a *= math.sqrt(1.0 / (3 * self.hidden_dim)) + a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) + + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att**2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2))) + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.config.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if _offload_inference: + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype) + ) + + return s + + +class EsmFoldBackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, config): + super().__init__() + + self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final") + + def forward(self, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class EsmFoldStructureModuleTransitionLayer(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class EsmFoldStructureModuleTransition(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.layers = nn.ModuleList() + for _ in range(config.num_transition_layers): + l = EsmFoldStructureModuleTransitionLayer(config) + self.layers.append(l) + + self.dropout = nn.Dropout(config.dropout_rate) + self.layer_norm = LayerNorm(config.sequence_dim) + + def forward(self, s): + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class EsmFoldStructureModule(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + # Buffers to be lazily initialized later + # self.default_frames + # self.group_idx + # self.atom_mask + # self.lit_positions + + self.layer_norm_s = LayerNorm(config.sequence_dim) + self.layer_norm_z = LayerNorm(config.pairwise_dim) + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim) + + self.ipa = EsmFoldInvariantPointAttention(config) + + self.ipa_dropout = nn.Dropout(config.dropout_rate) + self.layer_norm_ipa = LayerNorm(config.sequence_dim) + + self.transition = EsmFoldStructureModuleTransition(config) + self.bb_update = EsmFoldBackboneUpdate(config) + self.angle_resnet = EsmFoldAngleResnet(config) + + def forward( + self, + evoformer_output_dict, + aatype, + mask=None, + _offload_inference=False, + ): + """ + Args: + evoformer_output_dict: + Dictionary containing: + "single": + [*, N_res, C_s] single representation + "pair": + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + s = evoformer_output_dict["single"] + + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(evoformer_output_dict["pair"]) + + z_reference_list = None + if _offload_inference: + assert sys.getrefcount(evoformer_output_dict["pair"]) == 2 + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() + z_reference_list = [z] + z = None + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.config.num_blocks): + # [*, N, C_s] + s = s + self.ipa( + s, + z, + rigids, + mask, + _offload_inference=_offload_inference, + _z_reference_list=z_reference_list, + ) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids.compose_q_update_vec(self.bb_update(s)) + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype) + + scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + "states": s, + } + + outputs.append(preds) + + rigids = rigids.stop_rot_gradient() + + del z, z_reference_list + + if _offload_inference: + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device) + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if not hasattr(self, "default_frames"): + self.register_buffer( + "default_frames", + torch.tensor( + residue_constants.restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "group_idx"): + self.register_buffer( + "group_idx", + torch.tensor( + residue_constants.restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "atom_mask"): + self.register_buffer( + "atom_mask", + torch.tensor( + residue_constants.restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "lit_positions"): + self.register_buffer( + "lit_positions", + torch.tensor( + residue_constants.restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + + def torsion_angles_to_frames(self, r, alpha, f): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N] + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) + + +class EsmFoldingTrunk(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_state_dim + c_z = config.pairwise_state_dim + + self.pairwise_positional_embedding = EsmFoldRelativePosition(config) + + self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)]) + + self.recycle_bins = 15 + self.recycle_s_norm = nn.LayerNorm(c_s) + self.recycle_z_norm = nn.LayerNorm(c_z) + self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) + self.recycle_disto.weight[0].detach().zero_() + + self.structure_module = EsmFoldStructureModule(config.structure_module) + self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim) + self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim) + + self.chunk_size = config.chunk_size + + def set_chunk_size(self, chunk_size): + # This parameter means the axial attention will be computed + # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). + # It's equivalent to running a for loop over chunks of the dimension we're iterative over, + # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks. + self.chunk_size = chunk_size + + def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): + """ + Inputs: + seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B + x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues + + Output: + predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object + """ + + device = seq_feats.device + s_s_0 = seq_feats + s_z_0 = pair_feats + + if no_recycles is None: + no_recycles = self.config.max_recycles + else: + if no_recycles < 0: + raise ValueError("Number of recycles must not be negative.") + no_recycles += 1 # First 'recycle' is just the standard forward pass through the model. + + def trunk_iter(s, z, residx, mask): + z = z + self.pairwise_positional_embedding(residx, mask=mask) + + for block in self.blocks: + s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) + return s, z + + s_s = s_s_0 + s_z = s_z_0 + recycle_s = torch.zeros_like(s_s) + recycle_z = torch.zeros_like(s_z) + recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) + + for recycle_idx in range(no_recycles): + with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]): + # === Recycling === + recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device) + recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device) + recycle_z += self.recycle_disto(recycle_bins.detach()).to(device) + + s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) + + # === Structure module === + structure = self.structure_module( + {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, + true_aa, + mask.float(), + ) + + recycle_s = s_s + recycle_z = s_z + # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold. + recycle_bins = EsmFoldingTrunk.distogram( + structure["positions"][-1][:, :, :3], + 3.375, + 21.375, + self.recycle_bins, + ) + + structure["s_s"] = s_s + structure["s_z"] = s_z + + return structure + + @staticmethod + def distogram(coords, min_bin, max_bin, num_bins): + # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. + boundaries = torch.linspace( + min_bin, + max_bin, + num_bins - 1, + device=coords.device, + ) + boundaries = boundaries**2 + N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] + # Infer CB coordinates. + b = CA - N + c = C - CA + a = b.cross(c, dim=-1) + CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA + dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) + bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L] + return bins + + +# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare +# the outputs for downstream use. + + +@auto_docstring( + custom_intro=""" + ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed + by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to + the rest of the model combined! It outputs a dictionary containing predicted structural information about the input + protein(s). + """ +) +class EsmForProteinFolding(EsmPreTrainedModel): + _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"] + _supports_flash_attn = False + _supports_sdpa = False + _supports_attention_backend = False + + _can_record_outputs = None + + def __init__(self, config): + super().__init__(config) + + self.config = config + + self.distogram_bins = 64 + + self.esm = EsmModel(config, add_pooling_layer=False) + + self.esm.requires_grad_(False) + if self.config.esmfold_config.fp16_esm: + self.esm.half() + + self.esm_feats = self.config.hidden_size + self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads + self.esm_layers = self.config.num_hidden_layers + self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list)) + self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1)) + + trunk_config = self.config.esmfold_config.trunk + c_s = trunk_config.sequence_state_dim + c_z = trunk_config.pairwise_state_dim + self.esm_s_mlp = nn.Sequential( + LayerNorm(self.esm_feats), + nn.Linear(self.esm_feats, c_s), + nn.ReLU(), + nn.Linear(c_s, c_s), + ) + + # 0 is padding, N is unknown residues, N + 1 is mask. + self.n_tokens_embed = residue_constants.restype_num + 3 + self.pad_idx = 0 + self.unk_idx = self.n_tokens_embed - 2 + self.mask_idx = self.n_tokens_embed - 1 + self.esm_dict_cls_idx = self.config.vocab_list.index("") + self.esm_dict_mask_idx = self.config.vocab_list.index("") + self.esm_dict_eos_idx = self.config.vocab_list.index("") + self.esm_dict_padding_idx = self.config.vocab_list.index("") + if self.config.esmfold_config.embed_aa: + self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0) + + self.trunk = EsmFoldingTrunk(trunk_config) + + self.distogram_head = nn.Linear(c_z, self.distogram_bins) + self.ptm_head = nn.Linear(c_z, self.distogram_bins) + self.lm_head = nn.Linear(c_s, self.n_tokens_embed) + self.lddt_bins = 50 + structure_module_config = trunk_config.structure_module + self.lddt_head = nn.Sequential( + nn.LayerNorm(structure_module_config.sequence_dim), + nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins), + ) + + @staticmethod + def _af2_to_esm_from_vocab_list(vocab_list: list[str]) -> torch.Tensor: + # Remember that t is shifted from residue_constants by 1 (0 is padding). + esm_reorder = [vocab_list.index("")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x] + return torch.tensor(esm_reorder) + + @auto_docstring + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + masking_pattern: Optional[torch.Tensor] = None, + num_recycles: Optional[int] = None, + output_hidden_states: Optional[bool] = False, + ) -> EsmForProteinFoldingOutput: + r""" + masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`. + num_recycles (`int`, *optional*, defaults to `None`): + Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling" + consists of passing the output of the folding trunk back in as input to the trunk. During training, the + number of recycles should vary with each batch, to ensure that the model learns to output valid predictions + after each recycle. During inference, num_recycles should be set to the highest value that the model was + trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is + used. + + Example: + + ```python + >>> from transformers import AutoTokenizer, EsmForProteinFolding + + >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") + >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide + >>> outputs = model(**inputs) + >>> folded_positions = outputs.positions + ``` + + """ + cfg = self.config.esmfold_config + + aa = input_ids # B x L + B = aa.shape[0] + L = aa.shape[1] + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones_like(aa, device=device) + if position_ids is None: + position_ids = torch.arange(L, device=device).expand_as(input_ids) + + # === ESM === + esmaa = self.af2_idx_to_esm_idx(aa, attention_mask) + + if masking_pattern is not None: + masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern) + else: + masked_aa = aa + mlm_targets = None + + # We get sequence and pair representations from whatever version of ESM / + # configuration we are using. The sequence representation esm_s is always + # present. The pair embedding esm_z may be present depending on the + # configuration of the model. If esm_z is not used by the model then it + # is returned as None here. + esm_s = self.compute_language_model_representations(esmaa) + + # Convert esm_s and esm_z, if present, to the precision used by the trunk and + # the structure module. These tensors may be a lower precision if, for example, + # we're running the language model in fp16 precision. + esm_s = esm_s.to(self.esm_s_combine.dtype) + + if cfg.esm_ablate_sequence: + esm_s = esm_s * 0 + + esm_s = esm_s.detach() + + # === preprocessing === + esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) + s_s_0 = self.esm_s_mlp(esm_s) + + s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim) + + if self.config.esmfold_config.embed_aa: + s_s_0 += self.embedding(masked_aa) + + structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles) + # Documenting what we expect: + structure = { + k: v + for k, v in structure.items() + if k + in [ + "s_z", + "s_s", + "frames", + "sidechain_frames", + "unnormalized_angles", + "angles", + "positions", + "states", + ] + } + + # Add BERT mask for the loss to use, if available. + if mlm_targets: + structure["mlm_targets"] = mlm_targets + + disto_logits = self.distogram_head(structure["s_z"]) + disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2 + structure["distogram_logits"] = disto_logits + + lm_logits = self.lm_head(structure["s_s"]) + structure["lm_logits"] = lm_logits + + structure["aatype"] = aa + make_atom14_masks(structure) + # Of course, this doesn't respect the true mask because it doesn't know about it... + # We're not going to properly mask change of index tensors: + # "residx_atom14_to_atom37", + # "residx_atom37_to_atom14", + for k in [ + "atom14_atom_exists", + "atom37_atom_exists", + ]: + structure[k] *= attention_mask.unsqueeze(-1) + structure["residue_index"] = position_ids + + lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins) + structure["lddt_head"] = lddt_head + plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins) + structure["plddt"] = plddt + + ptm_logits = self.ptm_head(structure["s_z"]) + structure["ptm_logits"] = ptm_logits + structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins) + structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins)) + + return EsmForProteinFoldingOutput(**structure) + + def af2_idx_to_esm_idx(self, aa, mask): + # avoid indexing on different devices + if self.af2_to_esm.device != aa.device: + self.af2_to_esm = self.af2_to_esm.to(aa.device) + aa = (aa + 1).masked_fill(mask != 1, 0) + return self.af2_to_esm[aa] + + def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor: + device = next(self.parameters()).device + B, L = esmaa.shape # B = batch size, L = sequence length. + + if self.config.esmfold_config.bypass_lm: + esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device) + return esm_s + + bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx + bos = esmaa.new_full((B, 1), bosi) + eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx) + esmaa = torch.cat([bos, esmaa, eos], dim=1) + # Use the first padding index as eos during inference. + esmaa[range(B), (esmaa != 1).sum(1)] = eosi + + # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map) + # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models, + # esm_z is always None + esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"] + esm_s = torch.stack(esm_hidden_states, dim=2) + + esm_s = esm_s[:, 1:-1] # B, L, nLayers, C + + return esm_s + + def bert_mask(self, aa, esmaa, mask, pattern): + new_aa = aa.clone() + target = aa.clone() + new_esmaa = esmaa.clone() + new_aa[pattern == 1] = self.mask_idx + target[pattern != 1] = 0 + new_esmaa[pattern == 1] = self.esm_dict_mask_idx + return new_aa, new_esmaa, target + + @torch.no_grad() + def infer( + self, + seqs: Union[str, list[str]], + position_ids=None, + ): + if isinstance(seqs, str): + lst = [seqs] + else: + lst = seqs + # Returns the raw outputs of the model given an input sequence. + device = next(self.parameters()).device + aatype = collate_dense_tensors( + [ + torch.from_numpy( + residue_constants.sequence_to_onehot( + sequence=seq, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + ) + .to(device) + .argmax(dim=1) + for seq in lst + ] + ) # B=1 x L + mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst]) + position_ids = ( + torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) + if position_ids is None + else position_ids.to(device) + ) + if position_ids.ndim == 1: + position_ids = position_ids.unsqueeze(0) + return self.forward( + aatype, + mask, + position_ids=position_ids, + ) + + @staticmethod + def output_to_pdb(output: dict) -> list[str]: + """Returns the pbd (file) string from the model given the model output.""" + output = {k: v.to("cpu").numpy() for k, v in output.items()} + pdbs = [] + final_atom_positions = atom14_to_atom37(output["positions"][-1], output) + final_atom_mask = output["atom37_atom_exists"] + for i in range(output["aatype"].shape[0]): + aa = output["aatype"][i] + pred_pos = final_atom_positions[i] + mask = final_atom_mask[i] + resid = output["residue_index"][i] + 1 + pred = OFProtein( + aatype=aa, + atom_positions=pred_pos, + atom_mask=mask, + residue_index=resid, + b_factors=output["plddt"][i], + ) + pdbs.append(to_pdb(pred)) + return pdbs + + def infer_pdb(self, seqs, *args, **kwargs) -> str: + """Returns the pdb (file) string from the model given an input sequence.""" + assert isinstance(seqs, str) + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output)[0] + + def infer_pdbs(self, seqs: list[str], *args, **kwargs) -> list[str]: + """Returns the pdb (file) string from the model given an input sequence.""" + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output) + + +__all__ = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_tf_esm.py b/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_tf_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..3fd066868f0e86cd2e130ad170eb38c4791f4f88 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/esm/modeling_tf_esm.py @@ -0,0 +1,1574 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ESM model.""" + +from __future__ import annotations + +import os + +import numpy as np +import tensorflow as tf + +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFMaskedLMOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + shape_list, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, stable_softmax +from ...utils import logging +from .configuration_esm import EsmConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D" +_CONFIG_FOR_DOC = "EsmConfig" + + +def rotate_half(x): + x1, x2 = tf.split(x, 2, axis=-1) + return tf.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, :, : tf.shape(x)[-2], :] + sin = sin[:, :, : tf.shape(x)[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + tf.linalg.matrix_transpose(x) # Transposes last two dimensions only + + +def average_product_correct(x): + "Perform average product correct, used for contact prediction." + a1 = tf.reduce_sum(x, -1, keepdims=True) + a2 = tf.reduce_sum(x, -2, keepdims=True) + a12 = tf.reduce_sum(x, (-1, -2), keepdims=True) + + avg = a1 * a2 + avg = avg / a12 + normalized = x - avg + return normalized + + +class TFRotaryEmbedding(keras.layers.Layer): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + def __init__(self, dim: int, name=None): + super().__init__(name=name) + # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation + # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at + # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the + # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that + # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our + # models give different outputs from the original. + self.dim = dim + + def build(self, input_shape): + super().build(input_shape) + self.inv_freq = self.add_weight( + "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False + ) + self.inv_freq.assign( + 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) + ) + + def _compute_cos_sin(self, x, seq_dimension=2): + seq_len = tf.shape(x)[seq_dimension] + + t = tf.range(seq_len, dtype=self.inv_freq.dtype) + freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication + emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :] + + return tf.cos(emb), tf.sin(emb) + + def call(self, q: tf.Tensor, k: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]: + cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb(q, cos_emb, sin_emb), + apply_rotary_pos_emb(k, cos_emb, sin_emb), + ) + + +class TFEsmContactPredictionHead(keras.layers.Layer): + """Performs symmetrization, apc, and computes a logistic regression on the output features""" + + def __init__( + self, + in_features: int, + bias=True, + eos_idx: int = 2, + name=None, + ): + super().__init__(name=name) + self.eos_idx = eos_idx + self.in_features = in_features + self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "regression", None) is not None: + with tf.name_scope(self.regression.name): + self.regression.build((None, self.in_features)) + + def call(self, tokens, attentions): + # remove eos token attentions + eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype) + eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2) + attentions = attentions * eos_mask[:, None, None, :, :] + attentions = attentions[..., :-1, :-1] + # remove cls token attentions + attentions = attentions[..., 1:, 1:] + batch_size, layers, heads, seqlen, _ = shape_list(attentions) + attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen)) + + # features: batch x channels x tokens x tokens (symmetric) + attentions = average_product_correct(symmetrize(attentions)) + attentions = tf.transpose(attentions, perm=(0, 2, 3, 1)) + return tf.squeeze(self.regression(attentions), 3) + + +class TFEsmEmbeddings(keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, name=None): + super().__init__(name=name) + self.word_embeddings = keras.layers.Embedding( + config.vocab_size, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="word_embeddings", + ) + self.position_embeddings = keras.layers.Embedding( + config.max_position_embeddings, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="position_embeddings", + ) + + if config.emb_layer_norm_before: + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + else: + self.layer_norm = None + # Matt: I think this line was copied incorrectly from BERT, disabling for now + # self.dropout = Dropout(config.hidden_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.position_ids = tf.range(config.max_position_embeddings)[None, :] + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + self.config = config + + def call( + self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout: + embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32) + masked_tokens = input_ids == self.mask_token_id + mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths + embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype) + # Matt: I think this line was copied incorrectly from BERT, disabling it for now. + # embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: tf.Tensor + + Returns: tf.Tensor + """ + input_shape = shape_list(inputs_embeds)[:-1] + sequence_length = input_shape[1] + + position_ids = tf.range( + start=self.padding_idx + 1, limit=sequence_length + self.padding_idx + 1, dtype=tf.int64 + ) + return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "word_embeddings", None) is not None: + with tf.name_scope(self.word_embeddings.name): + self.word_embeddings.build(None) + if getattr(self, "position_embeddings", None) is not None: + with tf.name_scope(self.position_embeddings.name): + self.position_embeddings.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + +class TFEsmSelfAttention(keras.layers.Layer): + def __init__(self, config, position_embedding_type=None, name=None): + super().__init__(name=name) + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = keras.layers.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size, + embeddings_initializer=get_initializer(config.initializer_range), + ) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings") + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, perm=(0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + past_key_value: tuple[tuple[tf.Tensor]] | None = None, + output_attentions: bool | None = False, + training: bool = False, + ) -> tuple[tf.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # ESM code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = shape_list(hidden_states)[1] + position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), -1) + position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), 0) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = attention_probs @ value_layer + + context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3)) + new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + if getattr(self, "rotary_embeddings", None) is not None: + with tf.name_scope(self.rotary_embeddings.name): + self.rotary_embeddings.build(None) + + +class TFEsmSelfOutput(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states, input_tensor, training=False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states += input_tensor + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFEsmAttention(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.self = TFEsmSelfAttention(config, name="self") + self.output_layer = TFEsmSelfOutput(config, name="output") + self.pruned_heads = set() + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=False, + ): + hidden_states_ln = self.LayerNorm(hidden_states) + self_outputs = self.self( + hidden_states_ln, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training, + ) + attention_output = self.output_layer(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "output_layer", None) is not None: + with tf.name_scope(self.output_layer.name): + self.output_layer.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFEsmIntermediate(keras.layers.Layer): + def __init__(self, config: EsmConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = tf.nn.gelu(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFEsmOutput(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states, input_tensor, training=False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states += input_tensor + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + + +class TFEsmLayer(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = TFEsmAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFEsmAttention(config) + self.intermediate = TFEsmIntermediate(config, name="intermediate") + self.output_layer = TFEsmOutput(config, name="output") + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated" + " with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layernorm_output = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(hidden_states=layernorm_output) + layer_output = self.output_layer( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "output_layer", None) is not None: + with tf.name_scope(self.output_layer.name): + self.output_layer.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFEsmEncoder(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.config = config + self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + self.emb_layer_norm_after = keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="emb_layer_norm_after" + ) + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + training=False, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "emb_layer_norm_after", None) is not None: + with tf.name_scope(self.emb_layer_norm_after.name): + self.emb_layer_norm_after.build([None, None, self.config.hidden_size]) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm +class TFEsmPooler(keras.layers.Layer): + def __init__(self, config: EsmConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFEsmPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EsmConfig + base_model_prefix = "esm" + + +ESM_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a + regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior. + + Parameters: + config ([`EsmConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ESM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", + ESM_START_DOCSTRING, +) +class TFEsmMainLayer(keras.layers.Layer): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config, add_pooling_layer=True, name=None, **kwargs): + super().__init__(name=name, **kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.embeddings = TFEsmEmbeddings(config, name="embeddings") + self.encoder = TFEsmEncoder(config, name="encoder") + self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None + + self.contact_head = TFEsmContactPredictionHead( + in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head" + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "contact_head", None) is not None: + with tf.name_scope(self.contact_head.name): + self.contact_head.build(None) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.word_embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError + + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def predict_contacts(self, tokens, attention_mask): + attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions + attns = tf.stack(attns, axis=1) # Matches the original model layout + # In the original model, attentions for padding tokens are completely zeroed out. + # This makes no difference most of the time because the other tokens won't attend to them, + # but it does for the contact prediction task, which takes attentions as input, + # so we have to mimic that here. + attention_mask = tf.cast(attention_mask, attns.dtype) + attns *= attention_mask[:, None, None, None] + attns *= attention_mask[:, None, None, :, None] + return self.contact_head(tokens, attns) + + +@add_start_docstrings( + "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", + ESM_START_DOCSTRING, +) +class TFEsmModel(TFEsmPreTrainedModel): + def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.esm( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def predict_contacts(self, tokens, attention_mask): + return self.esm.predict_contacts(tokens, attention_mask) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + + +@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING) +class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss): + _keys_to_ignore_on_load_missing = [r"position_ids"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.lm_head = TFEsmLMHead(config, name="lm_head") + if config.tie_word_embeddings: + # Ensure word embeddings are built so that we actually have something to tie + with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")): + self.esm.embeddings.word_embeddings.build((None, None)) + self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0] + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + def get_lm_head(self): + return self.lm_head + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`dict[str, any]`, *optional*, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return TFMaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def predict_contacts(self, tokens, attention_mask): + return self.esm.predict_contacts(tokens, attention_mask) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFEsmLMHead(keras.layers.Layer): + """ESM Head for masked language modeling.""" + + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + if config.tie_word_embeddings: + self.decoder = None + else: + self.decoder = keras.layers.Dense( + config.vocab_size, + kernel_initializer=get_initializer(config.initializer_range), + name="decoder", + use_bias=False, + ) + self.config = config + + def build(self, input_shape=None): + # Separate bias to match the PT model and allow weight cross-loading to work + # Put it in the build so it gets the right name when adding it as a weight + if self.built: + return + self.built = True + self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings: + with tf.name_scope(self.decoder.name): + self.decoder.build([None, None, self.config.hidden_size]) + + def get_bias(self): + return {"bias": self.bias} + + def call(self, features): + x = self.dense(features) + x = tf.nn.gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + if self.config.tie_word_embeddings: + x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias + else: + x = self.decoder(x) + self.bias + return x + + +@add_start_docstrings( + """ + ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ESM_START_DOCSTRING, +) +class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.classifier = TFEsmClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ESM_START_DOCSTRING, +) +class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense(config.num_labels, name="classifier") + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +class TFEsmClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.out_proj = keras.layers.Dense( + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + activation="linear", + name="out_proj", + ) + self.config = config + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: tf.Tensor x: + + Returns: tf.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = tf.cast(input_ids != padding_idx, tf.int64) + incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask + return incremental_indices + padding_idx + + +__all__ = [ + "TFEsmForMaskedLM", + "TFEsmForSequenceClassification", + "TFEsmForTokenClassification", + "TFEsmModel", + "TFEsmPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/esm/tokenization_esm.py b/venv/lib/python3.13/site-packages/transformers/models/esm/tokenization_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9705f7dbd33216a327eab04415ec57fe8e858d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/esm/tokenization_esm.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for ESM.""" + +import os +from typing import Optional + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +def load_vocab_file(vocab_file): + with open(vocab_file, "r") as f: + lines = f.read().splitlines() + return [l.strip() for l in lines] + + +class EsmTokenizer(PreTrainedTokenizer): + """ + Constructs an ESM tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + cls_token="", + pad_token="", + mask_token="", + eos_token="", + **kwargs, + ): + self.all_tokens = load_vocab_file(vocab_file) + self._id_to_token = dict(enumerate(self.all_tokens)) + self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)} + super().__init__( + unk_token=unk_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + eos_token=eos_token, + **kwargs, + ) + + # TODO, all the tokens are added? But they are also part of the vocab... bit strange. + # none of them are special, but they all need special splitting. + + self.unique_no_split_tokens = self.all_tokens + self._update_trie(self.unique_no_split_tokens) + + def _convert_id_to_token(self, index: int) -> str: + return self._id_to_token.get(index, self.unk_token) + + def _convert_token_to_id(self, token: str) -> int: + return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) + + def _tokenize(self, text, **kwargs): + return text.split() + + def get_vocab(self): + base_vocab = self._token_to_id.copy() + base_vocab.update(self.added_tokens_encoder) + return base_vocab + + def token_to_id(self, token: str) -> int: + return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) + + def id_to_token(self, index: int) -> str: + return self._id_to_token.get(index, self.unk_token) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + cls = [self.cls_token_id] + sep = [self.eos_token_id] # No sep token in ESM vocabulary + if token_ids_1 is None: + if self.eos_token_id is None: + return cls + token_ids_0 + else: + return cls + token_ids_0 + sep + elif self.eos_token_id is None: + raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!") + return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token + + def get_special_tokens_mask( + self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`list[int]`): + List of ids of the first sequence. + token_ids_1 (`list[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + + return [1 if token in self.all_special_ids else 0 for token in token_ids_0] + mask = [1] + ([0] * len(token_ids_0)) + [1] + if token_ids_1 is not None: + mask += [0] * len(token_ids_1) + [1] + return mask + + def save_vocabulary(self, save_directory, filename_prefix): + vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") + with open(vocab_file, "w") as f: + f.write("\n".join(self.all_tokens)) + return (vocab_file,) + + @property + def vocab_size(self) -> int: + return len(self.all_tokens) + + +__all__ = ["EsmTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/evolla/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/evolla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09be74f03397074c49bf0df9a34b67f695856131 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/evolla/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_evolla import * + from .modeling_evolla import * + from .processing_evolla import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/evolla/configuration_evolla.py b/venv/lib/python3.13/site-packages/transformers/models/evolla/configuration_evolla.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d0361e95fb996075f2c14e7531b14164d5545e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/evolla/configuration_evolla.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evolla model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SaProtConfig(PretrainedConfig): + r"""This is the configuration class to store the configuration of a [`EvollaSaProtProteinEncoder`]. It is used to instantiate a + SaProt model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 446): + Vocabulary size of the protein sequence model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`EvollaModel`]. + mask_token_id (`int`, *optional*, defaults to 4): + The id of the *mask* token in the protein sequence model. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the *padding* token in the protein sequence model. + hidden_size (`int`, *optional*, defaults to 1280): + Dimensionality of the protein sequence model layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 33): + Number of hidden layers in the protein sequence model. + num_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the protein sequence model. + intermediate_size (`int`, *optional*, defaults to 5120): + Dimensionality of the intermediate layers in the protein sequence model. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the hidden layers in the protein sequence model. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities in the protein sequence model. + max_position_embeddings (`int`, *optional*, defaults to 1026): + The maximum sequence length that the protein sequence model might ever be used with. Typically set this to + something large just in case (e.g., 512 or 1024 or 2048). + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for the layer normalization layer in the protein sequence model. + position_embedding_type (`str`, *optional*, defaults to `"rotary"`): + The type of position embedding to use in the protein sequence model. Currently only `"rotary"` is supported. + emb_layer_norm_before (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization before the position embedding in the protein sequence model. + token_dropout (`bool`, *optional*, defaults to `True`): + Whether to apply dropout to the tokens in the protein sequence model.""" + + def __init__( + self, + vocab_size=446, + mask_token_id=4, + pad_token_id=1, + hidden_size=1280, + num_hidden_layers=33, + num_attention_heads=20, + intermediate_size=5120, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1026, + initializer_range=0.02, + layer_norm_eps=1e-05, + position_embedding_type="rotary", + emb_layer_norm_before=False, + token_dropout=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.emb_layer_norm_before = emb_layer_norm_before + self.token_dropout = token_dropout + + +class EvollaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EvollaModel`]. It is used to instantiate an + Evolla model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Evolla-10B. + + e.g. [westlake-repl/Evolla-10B-hf](https://huggingface.co/westlake-repl/Evolla-10B-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + protein_encoder_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SaProtConfig`]. + vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the Evolla llama model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`EvollaModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the llama layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimensionality of the intermediate layers in the llama model. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the llama model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the llama model. + num_key_value_heads (`int`, *optional*, defaults to 8): + Number of key-value pairs for each attention layer in the llama model. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the llama model. If string, `"gelu"`, `"relu"`, + `"selu"` and `"silu"` are supported. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for the RMS-norm layer in the llama model. + rope_theta (`float`, *optional*, defaults to 500000.0): + The threshold value for the RoPE layer in the llama model. + rope_scaling (`float`, *optional*): + The scaling factor for the RoPE layer in the llama model. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the attention layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention layer. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layer. + aligner_ffn_mult (`int`, *optional*, defaults to 4): + The FFN multiplier for the aligner layer. + aligner_enable_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the aligner layer. + aligner_attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities in the aligner layer. + aligner_num_add_layers (`int`, *optional*, defaults to 8): + The number of additional layers for the aligner layer. + resampler_depth (`int`, *optional*, defaults to 6): + The depth of the resampler layer in the llama model. + resampler_dim_head (`int`, *optional*, defaults to 64): + The dimension of the heads in the resampler layer in the llama model. + resampler_heads (`int`, *optional*, defaults to 8): + The number of heads in the resampler layer in the llama model. + resampler_num_latents (`int`, *optional*, defaults to 64): + The number of latents in the resampler layer in the llama model. + resampler_ff_mult (`int`, *optional*, defaults to 4): + The FFN multiplier for the resampler layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + bos_token_id (`int`, *optional*, defaults to 128000): + The id of the *beginning-of-sequence* token. + eos_token_id (`int`, *optional*, defaults to 128009): + The id of the *end-of-sequence* token. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the input and output word embeddings. + + Example: + + ```python + >>> from transformers import EvollaModel, EvollaConfig + + >>> # Initializing a Evolla evolla-10b style configuration + >>> configuration = EvollaConfig() + + >>> # Initializing a model from the evolla-10b style configuration + >>> model = EvollaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "EvollaModel" + sub_configs = {"protein_encoder_config": SaProtConfig} + + def __init__( + self, + protein_encoder_config=None, + vocab_size=128256, # llama vocab size + hidden_size=4096, # llama hidden size + intermediate_size=14336, # llama intermediate size + num_hidden_layers=32, # llama num layers + num_attention_heads=32, # llama num heads + num_key_value_heads=8, # llama num key-value heads + hidden_act="silu", # llama activation function + max_position_embeddings=8192, # llama rope max length + rms_norm_eps=1e-05, + rope_theta=500000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + aligner_ffn_mult=4, + aligner_enable_bias=True, + aligner_attention_probs_dropout_prob=0.1, + aligner_num_add_layers=8, + resampler_depth=6, + resampler_dim_head=64, + resampler_heads=8, + resampler_num_latents=64, + resampler_ff_mult=4, + initializer_range=0.02, + pad_token_id=None, + bos_token_id=128000, + eos_token_id=128009, + use_cache=False, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.tie_word_embeddings = tie_word_embeddings + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.aligner_ffn_mult = aligner_ffn_mult + self.aligner_enable_bias = aligner_enable_bias + self.aligner_attention_probs_dropout_prob = aligner_attention_probs_dropout_prob + self.aligner_num_add_layers = aligner_num_add_layers + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.resampler_depth = resampler_depth + self.resampler_dim_head = resampler_dim_head + self.resampler_heads = resampler_heads + self.resampler_num_latents = resampler_num_latents + self.resampler_ff_mult = resampler_ff_mult + + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # Subconfig + if protein_encoder_config is None: + protein_encoder_config = {} + logger.info("`protein_encoder_config` is `None`. Initializing the `SaProtConfig` with default values.") + self.protein_encoder_config = SaProtConfig(**protein_encoder_config) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["EvollaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/evolla/modeling_evolla.py b/venv/lib/python3.13/site-packages/transformers/models/evolla/modeling_evolla.py new file mode 100644 index 0000000000000000000000000000000000000000..75d68ba09cda4ac34d199cd26f2b001a5b80507b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/evolla/modeling_evolla.py @@ -0,0 +1,1585 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/evolla/modular_evolla.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_evolla.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPast, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithPast, + ModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype +from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_evolla import EvollaConfig, SaProtConfig + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indices.long() + padding_idx + + +class EvollaSaProtEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + if config.emb_layer_norm_before: + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + else: + self.layer_norm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + self.padding_idx = config.pad_token_id + if self.position_embedding_type == "absolute": + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + # remove the position_ids in EsmEmbeddings + self.position_ids = None + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + inputs_embeds=None, + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support EVOLLA_SA_PROT-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + # Matt: EVOLLA_SA_PROT has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all EVOLLA_SA_PROT model training runs + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths + embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to( + embeddings.dtype + ) + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + # Matt: I think this line was copied incorrectly from BERT, disabling it for now. + # embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +def rotate_half_esm(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_esm(x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half_esm(x) * sin) + + +class EvollaSaProtRotaryEmbedding(nn.Module): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=2): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :] + self._sin_cached = emb.sin()[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype), + apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype), + ) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], +): + # EVOLLA_SA_PROT applies relative position embeddings and we don't copy from Llama + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if hasattr(module, "position_embedding_type") and module.position_embedding_type in [ + "relative_key", + "relative_key_query", + ]: + seq_length = query.shape[2] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility + + if module.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + elif module.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) + relative_position_scores = relative_position_scores_query + relative_position_scores_key + + attn_weights = attn_weights + relative_position_scores + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EvollaSaProtSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False): + super().__init__() + self.config = config + + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = config.attention_probs_dropout_prob + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + self.scaling = 1.0 + self.is_causal = self.is_decoder and not is_cross_attention + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + batch_size, seq_length = hidden_states.shape[:-1] + hidden_shape = (batch_size, seq_length, -1, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # EVOLLA_SA_PROT scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # EVOLLA_SA_PROT code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.position_embedding_type in ["relative_key", "relative_key_query"]: + raise ValueError( + f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. " + "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`" + ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + return attn_output, attn_weights + + +class EvollaSaProtSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EvollaSaProtAttention(nn.Module): + def __init__(self, config, layer_idx=None, is_cross_attention=False): + super().__init__() + self.self = EvollaSaProtSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention) + self.output = EvollaSaProtSelfOutput(config) + self.pruned_heads = set() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + hidden_states_ln = self.LayerNorm(hidden_states) + attn_output, _ = self.self( + hidden_states_ln, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + attn_output = self.output(attn_output, hidden_states) + return attn_output + + +def gelu(x): + """ + This is the gelu implementation from the original EVOLLA_SA_PROT repo. Using F.gelu yields subtly wrong results. + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class EvollaSaProtIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = gelu(hidden_states) + return hidden_states + + +class EvollaSaProtOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EvollaSaProtLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = EvollaSaProtAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = EvollaSaProtAttention(config, is_cross_attention=True) + self.intermediate = EvollaSaProtIntermediate(config) + self.output = EvollaSaProtOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + attention_output = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + **kwargs, + ) + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated" + " with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + attention_output = self.crossattention( + attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + layer_output = self.feed_forward_chunk(attention_output) + return layer_output + + def feed_forward_chunk(self, attention_output): + attention_output_ln = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(attention_output_ln) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class EvollaSaProtEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([EvollaSaProtLayer(config) for _ in range(config.num_hidden_layers)]) + self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module( + hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states) + + +class EvollaSaProtPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class EvollaSaProtPreTrainedModel(PreTrainedModel): + config: SaProtConfig + _no_split_modules = ["EvollaSaProtLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": EvollaSaProtLayer, + "attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")], + "cross_attentions": [ + OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"), + ], + } + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): + def __init__(self, config: SaProtConfig): + super().__init__(config) + self.embeddings = EvollaSaProtEmbeddings(config) + self.encoder = EvollaSaProtEncoder(config) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask) + sequence_output = encoder_outputs[0] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: tuple[int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = get_parameter_dtype(self) + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + +class EvollaSequenceCompressorAttention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents, mask): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D); n2: num of latent tokens + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk( + 2, dim=-1 + ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads + + q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3) + k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3) + v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3) + q = q * self.scale # batch_size, num_heads, num_latents, dim_head + + # attention + sim = torch.matmul(q, k.transpose(-1, -2)) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + bs, nh, skd, okd = sim.shape + ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd) + mask_exp = mask[:, None, None, :] + ones_exp = ones[None, :, :, None] + mask = mask_exp * ones_exp + + sim = sim.masked_fill((1 - mask).bool(), -1e4) + attn = sim.softmax(dim=-1) + out = torch.matmul(attn, v) + out = out.permute(0, 2, 1, 3) + + # [batch, seq, head, features] -> [batch, seq, head*features] + out = out.reshape(out.size(0), out.size(1), -1) + + return self.to_out(out) + + +class EvollaFeedForward(nn.Module): + def __init__(self, dim, mult=4): + super().__init__() + inner_dim = int(dim * mult) + + self.norm = nn.LayerNorm(dim) + self.fc1 = nn.Linear(dim, inner_dim, bias=False) + self.activation = nn.GELU() + self.fc2 = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + return self.fc2(self.activation(self.fc1(self.norm(x)))) + + +class EvollaSequenceCompressorResampler(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + protein_repr_dim = config.protein_encoder_config.hidden_size + self.num_latents = config.resampler_num_latents + self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True) + self.layers = nn.ModuleList([]) + for _ in range(config.resampler_depth): + self.layers.append( + nn.ModuleList( + [ + EvollaSequenceCompressorAttention( + dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads + ), + EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(config.hidden_size) + self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size) + + def forward(self, embeds, mask): + b = embeds.shape[0] + + bs, _ = mask.shape # bs, max_protein_length + latent_mask = torch.ones(bs, self.num_latents).to(mask.device) + mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents + + # blocks + ones = torch.ones(b).to(self.latents.device) + latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d] + latents = latents.to(embeds.dtype) + for attn, ff in self.layers: + latents = attn(embeds, latents, mask) + latents + latents = ff(latents) + latents + + transformed_feature = self.protein_projector(latents) + + return self.norm(transformed_feature) + + +@dataclass +@auto_docstring +class EvollaProteinEncoderModelOutput(ModelOutput): + sequence_compressor_output: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class EvollaProteinEncoder(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config) + self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config) + + @can_return_tuple + def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs): + protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + protein_embeds = protein_output.last_hidden_state + sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask) + + return EvollaProteinEncoderModelOutput( + sequence_compressor_output=sequence_repr, + last_hidden_state=protein_output.last_hidden_state, + ) + + +class EvollaSequenceAlignerCrossAttention(nn.Module): + def __init__( + self, + config, + protein_encoder_dim: Optional[int] = None, + structure_encoder_dim: Optional[int] = None, + msa_encoder_dim: Optional[int] = None, + ): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.scale = self.num_attention_heads**-0.5 + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob + enable_bias = config.aligner_enable_bias + ffn_mult = config.aligner_ffn_mult + + self.query = nn.Linear(self.hidden_size, self.all_head_size) + if protein_encoder_dim is not None: + self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + else: + self.key_protein = None + self.value_protein = None + + if structure_encoder_dim is not None: + self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + else: + self.key_structure = None + self.value_structure = None + + if msa_encoder_dim is not None: + self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + else: + self.key_msa = None + self.value_msa = None + + self.attention_norm = EvollaRMSNorm(self.hidden_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias) + + self.ff = EvollaFeedForward(self.hidden_size, ffn_mult) + self.gate_attention = nn.Parameter(torch.tensor([0.0])) + self.gate_ffw = nn.Parameter(torch.tensor([0.0])) + + def cross_attention( + self, + query_states, + protein_key_value_states, + structure_key_value_states, + msa_key_value_states, + query_attn_mask, + protein_kv_attn_mask, + structure_kv_attn_mask, + msa_kv_attn_mask, + ): + """ + query_states: text + key_value_states: protein + query_states: [bs, query_seq_len, dim] + key_value_states: [bs, kv_seq_len, dim] + query_attn_mask: [bs, query_seq_len] + kv_attn_mask: [bs, kv_seq_len] + """ + + # Concatenate protein and structure + kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask] + kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None] + if not kv_attn_mask: + raise ValueError("At least one modality should be provided for cross attention.") + kv_attn_mask = torch.cat(kv_attn_mask, dim=1) + + query_layer = self.attention_norm(query_states) + + # Warning: This place might cause issues, refers to + # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13 + # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable + # Apply linear transformation to input_query, input_key, and input_value + query_layer = self.query(query_layer) # [bs, querylength, dim] + + if self.key_protein is not None and self.value_protein is not None: + protein_key_value_states = protein_key_value_states.to(query_states) + key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim] + value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim] + else: + key_layer_protein = None + value_layer_protein = None + + if self.key_structure is not None and self.value_structure is not None: + structure_key_value_states = structure_key_value_states.to(query_states) + key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim] + value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim] + else: + key_layer_structure = None + value_layer_structure = None + + if self.key_msa is not None and self.value_msa is not None: + msa_key_value_states = msa_key_value_states.to(query_states) + key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim] + value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim] + else: + key_layer_msa = None + value_layer_msa = None + + key_layer = [key_layer_protein, key_layer_structure, key_layer_msa] + key_layer = [_ for _ in key_layer if _ is not None] + key_layer = torch.cat(key_layer, dim=1) + + value_layer = [value_layer_protein, value_layer_structure, value_layer_msa] + value_layer = [_ for _ in value_layer if _ is not None] + value_layer = torch.cat(value_layer, dim=1) + + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3) + + new_key_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3) + + new_value_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3) + + query_layer = query_layer * self.scale + + # attention_mask: [bs, 1, querylength, keylength] + if query_attn_mask is None: + query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device) + attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :] + # Compute the scaled dot-product attention scores + attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength] + attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score + attention_scores = attn_weights.masked_fill( + (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min + ) # [bs, numheads, querylength, keylength] + + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads] + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + return context_layer + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + query_states, + protein_kv_states, + structure_kv_states, + msa_kv_states, + query_attn_mask, + protein_kv_attn_mask=None, + structure_kv_attn_mask=None, + msa_kv_attn_mask=None, + protein_batch_mask=None, + structure_batch_mask=None, + msa_batch_mask=None, + past_key_values=None, + ): + if protein_kv_states is not None: + bs, protein_kv_seq_len, dim = protein_kv_states.shape + if protein_kv_attn_mask is None: + protein_kv_attn_mask = ( + torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device) + * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T + ).to(protein_kv_states.device) + else: + protein_kv_attn_mask = None + + if structure_kv_states is not None: + bs, structure_kv_seq_len, dim = structure_kv_states.shape + if structure_kv_attn_mask is None: + structure_kv_attn_mask = ( + torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device) + * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T + ).to(structure_kv_states.device) + else: + structure_kv_attn_mask = None + + if msa_kv_states is not None: + bs, msa_kv_seq_len, dim = msa_kv_states.shape + if msa_kv_attn_mask is None: + msa_kv_attn_mask = ( + torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device) + * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T + ).to(msa_kv_states.device) + else: + msa_kv_attn_mask = None + hidden_states = query_states + # only when there's at least one valid modality, crossattention will be performed + if ( + (protein_kv_states is not None and protein_kv_attn_mask.any()) + or (structure_kv_states is not None and structure_kv_attn_mask.any()) + or (msa_kv_states is not None and msa_kv_attn_mask.any()) + ): + residual = hidden_states + hidden_states = self.cross_attention( + query_states=hidden_states, + protein_key_value_states=protein_kv_states, + structure_key_value_states=structure_kv_states, + msa_key_value_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_kv_attn_mask=protein_kv_attn_mask, + structure_kv_attn_mask=structure_kv_attn_mask, + msa_kv_attn_mask=msa_kv_attn_mask, + ) # [bs, query_seq_len, dim] + # tanh gate + hidden_states = torch.tanh(self.gate_attention) * hidden_states + + hidden_states = residual + hidden_states # input_query + + residual = hidden_states + hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw) + hidden_states = residual + hidden_states + + return hidden_states + + +@use_kernel_forward_from_hub("RMSNorm") +class EvollaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + EvollaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class EvollaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: EvollaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class EvollaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class EvollaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: EvollaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EvollaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: EvollaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = EvollaAttention(config=config, layer_idx=layer_idx) + + self.mlp = EvollaMLP(config) + self.input_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0: + self.adapter = EvollaSequenceAlignerCrossAttention( + config, + protein_encoder_dim=config.hidden_size, + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + protein_kv_states: Optional[torch.Tensor] = None, + structure_kv_states: Optional[torch.Tensor] = None, + msa_kv_states: Optional[torch.Tensor] = None, + protein_batch_mask: Optional[torch.Tensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + query_attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if hasattr(self, "adapter"): + hidden_states = self.adapter( + query_states=hidden_states, + protein_kv_states=protein_kv_states, + structure_kv_states=structure_kv_states, + msa_kv_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + ) + + return hidden_states + + +@auto_docstring +class EvollaPreTrainedModel(PreTrainedModel): + config: EvollaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "EvollaDecoderLayer", + "EvollaSequenceCompressorResampler", + "EvollaSequenceAlignerCrossAttention", + ] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder` + _supports_sdpa = True + _supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder` + + _can_compile_fullgraph = True + _supports_attention_backend = False + _can_record_outputs = { + "hidden_states": EvollaDecoderLayer, + "attentions": EvollaAttention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + super()._init_weights(module) + if isinstance(module, EvollaSequenceAlignerCrossAttention): + module.gate_attention.zero_() + module.gate_ffw.zero_() + module.attention_norm.weight.data.fill_(1.0) + elif isinstance(module, EvollaSequenceCompressorResampler): + module.latents.data.normal_(mean=0.0, std=std) + + +class EvollaModel(EvollaPreTrainedModel): + def __init__(self, config: EvollaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + self.protein_encoder = EvollaProteinEncoder(config=config) + self.layers = nn.ModuleList( + [ + EvollaDecoderLayer( + config=config, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = EvollaRotaryEmbedding(config=config) + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + protein_input_ids: Optional[torch.LongTensor] = None, + protein_attention_mask: Optional[torch.Tensor] = None, + structure_feats: Optional[torch.FloatTensor] = None, + msa_feats: Optional[torch.FloatTensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + structure_feats (torch.FloatTensor): + The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + msa_feats (torch.FloatTensor): + The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + structure_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now. + msa_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + protein_feats = None + protein_batch_mask = None + # If provided, actually compute them + if protein_input_ids is not None and protein_attention_mask is not None: + protein_outputs = self.protein_encoder( + input_ids=protein_input_ids, + attention_mask=protein_attention_mask, + ) + protein_feats = protein_outputs.sequence_compressor_output + protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + protein_kv_states=protein_feats, + structure_kv_states=structure_feats, + msa_kv_states=msa_feats, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + query_attn_mask=attention_mask, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + return output + + +class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.model = EvollaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + return self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, # text input ids + attention_mask: Optional[torch.Tensor] = None, # text attention mask + inputs_embeds: Optional[torch.FloatTensor] = None, # text input embeddings + labels: Optional[torch.LongTensor] = None, + protein_input_ids: Optional[torch.LongTensor] = None, + protein_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs, + ): + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + + Example: + + ```python + >>> from transformers import EvollaProcessor, EvollaForProteinText2Text + >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf") + >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf") + + >>> protein_information = { + "aa_seq": "your amino acid sequence", + "foldseek": "your foldseek sequence", + } + >>> question = "What is the function of this protein?" + >>> message = [ + {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, + {"role": "user", "content": question}, + ] + + >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest") + >>> outputs = model.generate(**inputs) + + >>> print(processor.batch_decode(outputs, skip_special_tokens=True)) + ```""" + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + protein_input_ids=protein_input_ids, + protein_attention_mask=protein_attention_mask, + use_cache=use_cache, + **kwargs, + ) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + lm_outputs = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + return lm_outputs + + +__all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/evolla/modular_evolla.py b/venv/lib/python3.13/site-packages/transformers/models/evolla/modular_evolla.py new file mode 100644 index 0000000000000000000000000000000000000000..496a284935abd6031f03ccfec8bfe208d70c1a5f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/evolla/modular_evolla.py @@ -0,0 +1,1023 @@ +# coding=utf-8 +# Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn + +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithPast, + ModelOutput, +) +from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype +from ...utils import ( + auto_docstring, + can_return_tuple, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from ..esm.modeling_esm import ( + EsmAttention, + EsmEmbeddings, + EsmEncoder, + EsmIntermediate, + EsmLayer, + EsmOutput, + EsmPooler, + EsmSelfAttention, + EsmSelfOutput, +) +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from .configuration_evolla import EvollaConfig, SaProtConfig + + +logger = logging.get_logger(__name__) + + +class EvollaSaProtEmbeddings(EsmEmbeddings): + def __init__(self, config): + super().__init__(config) + # remove the position_ids in EsmEmbeddings + self.position_ids = None + + +def rotate_half_esm(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_esm(x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half_esm(x) * sin) + + +class EvollaSaProtRotaryEmbedding(nn.Module): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=2): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :] + self._sin_cached = emb.sin()[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype), + apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype), + ) + + +class EvollaSaProtSelfAttention(EsmSelfAttention): + def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False): + nn.Module.__init__(self) + self.config = config + + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = config.attention_probs_dropout_prob + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + self.scaling = 1.0 + self.is_causal = self.is_decoder and not is_cross_attention + + +class EvollaSaProtSelfOutput(EsmSelfOutput): + pass + + +class EvollaSaProtAttention(EsmAttention): + pass + + +class EvollaSaProtIntermediate(EsmIntermediate): + pass + + +class EvollaSaProtOutput(EsmOutput): + pass + + +class EvollaSaProtLayer(EsmLayer): + pass + + +class EvollaSaProtEncoder(EsmEncoder): + pass + + +class EvollaSaProtPooler(EsmPooler): + pass + + +@auto_docstring +class EvollaSaProtPreTrainedModel(PreTrainedModel): + config: SaProtConfig + _no_split_modules = ["EvollaSaProtLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": EvollaSaProtLayer, + "attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")], + "cross_attentions": [ + OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"), + ], + } + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): + def __init__(self, config: SaProtConfig): + super().__init__(config) + self.embeddings = EvollaSaProtEmbeddings(config) + self.encoder = EvollaSaProtEncoder(config) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask) + sequence_output = encoder_outputs[0] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: tuple[int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = get_parameter_dtype(self) + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + +class EvollaSequenceCompressorAttention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents, mask): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D); n2: num of latent tokens + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk( + 2, dim=-1 + ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads + + q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3) + k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3) + v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3) + q = q * self.scale # batch_size, num_heads, num_latents, dim_head + + # attention + sim = torch.matmul(q, k.transpose(-1, -2)) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + bs, nh, skd, okd = sim.shape + ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd) + mask_exp = mask[:, None, None, :] + ones_exp = ones[None, :, :, None] + mask = mask_exp * ones_exp + + sim = sim.masked_fill((1 - mask).bool(), -1e4) + attn = sim.softmax(dim=-1) + out = torch.matmul(attn, v) + out = out.permute(0, 2, 1, 3) + + # [batch, seq, head, features] -> [batch, seq, head*features] + out = out.reshape(out.size(0), out.size(1), -1) + + return self.to_out(out) + + +class EvollaFeedForward(nn.Module): + def __init__(self, dim, mult=4): + super().__init__() + inner_dim = int(dim * mult) + + self.norm = nn.LayerNorm(dim) + self.fc1 = nn.Linear(dim, inner_dim, bias=False) + self.activation = nn.GELU() + self.fc2 = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + return self.fc2(self.activation(self.fc1(self.norm(x)))) + + +class EvollaSequenceCompressorResampler(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + protein_repr_dim = config.protein_encoder_config.hidden_size + self.num_latents = config.resampler_num_latents + self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True) + self.layers = nn.ModuleList([]) + for _ in range(config.resampler_depth): + self.layers.append( + nn.ModuleList( + [ + EvollaSequenceCompressorAttention( + dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads + ), + EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(config.hidden_size) + self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size) + + def forward(self, embeds, mask): + b = embeds.shape[0] + + bs, _ = mask.shape # bs, max_protein_length + latent_mask = torch.ones(bs, self.num_latents).to(mask.device) + mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents + + # blocks + ones = torch.ones(b).to(self.latents.device) + latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d] + latents = latents.to(embeds.dtype) + for attn, ff in self.layers: + latents = attn(embeds, latents, mask) + latents + latents = ff(latents) + latents + + transformed_feature = self.protein_projector(latents) + + return self.norm(transformed_feature) + + +@dataclass +@auto_docstring +class EvollaProteinEncoderModelOutput(ModelOutput): + sequence_compressor_output: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class EvollaProteinEncoder(nn.Module): + def __init__(self, config: EvollaConfig): + super().__init__() + self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config) + self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config) + + @can_return_tuple + def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs): + protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + protein_embeds = protein_output.last_hidden_state + sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask) + + return EvollaProteinEncoderModelOutput( + sequence_compressor_output=sequence_repr, + last_hidden_state=protein_output.last_hidden_state, + ) + + +class EvollaSequenceAlignerCrossAttention(nn.Module): + def __init__( + self, + config, + protein_encoder_dim: Optional[int] = None, + structure_encoder_dim: Optional[int] = None, + msa_encoder_dim: Optional[int] = None, + ): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.scale = self.num_attention_heads**-0.5 + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob + enable_bias = config.aligner_enable_bias + ffn_mult = config.aligner_ffn_mult + + self.query = nn.Linear(self.hidden_size, self.all_head_size) + if protein_encoder_dim is not None: + self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size) + else: + self.key_protein = None + self.value_protein = None + + if structure_encoder_dim is not None: + self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size) + else: + self.key_structure = None + self.value_structure = None + + if msa_encoder_dim is not None: + self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size) + else: + self.key_msa = None + self.value_msa = None + + self.attention_norm = EvollaRMSNorm(self.hidden_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias) + + self.ff = EvollaFeedForward(self.hidden_size, ffn_mult) + self.gate_attention = nn.Parameter(torch.tensor([0.0])) + self.gate_ffw = nn.Parameter(torch.tensor([0.0])) + + def cross_attention( + self, + query_states, + protein_key_value_states, + structure_key_value_states, + msa_key_value_states, + query_attn_mask, + protein_kv_attn_mask, + structure_kv_attn_mask, + msa_kv_attn_mask, + ): + """ + query_states: text + key_value_states: protein + query_states: [bs, query_seq_len, dim] + key_value_states: [bs, kv_seq_len, dim] + query_attn_mask: [bs, query_seq_len] + kv_attn_mask: [bs, kv_seq_len] + """ + + # Concatenate protein and structure + kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask] + kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None] + if not kv_attn_mask: + raise ValueError("At least one modality should be provided for cross attention.") + kv_attn_mask = torch.cat(kv_attn_mask, dim=1) + + query_layer = self.attention_norm(query_states) + + # Warning: This place might cause issues, refers to + # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13 + # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable + # Apply linear transformation to input_query, input_key, and input_value + query_layer = self.query(query_layer) # [bs, querylength, dim] + + if self.key_protein is not None and self.value_protein is not None: + protein_key_value_states = protein_key_value_states.to(query_states) + key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim] + value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim] + else: + key_layer_protein = None + value_layer_protein = None + + if self.key_structure is not None and self.value_structure is not None: + structure_key_value_states = structure_key_value_states.to(query_states) + key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim] + value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim] + else: + key_layer_structure = None + value_layer_structure = None + + if self.key_msa is not None and self.value_msa is not None: + msa_key_value_states = msa_key_value_states.to(query_states) + key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim] + value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim] + else: + key_layer_msa = None + value_layer_msa = None + + key_layer = [key_layer_protein, key_layer_structure, key_layer_msa] + key_layer = [_ for _ in key_layer if _ is not None] + key_layer = torch.cat(key_layer, dim=1) + + value_layer = [value_layer_protein, value_layer_structure, value_layer_msa] + value_layer = [_ for _ in value_layer if _ is not None] + value_layer = torch.cat(value_layer, dim=1) + + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3) + + new_key_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3) + + new_value_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3) + + query_layer = query_layer * self.scale + + # attention_mask: [bs, 1, querylength, keylength] + if query_attn_mask is None: + query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device) + attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :] + # Compute the scaled dot-product attention scores + attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength] + attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score + attention_scores = attn_weights.masked_fill( + (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min + ) # [bs, numheads, querylength, keylength] + + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads] + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + return context_layer + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + query_states, + protein_kv_states, + structure_kv_states, + msa_kv_states, + query_attn_mask, + protein_kv_attn_mask=None, + structure_kv_attn_mask=None, + msa_kv_attn_mask=None, + protein_batch_mask=None, + structure_batch_mask=None, + msa_batch_mask=None, + past_key_values=None, + ): + if protein_kv_states is not None: + bs, protein_kv_seq_len, dim = protein_kv_states.shape + if protein_kv_attn_mask is None: + protein_kv_attn_mask = ( + torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device) + * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T + ).to(protein_kv_states.device) + else: + protein_kv_attn_mask = None + + if structure_kv_states is not None: + bs, structure_kv_seq_len, dim = structure_kv_states.shape + if structure_kv_attn_mask is None: + structure_kv_attn_mask = ( + torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device) + * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T + ).to(structure_kv_states.device) + else: + structure_kv_attn_mask = None + + if msa_kv_states is not None: + bs, msa_kv_seq_len, dim = msa_kv_states.shape + if msa_kv_attn_mask is None: + msa_kv_attn_mask = ( + torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device) + * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T + ).to(msa_kv_states.device) + else: + msa_kv_attn_mask = None + hidden_states = query_states + # only when there's at least one valid modality, crossattention will be performed + if ( + (protein_kv_states is not None and protein_kv_attn_mask.any()) + or (structure_kv_states is not None and structure_kv_attn_mask.any()) + or (msa_kv_states is not None and msa_kv_attn_mask.any()) + ): + residual = hidden_states + hidden_states = self.cross_attention( + query_states=hidden_states, + protein_key_value_states=protein_kv_states, + structure_key_value_states=structure_kv_states, + msa_key_value_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_kv_attn_mask=protein_kv_attn_mask, + structure_kv_attn_mask=structure_kv_attn_mask, + msa_kv_attn_mask=msa_kv_attn_mask, + ) # [bs, query_seq_len, dim] + # tanh gate + hidden_states = torch.tanh(self.gate_attention) * hidden_states + + hidden_states = residual + hidden_states # input_query + + residual = hidden_states + hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw) + hidden_states = residual + hidden_states + + return hidden_states + + +class EvollaRMSNorm(LlamaRMSNorm): + pass + + +class EvollaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class EvollaMLP(LlamaMLP): + pass + + +class EvollaAttention(LlamaAttention): + pass + + +class EvollaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: EvollaConfig, layer_idx: int): + super().__init__(config, layer_idx) + if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0: + self.adapter = EvollaSequenceAlignerCrossAttention( + config, + protein_encoder_dim=config.hidden_size, + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + protein_kv_states: Optional[torch.Tensor] = None, + structure_kv_states: Optional[torch.Tensor] = None, + msa_kv_states: Optional[torch.Tensor] = None, + protein_batch_mask: Optional[torch.Tensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + query_attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if hasattr(self, "adapter"): + hidden_states = self.adapter( + query_states=hidden_states, + protein_kv_states=protein_kv_states, + structure_kv_states=structure_kv_states, + msa_kv_states=msa_kv_states, + query_attn_mask=query_attn_mask, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + ) + + return hidden_states + + +class EvollaPreTrainedModel(LlamaPreTrainedModel): + _supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder` + _supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder` + _supports_attention_backend = False + _no_split_modules = [ + "EvollaDecoderLayer", + "EvollaSequenceCompressorResampler", + "EvollaSequenceAlignerCrossAttention", + ] + + def _init_weights(self, module): + std = self.config.initializer_range + PreTrainedModel._init_weights(self, module) + if isinstance(module, EvollaSequenceAlignerCrossAttention): + module.gate_attention.zero_() + module.gate_ffw.zero_() + module.attention_norm.weight.data.fill_(1.0) + elif isinstance(module, EvollaSequenceCompressorResampler): + module.latents.data.normal_(mean=0.0, std=std) + + +class EvollaModel(EvollaPreTrainedModel): + def __init__(self, config: EvollaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + self.protein_encoder = EvollaProteinEncoder(config=config) + self.layers = nn.ModuleList( + [ + EvollaDecoderLayer( + config=config, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = EvollaRotaryEmbedding(config=config) + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + protein_input_ids: Optional[torch.LongTensor] = None, + protein_attention_mask: Optional[torch.Tensor] = None, + structure_feats: Optional[torch.FloatTensor] = None, + msa_feats: Optional[torch.FloatTensor] = None, + structure_batch_mask: Optional[torch.Tensor] = None, + msa_batch_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + structure_feats (torch.FloatTensor): + The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + msa_feats (torch.FloatTensor): + The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now. + structure_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now. + msa_batch_mask (torch.Tensor): + The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + protein_feats = None + protein_batch_mask = None + # If provided, actually compute them + if protein_input_ids is not None and protein_attention_mask is not None: + protein_outputs = self.protein_encoder( + input_ids=protein_input_ids, + attention_mask=protein_attention_mask, + ) + protein_feats = protein_outputs.sequence_compressor_output + protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + protein_kv_states=protein_feats, + structure_kv_states=structure_feats, + msa_kv_states=msa_feats, + protein_batch_mask=protein_batch_mask, + structure_batch_mask=structure_batch_mask, + msa_batch_mask=msa_batch_mask, + query_attn_mask=attention_mask, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + return output + + +class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.model = EvollaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + return self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, # text input ids + attention_mask: Optional[torch.Tensor] = None, # text attention mask + inputs_embeds: Optional[torch.FloatTensor] = None, # text input embeddings + labels: Optional[torch.LongTensor] = None, + protein_input_ids: Optional[torch.LongTensor] = None, + protein_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs, + ): + r""" + protein_input_ids (torch.LongTensor): + The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`. + protein_attention_mask (torch.Tensor): + The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`. + + Example: + + ```python + >>> from transformers import EvollaProcessor, EvollaForProteinText2Text + >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf") + >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf") + + >>> protein_information = { + "aa_seq": "your amino acid sequence", + "foldseek": "your foldseek sequence", + } + >>> question = "What is the function of this protein?" + >>> message = [ + {"role": "system", "content": "You are an AI expert that can answer any questions about protein."}, + {"role": "user", "content": question}, + ] + + >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest") + >>> outputs = model.generate(**inputs) + + >>> print(processor.batch_decode(outputs, skip_special_tokens=True)) + ```""" + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + protein_input_ids=protein_input_ids, + protein_attention_mask=protein_attention_mask, + use_cache=use_cache, + **kwargs, + ) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + lm_outputs = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + return lm_outputs + + +__all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/evolla/processing_evolla.py b/venv/lib/python3.13/site-packages/transformers/models/evolla/processing_evolla.py new file mode 100644 index 0000000000000000000000000000000000000000..a421287676112b73cdfa4adc67a9b1dd13ecc861 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/evolla/processing_evolla.py @@ -0,0 +1,247 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for EVOLLA. +""" + +import os +from typing import Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ( + ProcessorMixin, +) +from ..auto import AutoTokenizer + + +PROTEIN_VALID_KEYS = ["aa_seq", "foldseek", "msa"] + + +class EvollaProcessor(ProcessorMixin): + r""" + Constructs a EVOLLA processor which wraps a LLama tokenizer and SaProt tokenizer (EsmTokenizer) into a single processor. + + [`EvollaProcessor`] offers all the functionalities of [`EsmTokenizer`] and [`LlamaTokenizerFast`]. See the + docstring of [`~EvollaProcessor.__call__`] and [`~EvollaProcessor.decode`] for more information. + + Args: + protein_tokenizer (`EsmTokenizer`): + An instance of [`EsmTokenizer`]. The protein tokenizer is a required input. + tokenizer (`LlamaTokenizerFast`, *optional*): + An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. + protein_max_length (`int`, *optional*, defaults to 1024): + The maximum length of the sequence to be generated. + text_max_length (`int`, *optional*, defaults to 512): + The maximum length of the text to be generated. + """ + + attributes = ["protein_tokenizer", "tokenizer"] + valid_kwargs = ["sequence_max_length"] + # protein_tokenizer_class = "EsmTokenizer" + # tokenizer_class = "LlamaTokenizerFast" + protein_tokenizer_class = "AutoTokenizer" + tokenizer_class = "AutoTokenizer" + protein_tokenizer_dir_name = "protein_tokenizer" + # tokenizer_dir_name = "text_tokenizer" + + def __init__(self, protein_tokenizer, tokenizer=None, protein_max_length=1024, text_max_length=512, **kwargs): + if protein_tokenizer is None: + raise ValueError("You need to specify an `protein_tokenizer`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(protein_tokenizer, tokenizer) + + self.tokenizer.pad_token = "<|reserved_special_token_0|>" + self.protein_max_length = protein_max_length + self.text_max_length = text_max_length + + def process_proteins(self, proteins, protein_max_length=1024): + sa_sequences = [] + for protein in proteins: + aa_seq = protein.get("aa_seq") + foldseek = protein.get("foldseek") + sa_sequence = "".join([s.upper() + f.lower() for s, f in zip(aa_seq, foldseek)]) + sa_sequences.append(sa_sequence) + + sa_tokens = self.protein_tokenizer.batch_encode_plus( + sa_sequences, return_tensors="pt", truncation=True, max_length=protein_max_length, padding=True + ) + return sa_tokens + + def process_text( + self, + texts, + text_max_length: int = 512, + ): + prompts = [] + for messages in texts: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + prompts.append(prompt) + + prompt_inputs = self.tokenizer( + prompts, + add_special_tokens=False, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=text_max_length, + ) + return prompt_inputs + + def __call__( + self, + proteins: Optional[Union[list[dict], dict]] = None, + messages_list: Optional[Union[list[list[dict]], list[dict]]] = None, + protein_max_length: Optional[int] = None, + text_max_length: Optional[int] = None, + **kwargs, + ): + r"""This method takes batched or non-batched proteins and messages_list and converts them into format that can be used by + the model. + + Args: + proteins (`Union[List[dict], dict]`): + A list of dictionaries or a single dictionary containing the following keys: + - `"aa_seq"` (`str`) -- The amino acid sequence of the protein. + - `"foldseek"` (`str`) -- The foldseek string of the protein. + messages_list (`Union[List[List[dict]], List[dict]]`): + A list of lists of dictionaries or a list of dictionaries containing the following keys: + - `"role"` (`str`) -- The role of the message. + - `"content"` (`str`) -- The content of the message. + protein_max_length (`int`, *optional*, defaults to 1024): + The maximum length of the sequence to be generated. + text_max_length (`int`, *optional*, defaults to 512): + The maximum length of the text. + + Return: + a dict with following keys: + - `protein_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the protein sequence. + - `protein_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the protein sequence. + - `text_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the text sequence. + - `text_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the text sequence. + """ + # proteins and messages_list should be provided + if proteins is None or messages_list is None: + raise ValueError("You need to specify `messages_list` and `proteins`.") + + protein_max_length = protein_max_length if protein_max_length is not None else self.protein_max_length + text_max_length = text_max_length if text_max_length is not None else self.text_max_length + + # proteins should be List[dict] + if isinstance(proteins, dict): + proteins = [proteins] + # messages_list should be List[List[dict]] + if isinstance(messages_list, (list, tuple)) and not isinstance(messages_list[0], (list, tuple)): + messages_list = [messages_list] + # Check if batched proteins are in the correct format + if isinstance(proteins, (list, tuple)) and not all(isinstance(p, dict) for p in proteins): + raise ValueError("The proteins should be a list of dictionaries, but not all elements are dictionaries.") + if isinstance(proteins, (list, tuple)) and not all( + all(k in PROTEIN_VALID_KEYS for k in p.keys()) for p in proteins + ): + raise ValueError( + "There should be a list of dictionaries with keys: " + f"{', '.join(PROTEIN_VALID_KEYS)} for each protein." + f"But got: {proteins}" + ) + # Check if batched messages_list is in the correct format + if isinstance(messages_list, (list, tuple)): + for messages in messages_list: + if not isinstance(messages, (list, tuple)): + raise ValueError(f"Each messages in messages_list should be a list instead of {type(messages)}.") + if not all(isinstance(m, dict) for m in messages): + raise ValueError( + "Each message in messages_list should be a list of dictionaries, but not all elements are dictionaries." + ) + if any(len(m.keys()) != 2 for m in messages) or any( + set(m.keys()) != {"role", "content"} for m in messages + ): + raise ValueError( + "Each message in messages_list should be a list of dictionaries with two keys: 'role' and 'content'." + f"But got: {messages}" + ) + else: + raise ValueError( + f"The messages_list should be a list of lists of dictionaries, but it's {type(messages_list)}." + ) + sa_tokens = self.process_proteins(proteins, protein_max_length) + + text_tokens = self.process_text(messages_list, text_max_length) + + return BatchFeature( + data={ + "protein_input_ids": sa_tokens["input_ids"], + "protein_attention_mask": sa_tokens["attention_mask"], + "input_ids": text_tokens["input_ids"], + "attention_mask": text_tokens["attention_mask"], + } + ) + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + def protein_batch_decode(self, *args, **kwargs): + return self.protein_tokenizer.batch_decode(*args, **kwargs) + + def protein_decode(self, *args, **kwargs): + return self.protein_tokenizer.decode(*args, **kwargs) + + # overwrite to save the protein tokenizer in a separate folder + # Adapted from instructblip.processing_instructblip.py (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/transformers/models/instructblip/processing_instructblip.py#L191-L221) + def save_pretrained(self, save_directory, **kwargs): + # only save the protein tokenizer in sub_dir + self.protein_tokenizer.save_pretrained(os.path.join(save_directory, self.protein_tokenizer_dir_name)) + + # we modify the attributes so that only the text tokenizer are saved in the main folder + protein_tokenizer_present = "protein_tokenizer" in self.attributes + # find the correct position of it in the attributes list + protein_tokenizer_index = self.attributes.index("protein_tokenizer") if protein_tokenizer_present else None + if protein_tokenizer_present and protein_tokenizer_index is not None: + self.attributes.remove("protein_tokenizer") + + outputs = super().save_pretrained(save_directory, **kwargs) + + if protein_tokenizer_present and protein_tokenizer_index is not None: + self.attributes.insert(protein_tokenizer_index, "protein_tokenizer") + + return outputs + + # overwrite to load the protein tokenizer from a separate folder + # Adapted from instructblip.processing_instructblip.py (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/transformers/models/instructblip/processing_instructblip.py#L191-L221) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + # if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs' + if isinstance(processor, tuple): + processor = processor[0] + protein_tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, subfolder=cls.protein_tokenizer_dir_name + ) + + processor.protein_tokenizer = protein_tokenizer + + return processor + + +__all__ = ["EvollaProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9749c5e1e982cd1ac49db2d560238e4d45907c62 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_falcon_h1 import * + from .modeling_falcon_h1 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/configuration_falcon_h1.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/configuration_falcon_h1.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9aaaf3405f066d90731a3e84fd41cb3b352876 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/configuration_falcon_h1.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FalconH1 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class FalconH1Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a + FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration + with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf). + The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU. + The checkpoints are jointly trained by IBM, Princeton, and UIUC. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 128000): + Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconH1Model`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 8192): + Max cached sequence length for the model + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mamba_d_ssm (`int`, *optional*, defaults to 1024): + The dimension of the SSM state space latents. + mamba_n_heads (`int`, *optional*, defaults to 128): + The number of mamba heads used in the v2 implementation. + mamba_d_head (`int`, *optional*, defaults to `"auto"`): + Head embeddding dimension size + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups used in the v2 implementation. + mamba_d_state (`int`, *optional*, defaults to 256): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + mamba_norm_before_gate (`bool`, *optional*, defaults to `True`): + Whether to use RMSNorm before the gate in the Mamba block + mamba_rms_norm (`bool`, *optional*, defaults to `False`): + Whether to use RMSNorm instead of LayerNorm in the Mamba block + projectors_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block + rope_theta (`float`, *optional*, defaults to 100000.0): + The theta value used for the RoPE embeddings. + rope_scaling (`float`, *optional*): + The scaling value used for the RoPE embeddings. If `None`, no scaling is applied. + lm_head_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the LM head. This is used to scale the output of the LM head. + embedding_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the embedding layer. This is used to scale the output of the embedding layer. + mlp_multipliers (`list[float]`, *optional*): + The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is + the multiplier of gate layer, the second value is the multiplier of the down_proj layer. + key_multiplier (`float`, *optional*): + The multiplier for the key layer. This is used to scale the output of the key layer. + attention_out_multiplier (`float`, *optional*): + The multiplier for the attention output layer. This is used to scale the output of the attention output + attention_in_multiplier (`float`, *optional*): + The multiplier for the attention input layer. This is used to scale the output of the attention input layer. + ssm_multipliers (`list[float]`, *optional*): + The multipliers for the SSM layers. This is used to scale the output of the SSM layers. + ssm_in_multiplier (`float`, *optional*): + The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer. + ssm_out_multiplier (`float`, *optional*): + The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer. + """ + + model_type = "falcon_h1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128000, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + max_position_embeddings=8192, + attention_dropout=0.0, + mamba_d_ssm=1024, + mamba_n_heads=128, + mamba_d_head="auto", + mamba_n_groups=1, + mamba_d_state=256, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_norm_before_gate=True, + mamba_rms_norm=False, + projectors_bias=False, + rope_theta=100000.0, + rope_scaling=None, + lm_head_multiplier=1.0, + embedding_multiplier=1.0, + mlp_multipliers=None, + key_multiplier=None, + attention_out_multiplier=None, + attention_in_multiplier=None, + ssm_multipliers=None, + ssm_in_multiplier=None, + ssm_out_multiplier=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.attention_bias = False + self.mlp_bias = False + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.rope_theta = rope_theta + self.rope_scaling = None + self.rope_scaling = rope_scaling + self.projectors_bias = projectors_bias + mamba_intermediate = mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm + + if mamba_intermediate % mamba_n_heads != 0: + raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size") + + # for the mamba_v2, must satisfy the following + if mamba_d_head == "auto": + mamba_d_head = mamba_intermediate // mamba_n_heads + + if mamba_d_head * mamba_n_heads != mamba_intermediate: + raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size") + + self.mamba_d_ssm = mamba_d_ssm + self.mamba_n_heads = mamba_n_heads + self.mamba_d_head = mamba_d_head + self.mamba_n_groups = mamba_n_groups + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + self.mamba_norm_before_gate = mamba_norm_before_gate + self.mamba_rms_norm = mamba_rms_norm + + self.lm_head_multiplier = lm_head_multiplier + self.embedding_multiplier = embedding_multiplier + + if mlp_multipliers is not None: + self.mlp_multipliers = mlp_multipliers + else: + self.mlp_multipliers = [1.0, 1.0] + + if attention_out_multiplier is not None: + self.attention_out_multiplier = attention_out_multiplier + else: + self.attention_out_multiplier = 1.0 + + if attention_in_multiplier is not None: + self.attention_in_multiplier = attention_in_multiplier + else: + self.attention_in_multiplier = 1.0 + + if key_multiplier is not None: + self.key_multiplier = key_multiplier + else: + self.key_multiplier = 1.0 + + if ssm_multipliers is not None: + self.ssm_multipliers = ssm_multipliers + else: + self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0] + + if ssm_in_multiplier is not None: + self.ssm_in_multiplier = ssm_in_multiplier + else: + self.ssm_in_multiplier = 1.0 + + if ssm_out_multiplier is not None: + self.ssm_out_multiplier = ssm_out_multiplier + else: + self.ssm_out_multiplier = 1.0 + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return ["attention" for i in range(self.num_hidden_layers)] + + +__all__ = ["FalconH1Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/modeling_falcon_h1.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/modeling_falcon_h1.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8b13ef21d00e400fa19d4e7cea7ea0ab89e4a3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -0,0 +1,1619 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_h1/modular_falcon_h1.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_h1.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Technology Innovation Institute and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from .configuration_falcon_h1 import FalconH1Config + + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + + +logger = logging.get_logger(__name__) + + +class FalconHybridMambaAttentionDynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__( + self, + config: FalconH1Config, + batch_size: int, + dtype: torch.dtype = torch.float16, + devices: Optional[list[str]] = None, + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.has_previous_state = False + self.conv_kernel_size = config.mamba_d_conv + + self.intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, + self.conv_kernel_size, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + config.mamba_d_state, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + self.transformer_layers.append(i) + + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the cache + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def update_conv_state( + self, + layer_idx: int, + new_conv_state: torch.Tensor, + cache_position: torch.LongTensor, + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + if len(cache_position) > 1: + conv_state[:, :, :] = new_conv_state.to(conv_state.device) + else: + conv_state[:, :, -1] = new_conv_state[:, :, -1].to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class FalconH1RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: FalconH1Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class FalconH1Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.key_multiplier = config.key_multiplier + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.key_multiplier + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class FalconH1RMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.n_groups = n_groups + self.norm_before_gate = norm_before_gate + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + + if not self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + if len(hidden_states.shape) == 3: + batch_size, seq_len, dim = hidden_states.shape + else: + batch_size, dim = hidden_states.shape + seq_len = 1 + hidden_states = hidden_states.to(torch.float32) + + hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states + hidden_states = hidden_states.view(batch_size, seq_len, dim) + + if seq_len == 1: + hidden_states = hidden_states.squeeze(1) + + if self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class FalconH1Mixer(nn.Module): + """ + FalconH1Mixer is identical to classic Mamba2 mixer classes but differs on two different things + - Users can pass custom intermediate_size through `config.mamba_d_ssm` + - The use of gated RMS normalization layer is optional + """ + + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = ( + int(config.mamba_expand * self.hidden_size) if config.mamba_d_ssm is None else config.mamba_d_ssm + ) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + self.groups_time_state_size = config.mamba_n_groups * self.ssm_state_size + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.mamba_rms_norm = config.mamba_rms_norm + + if self.mamba_rms_norm: + self.norm = FalconH1RMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon, + n_groups=self.n_groups, + norm_before_gate=config.mamba_norm_before_gate, + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + + self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for FalconH1 will be used when running the model on a GPU") + + self.zxbcdt_multipliers = config.ssm_multipliers + self.ssm_in_multiplier = config.ssm_in_multiplier + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + # Add Multipliers + hidden_states = hidden_states * self.ssm_in_multiplier + projected_states = self.in_proj(hidden_states) + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + d_mlp = (projected_states.squeeze(1).shape[-1] - d_to_remove) // 2 + + z0, x0, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=gate.view(batch_size, self.num_heads, self.head_dim) if not self.mamba_rms_norm else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + + if self.mamba_rms_norm: + hidden_states = self.norm(hidden_states, gate) + + if d_mlp > 0: + hidden_states = torch.cat([F.silu(z0) * x0, hidden_states], dim=-1) + + # 4. Final linear projection + out = self.out_proj(hidden_states[:, None, ...]) + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight if self.mamba_rms_norm else None, + rmsnorm_eps=self.norm.variance_epsilon if self.mamba_rms_norm else None, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + if attention_mask is not None: + projected_states = projected_states * attention_mask[..., None] + _, gate, hidden_states_B_C, dt = projected_states.split( + [ + 2 * d_mlp, + self.intermediate_size, + self.conv_dim, + self.num_heads, + ], + dim=-1, + ) + + if cache_params is not None: + conv_states = F.pad( + hidden_states_B_C.permute(0, 2, 1), + (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + + time_step = nn.functional.softplus(dt + self.dt_bias) + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # This is a hack to make sure multi-GPU inference works with HF accelerate + # see: https://github.com/Dao-AILab/flash-attention/issues/523 for more details + with torch.cuda.device(hidden_states.device): + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + if self.mamba_rms_norm: + out = self.norm(scan_output, gate) + else: + out = scan_output * torch.nn.functional.silu(gate) + out = self.out_proj(out) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + # Add Multipliers + input_states = input_states * self.ssm_in_multiplier + projected_states = self.in_proj(input_states) + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers + gate, hidden_states_B_C, dt = projected_states.split([ + self.intermediate_size, self.conv_dim, self.num_heads + ], dim=-1) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + if self.mamba_rms_norm: + scan_output = self.norm(y, gate) + else: + scan_output = y * torch.nn.functional.silu(gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class FalconH1MLP(nn.Module): + def __init__(self, config: FalconH1Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + self.gate_multiplier, self.down_multiplier = config.mlp_multipliers + + def forward(self, x): + y = self.up_proj(x) * self.act_fn(self.gate_proj(x) * self.gate_multiplier) + y = self.down_proj(y) * self.down_multiplier + return y + + +@use_kernel_forward_from_hub("RMSNorm") +class FalconH1RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + FalconH1RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class FalconH1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.feed_forward = FalconH1MLP(config) + + head_dim = config.hidden_size // config.num_attention_heads + self.channels_attn = config.num_attention_heads * head_dim + 2 * config.num_key_value_heads * head_dim + + self.mamba = FalconH1Mixer(config=config, layer_idx=layer_idx) + + self.self_attn = FalconH1Attention(config, layer_idx) + + self.attention_in_multiplier = config.attention_in_multiplier + self.ssm_out_multiplier = config.ssm_out_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.input_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + mamba_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + mamba_hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=mamba_attention_mask, + ) + mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier + + attention_hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states * self.attention_in_multiplier, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + + hidden_states = mamba_hidden_states + attention_hidden_states + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class FalconH1PreTrainedModel(PreTrainedModel): + config: FalconH1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FalconH1DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + for name, param in module.named_parameters(recurse=True): + if not param.requires_grad: + continue + if "layernorm" in name.lower() and "weight" in name: + # LayerNorm weights usually initialized to 1 + param.data.fill_(1.0) + elif "bias" in name: + param.data.zero_() + else: + try: + param.data.normal_(mean=0.0, std=std) + except Exception as e: + print(f"Skipping init for {name} due to error: {e}") + + +def compute_mup_vector(config): + """ + Computes the MuP vector based on model configuration. + + FalconH1 applies different MuP multiplier for each dimension of the hidden states. + The MuP vector is partitioned into chunks, and each chunk is multiplied with its + corresponding projected dimension. + + Args: + config: FalconH1Config object + + Returns: + torch.Tensor: The computed MuP vector + """ + # We'll need some values from the config to compute the vector dimensions + intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + groups_time_state_size = config.mamba_n_groups * config.mamba_d_state + num_heads = config.mamba_n_heads + zxbcdt_multipliers = config.ssm_multipliers + + vector_shape = 2 * intermediate_size + 2 * groups_time_state_size + num_heads + mup_vector = torch.ones(1, 1, vector_shape) + + # Apply multipliers to different sections of the vector + mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] + mup_vector[:, :, intermediate_size : 2 * intermediate_size] *= zxbcdt_multipliers[1] + mup_vector[:, :, 2 * intermediate_size : 2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] + mup_vector[ + :, :, 2 * intermediate_size + groups_time_state_size : 2 * intermediate_size + 2 * groups_time_state_size + ] *= zxbcdt_multipliers[3] + mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size :] *= zxbcdt_multipliers[4] + + return mup_vector + + +@auto_docstring +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class FalconH1Model(FalconH1PreTrainedModel): + def __init__(self, config: FalconH1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(FalconH1DecoderLayer(config, layer_idx=i)) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = FalconH1RotaryEmbedding(config=config) + + self.embedding_multiplier = config.embedding_multiplier + self.lm_head_multiplier = config.lm_head_multiplier + + self.gradient_checkpointing = False + # Compute the MuP vector once and register it for all layers + mup_vector = compute_mup_vector(config) + for layer in self.layers: + layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embedding_multiplier + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + mamba_attention_mask=mamba_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: FalconHybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring +class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = FalconH1Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconH1ForCausalLM + + >>> model = FalconH1ForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) * self.model.lm_head_multiplier + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = FalconHybridMambaAttentionDynamicCache( + self.config, + input_ids.shape[0], + self.dtype, + devices=[ + self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) + ], + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + + +__all__ = ["FalconH1Model", "FalconH1ForCausalLM", "FalconH1PreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/modular_falcon_h1.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/modular_falcon_h1.py new file mode 100644 index 0000000000000000000000000000000000000000..fe716dded4b35dad154084d3ca7afc632bf8e996 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_h1/modular_falcon_h1.py @@ -0,0 +1,1383 @@ +# coding=utf-8 +# Copyright 2025 Technology Innovation Institute and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FalconH1 model.""" + +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN +from transformers.models.jamba.modeling_jamba import HybridMambaAttentionDynamicCache +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.models.mamba2.modeling_mamba2 import ( + MambaRMSNormGated, + pad_tensor_by_size, + reshape_into_chunks, + segment_sum, +) + +from ...cache_utils import Cache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from .configuration_falcon_h1 import FalconH1Config + + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +logger = logging.get_logger(__name__) + + +class FalconHybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__( + self, + config: FalconH1Config, + batch_size: int, + dtype: torch.dtype = torch.float16, + devices: Optional[list[str]] = None, + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.has_previous_state = False + self.conv_kernel_size = config.mamba_d_conv + + self.intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, + self.conv_kernel_size, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + config.mamba_d_state, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + self.transformer_layers.append(i) + + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the cache + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def update_conv_state( + self, + layer_idx: int, + new_conv_state: torch.Tensor, + cache_position: torch.LongTensor, + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + if len(cache_position) > 1: + conv_state[:, :, :] = new_conv_state.to(conv_state.device) + else: + conv_state[:, :, -1] = new_conv_state[:, :, -1].to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class FalconH1RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class FalconH1Attention(LlamaAttention): + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__(config, layer_idx) + self.key_multiplier = config.key_multiplier + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.key_multiplier + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class FalconH1RMSNormGated(MambaRMSNormGated): + def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True): + super().__init__(hidden_size=hidden_size, eps=eps) + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.n_groups = n_groups + self.norm_before_gate = norm_before_gate + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + + if not self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + if len(hidden_states.shape) == 3: + batch_size, seq_len, dim = hidden_states.shape + else: + batch_size, dim = hidden_states.shape + seq_len = 1 + hidden_states = hidden_states.to(torch.float32) + + hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states + hidden_states = hidden_states.view(batch_size, seq_len, dim) + + if seq_len == 1: + hidden_states = hidden_states.squeeze(1) + + if self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class FalconH1Mixer(nn.Module): + """ + FalconH1Mixer is identical to classic Mamba2 mixer classes but differs on two different things + - Users can pass custom intermediate_size through `config.mamba_d_ssm` + - The use of gated RMS normalization layer is optional + """ + + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = ( + int(config.mamba_expand * self.hidden_size) if config.mamba_d_ssm is None else config.mamba_d_ssm + ) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + self.groups_time_state_size = config.mamba_n_groups * self.ssm_state_size + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.mamba_rms_norm = config.mamba_rms_norm + + if self.mamba_rms_norm: + self.norm = FalconH1RMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon, + n_groups=self.n_groups, + norm_before_gate=config.mamba_norm_before_gate, + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + + self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for FalconH1 will be used when running the model on a GPU") + + self.zxbcdt_multipliers = config.ssm_multipliers + self.ssm_in_multiplier = config.ssm_in_multiplier + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + # Add Multipliers + hidden_states = hidden_states * self.ssm_in_multiplier + projected_states = self.in_proj(hidden_states) + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + d_mlp = (projected_states.squeeze(1).shape[-1] - d_to_remove) // 2 + + z0, x0, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=gate.view(batch_size, self.num_heads, self.head_dim) if not self.mamba_rms_norm else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + + if self.mamba_rms_norm: + hidden_states = self.norm(hidden_states, gate) + + if d_mlp > 0: + hidden_states = torch.cat([F.silu(z0) * x0, hidden_states], dim=-1) + + # 4. Final linear projection + out = self.out_proj(hidden_states[:, None, ...]) + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight if self.mamba_rms_norm else None, + rmsnorm_eps=self.norm.variance_epsilon if self.mamba_rms_norm else None, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + if attention_mask is not None: + projected_states = projected_states * attention_mask[..., None] + _, gate, hidden_states_B_C, dt = projected_states.split( + [ + 2 * d_mlp, + self.intermediate_size, + self.conv_dim, + self.num_heads, + ], + dim=-1, + ) + + if cache_params is not None: + conv_states = F.pad( + hidden_states_B_C.permute(0, 2, 1), + (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + + time_step = nn.functional.softplus(dt + self.dt_bias) + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # This is a hack to make sure multi-GPU inference works with HF accelerate + # see: https://github.com/Dao-AILab/flash-attention/issues/523 for more details + with torch.cuda.device(hidden_states.device): + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + if self.mamba_rms_norm: + out = self.norm(scan_output, gate) + else: + out = scan_output * torch.nn.functional.silu(gate) + out = self.out_proj(out) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + # Add Multipliers + input_states = input_states * self.ssm_in_multiplier + projected_states = self.in_proj(input_states) + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers + gate, hidden_states_B_C, dt = projected_states.split([ + self.intermediate_size, self.conv_dim, self.num_heads + ], dim=-1) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + if self.mamba_rms_norm: + scan_output = self.norm(y, gate) + else: + scan_output = y * torch.nn.functional.silu(gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class FalconH1MLP(LlamaMLP): + def __init__(self, config: FalconH1Config): + super().__init__(config) + self.gate_multiplier, self.down_multiplier = config.mlp_multipliers + + def forward(self, x): + y = self.up_proj(x) * self.act_fn(self.gate_proj(x) * self.gate_multiplier) + y = self.down_proj(y) * self.down_multiplier + return y + + +class FalconH1RMSNorm(LlamaRMSNorm): + pass + + +class FalconH1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.feed_forward = FalconH1MLP(config) + + head_dim = config.hidden_size // config.num_attention_heads + self.channels_attn = config.num_attention_heads * head_dim + 2 * config.num_key_value_heads * head_dim + + self.mamba = FalconH1Mixer(config=config, layer_idx=layer_idx) + + self.self_attn = FalconH1Attention(config, layer_idx) + + self.attention_in_multiplier = config.attention_in_multiplier + self.ssm_out_multiplier = config.ssm_out_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.input_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + mamba_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + mamba_hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=mamba_attention_mask, + ) + mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier + + attention_hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states * self.attention_in_multiplier, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + + hidden_states = mamba_hidden_states + attention_hidden_states + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class FalconH1PreTrainedModel(PreTrainedModel): + config: FalconH1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FalconH1DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + for name, param in module.named_parameters(recurse=True): + if not param.requires_grad: + continue + if "layernorm" in name.lower() and "weight" in name: + # LayerNorm weights usually initialized to 1 + param.data.fill_(1.0) + elif "bias" in name: + param.data.zero_() + else: + try: + param.data.normal_(mean=0.0, std=std) + except Exception as e: + print(f"Skipping init for {name} due to error: {e}") + + +def compute_mup_vector(config): + """ + Computes the MuP vector based on model configuration. + + FalconH1 applies different MuP multiplier for each dimension of the hidden states. + The MuP vector is partitioned into chunks, and each chunk is multiplied with its + corresponding projected dimension. + + Args: + config: FalconH1Config object + + Returns: + torch.Tensor: The computed MuP vector + """ + # We'll need some values from the config to compute the vector dimensions + intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + groups_time_state_size = config.mamba_n_groups * config.mamba_d_state + num_heads = config.mamba_n_heads + zxbcdt_multipliers = config.ssm_multipliers + + vector_shape = 2 * intermediate_size + 2 * groups_time_state_size + num_heads + mup_vector = torch.ones(1, 1, vector_shape) + + # Apply multipliers to different sections of the vector + mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] + mup_vector[:, :, intermediate_size : 2 * intermediate_size] *= zxbcdt_multipliers[1] + mup_vector[:, :, 2 * intermediate_size : 2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] + mup_vector[ + :, :, 2 * intermediate_size + groups_time_state_size : 2 * intermediate_size + 2 * groups_time_state_size + ] *= zxbcdt_multipliers[3] + mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size :] *= zxbcdt_multipliers[4] + + return mup_vector + + +@auto_docstring +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class FalconH1Model(FalconH1PreTrainedModel): + def __init__(self, config: FalconH1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(FalconH1DecoderLayer(config, layer_idx=i)) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = FalconH1RotaryEmbedding(config=config) + + self.embedding_multiplier = config.embedding_multiplier + self.lm_head_multiplier = config.lm_head_multiplier + + self.gradient_checkpointing = False + # Compute the MuP vector once and register it for all layers + mup_vector = compute_mup_vector(config) + for layer in self.layers: + layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embedding_multiplier + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + mamba_attention_mask=mamba_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: FalconHybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class FalconH1ForCausalLM(LlamaForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconH1ForCausalLM + + >>> model = FalconH1ForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) * self.model.lm_head_multiplier + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = FalconHybridMambaAttentionDynamicCache( + self.config, + input_ids.shape[0], + self.dtype, + devices=[ + self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) + ], + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + + +__all__ = ["FalconH1Model", "FalconH1ForCausalLM", "FalconH1PreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..202147c938465dd7dfcb7e79ecbeeb93ce632dbf --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_falcon_mamba import * + from .modeling_falcon_mamba import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..7630ebd6343ac968303fc0c31f2742bb352b4f8a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -0,0 +1,170 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from ...configuration_utils import PretrainedConfig + + +class FalconMambaConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the FALCON_MAMBA + [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50280): + Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconMambaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 16): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_scale (`float`, *optional*, defaults to 1.0): + Scale used used to scale `dt_proj.bias`. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_init_scheme (`float`, *optional*, defaults to `"random"`): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): + This argument corresponds to `use_mambapy` in MambaConfig. + Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. + mixer_rms_eps (`float`, *optional*, defaults to 1e-06): + The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + + Example: + + ```python + >>> from transformers import FalconMambaConfig, FalconMambaModel + + >>> # Initializing a FalconMamba configuration + >>> configuration = FalconMambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = FalconMambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "falcon_mamba" + + def __init__( + self, + vocab_size=50280, + hidden_size=768, + state_size=16, + num_hidden_layers=32, + layer_norm_epsilon=1e-5, + pad_token_id=0, + bos_token_id=0, + eos_token_id=0, + expand=2, + conv_kernel=4, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=0.1, + time_step_init_scheme="random", + time_step_floor=1e-4, + rescale_prenorm_residual=False, + use_cache=True, + use_falcon_mambapy=False, + mixer_rms_eps=1e-6, + **kwargs, + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.layer_norm_epsilon = layer_norm_epsilon + self.conv_kernel = conv_kernel + self.expand = expand + # This is needed since mamba overrides the intermediate_size attribute + self.intermediate_size = ( + int(expand * self.hidden_size) + if kwargs.get("intermediate_size") is None + else kwargs.get("intermediate_size") + ) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.use_falcon_mambapy = use_falcon_mambapy + self.mixer_rms_eps = mixer_rms_eps + + +__all__ = ["FalconMambaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..3cdf6da7bda376a34dd00545d96c44a99fc0e660 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -0,0 +1,937 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_kernels_available, + is_mamba_ssm_available, + is_mambapy_available, +) +from .configuration_falcon_mamba import FalconMambaConfig + + +if is_mambapy_available(): + from mambapy.pscan import pscan +else: + pscan = None + +if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + + from ...kernels.falcon_mamba import mamba_inner_fn +else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + +logger = logging.get_logger(__name__) + + +class FalconMambaCache: + """ + Cache for falcon_mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if + a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache + + >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") + + >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device) # sequence length + >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True) + >>> outputs.cache_params + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. FalconMamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx].zero_() + self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + +def _lazy_load_causal_conv1d(): + global _causal_conv1d_cache + if _causal_conv1d_cache is not None: + return _causal_conv1d_cache + + if is_kernels_available(): + from kernels import get_kernel + + _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d") + _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn) + elif is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + + _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn) + else: + _causal_conv1d_cache = (None, None) + return _causal_conv1d_cache + + +_causal_conv1d_cache = None + + +def rms_forward(hidden_states, variance_epsilon=1e-6): + """ + Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will + leverage this in order to multiply the final result with the RMSNorm weight + + Args: + hidden_states (`torch.Tensor`): + Hidden states to normalize + variance_epsilon (`float`): + The eps value to add in the square root scaling factor + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return hidden_states.to(input_dtype) + + +class FalconMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see FalconMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between FalconMamba and the linear time invariant S4, + and is why FalconMamba is called **selective** state spaces) + """ + + def __init__(self, config: FalconMambaConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.intermediate_size, + padding=config.conv_kernel - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + self.use_falcon_mambapy = config.use_falcon_mambapy + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + self.warn_slow_implementation() + # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here + self.register_buffer( + "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False + ) + self.register_buffer( + "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False + ) + self.rms_eps = config.mixer_rms_eps + + def warn_slow_implementation(self): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if not is_fast_path_available: + if self.use_falcon_mambapy: + if is_mambapy_available(): + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d" + ) + else: + raise ImportError( + "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + ) + else: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + b_rms_weight=self.b_c_rms, + c_rms_weight=self.b_c_rms, + dt_rms_weight=self.dt_rms, + b_c_dt_rms_eps=self.rms_eps, + ) + + else: + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_position[0] > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module + # at the price of a small overhead. + if hasattr(self.config, "_pre_quantization_dtype"): + discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2) + else: + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_position[0] > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + # fmt: off + def slow_forward(self, + input_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + hidden_states = self.act( + self.conv1d(hidden_states)[..., :seq_len] + ) # [batch, intermediate_size, seq_len] + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + conv_state = conv_state.to(self.conv1d.weight.device) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = ( + self.act(hidden_states).to(dtype).unsqueeze(-1) + ) # [batch, intermediate_size, 1] : decoding + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose( + 1, 2 + ) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp( + A[None, :, None, :] * discrete_time_step[:, :, :, None] + ) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = ( + discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + ) # [batch, intermediate_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + if self.use_falcon_mambapy and self.training and cache_params is None: + hs = pscan( + discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) + ) # [batch, seq_len, intermediate_size, ssm_state_size] + scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = scan_output + hidden_states * self.D[None, :, None] + scan_output = scan_output * self.act(gate) + else: + scan_outputs = [] + for i in range(seq_len): + ssm_state = ( + discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + ) # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul( + ssm_state.to(dtype), C[:, i, :].unsqueeze(-1) + ) # [batch, intermediate_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = scan_output * self.act(gate) + + if cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class FalconMambaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + FalconMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return self.weight.to(hidden_states.device) * rms_forward( + hidden_states, variance_epsilon=self.variance_epsilon + ) + + def extra_repr(self): + return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" + + +class FalconMambaBlock(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = FalconMambaMixer(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class FalconMambaPreTrainedModel(PreTrainedModel): + config: FalconMambaConfig + base_model_prefix = "backbone" + _no_split_modules = ["FalconMambaBlock", "FalconMambaMixer"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + std = self.config.initializer_range + if isinstance(module, FalconMambaMixer): + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(module.intermediate_size, -1).contiguous() + module.A_log.copy_(torch.log(A)) + module.D.data.fill_(1.0) + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + module.dt_proj.bias.copy_(inv_dt) + module.dt_proj.bias._no_reinit = True + + nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5)) + if module.conv1d.bias is not None: + if not getattr(module.conv1d.bias, "_no_reinit", False): + nn.init.zeros_(module.conv1d.bias) + nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5)) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + p = module.out_proj.weight + p /= math.sqrt(self.config.num_hidden_layers) + + if isinstance(module, nn.Linear): + if not getattr(module.weight, "_no_reinit", False): + nn.init.normal_(module.weight, std=std) + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, FalconMambaRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=std) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for the FALCON_MAMBA model outputs. + """ +) +class FalconMambaOutput(ModelOutput): + r""" + cache_params (`FalconMambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[FalconMambaCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for causal language model (or autoregressive) outputs. + """ +) +class FalconMambaCausalLMOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`FalconMambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[FalconMambaCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class FalconMambaModel(FalconMambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] + ) + + self.gradient_checkpointing = False + self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[FalconMambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[tuple, FalconMambaOutput]: + r""" + cache_params (`FalconMambaCache`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = FalconMambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return FalconMambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """ +) +class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = FalconMambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: dict[str, Any], num_new_tokens: int = 1, **kwargs + ) -> dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + if ( + model_kwargs.get("use_cache", True) + and "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **kwargs, + ): + # Overwritten -- uses `cache_params` as opposed to `past_key_values` + model_inputs = {"input_ids": input_ids.contiguous()} + if use_cache and cache_params is None: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) + if inputs_embeds is not None: + model_inputs = {"inputs_embeds": inputs_embeds} + max_batch_size = inputs_embeds.size(0) + else: + max_batch_size = input_ids.size(0) + cache_params = FalconMambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) + + if use_cache and cache_position[0] > 0: + model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous() + attention_mask = None + + if not use_cache and inputs_embeds is not None: + model_inputs = {"inputs_embeds": inputs_embeds} + + model_inputs.update( + { + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + "attention_mask": attention_mask, + } + ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[FalconMambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[tuple, FalconMambaCausalLMOutput]: + r""" + cache_params (`FalconMambaCache`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + falcon_mamba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = falcon_mamba_outputs[0] + + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + falcon_mamba_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return FalconMambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=falcon_mamba_outputs.cache_params, + hidden_states=falcon_mamba_outputs.hidden_states, + ) + + +__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..6df2be3a2652cf47100b82aa69be3c2554ba4161 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -0,0 +1,582 @@ +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FALCONMAMBA model.""" + +from typing import Optional + +import torch +from torch import nn + +from ...utils import auto_docstring, logging +from ...utils.import_utils import ( + is_mamba_ssm_available, + is_mambapy_available, +) +from ..mamba.configuration_mamba import MambaConfig +from ..mamba.modeling_mamba import ( + MambaBlock, + MambaCache, + MambaCausalLMOutput, + MambaForCausalLM, + MambaMixer, + MambaModel, + MambaOutput, + MambaPreTrainedModel, + MambaRMSNorm, + _lazy_load_causal_conv1d, +) + + +logger = logging.get_logger(__name__) + +if is_mambapy_available(): + from mambapy.pscan import pscan +else: + pscan = None + +if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + + from ...kernels.falcon_mamba import mamba_inner_fn +else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +_causal_conv1d_cache = None + + +class FalconMambaConfig(MambaConfig): + """ + This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the FALCON_MAMBA + [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50280): + Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconMambaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 16): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_scale (`float`, *optional*, defaults to 1.0): + Scale used used to scale `dt_proj.bias`. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_init_scheme (`float`, *optional*, defaults to `"random"`): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): + This argument corresponds to `use_mambapy` in MambaConfig. + Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. + mixer_rms_eps (`float`, *optional*, defaults to 1e-06): + The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + + Example: + + ```python + >>> from transformers import FalconMambaConfig, FalconMambaModel + + >>> # Initializing a FalconMamba configuration + >>> configuration = FalconMambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = FalconMambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + vocab_size=50280, + hidden_size=768, + state_size=16, + num_hidden_layers=32, + layer_norm_epsilon=1e-5, + pad_token_id=0, + bos_token_id=0, + eos_token_id=0, + expand=2, + conv_kernel=4, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=0.1, + time_step_init_scheme="random", + time_step_floor=1e-4, + rescale_prenorm_residual=False, + use_cache=True, + use_falcon_mambapy=False, + mixer_rms_eps=1e-6, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + state_size=state_size, + num_hidden_layers=num_hidden_layers, + layer_norm_epsilon=layer_norm_epsilon, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + expand=expand, + conv_kernel=conv_kernel, + use_bias=use_bias, + use_conv_bias=use_conv_bias, + hidden_act=hidden_act, + initializer_range=initializer_range, + residual_in_fp32=residual_in_fp32, + time_step_rank=time_step_rank, + time_step_scale=time_step_scale, + time_step_min=time_step_min, + time_step_max=time_step_max, + time_step_init_scheme=time_step_init_scheme, + time_step_floor=time_step_floor, + rescale_prenorm_residual=rescale_prenorm_residual, + use_cache=use_cache, + use_falcon_mambapy=use_falcon_mambapy, + **kwargs, + ) + self.mixer_rms_eps = mixer_rms_eps + # This is needed since mamba overrides the intermediate_size attribute + self.intermediate_size = ( + int(expand * self.hidden_size) + if kwargs.get("intermediate_size") is None + else kwargs.get("intermediate_size") + ) + + +class FalconMambaCache(MambaCache): + """ + Cache for falcon_mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if + a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache + + >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") + + >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device) # sequence length + >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True) + >>> outputs.cache_params + ``` + """ + + pass + + +def rms_forward(hidden_states, variance_epsilon=1e-6): + """ + Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will + leverage this in order to multiply the final result with the RMSNorm weight + + Args: + hidden_states (`torch.Tensor`): + Hidden states to normalize + variance_epsilon (`float`): + The eps value to add in the square root scaling factor + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return hidden_states.to(input_dtype) + + +class FalconMambaMixer(MambaMixer): + def warn_slow_implementation(self): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if not is_fast_path_available: + if self.use_falcon_mambapy: + if is_mambapy_available(): + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d" + ) + else: + raise ImportError( + "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + ) + else: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + ) + + def __init__(self, config: FalconMambaConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here + self.register_buffer( + "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False + ) + self.register_buffer( + "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False + ) + self.rms_eps = config.mixer_rms_eps + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + b_rms_weight=self.b_c_rms, + c_rms_weight=self.b_c_rms, + dt_rms_weight=self.dt_rms, + b_c_dt_rms_eps=self.rms_eps, + ) + + else: + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_position[0] > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module + # at the price of a small overhead. + if hasattr(self.config, "_pre_quantization_dtype"): + discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2) + else: + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_position[0] > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward( + self, + input_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + hidden_states = self.act( + self.conv1d(hidden_states)[..., :seq_len] + ) # [batch, intermediate_size, seq_len] + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + conv_state = conv_state.to(self.conv1d.weight.device) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = ( + self.act(hidden_states).to(dtype).unsqueeze(-1) + ) # [batch, intermediate_size, 1] : decoding + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose( + 1, 2 + ) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp( + A[None, :, None, :] * discrete_time_step[:, :, :, None] + ) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = ( + discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + ) # [batch, intermediate_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + if self.use_falcon_mambapy and self.training and cache_params is None: + hs = pscan( + discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) + ) # [batch, seq_len, intermediate_size, ssm_state_size] + scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = scan_output + hidden_states * self.D[None, :, None] + scan_output = scan_output * self.act(gate) + else: + scan_outputs = [] + for i in range(seq_len): + ssm_state = ( + discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + ) # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul( + ssm_state.to(dtype), C[:, i, :].unsqueeze(-1) + ) # [batch, intermediate_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = scan_output * self.act(gate) + + if cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + + def forward( + self, + hidden_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class FalconMambaRMSNorm(MambaRMSNorm): + def forward(self, hidden_states): + return self.weight.to(hidden_states.device) * rms_forward( + hidden_states, variance_epsilon=self.variance_epsilon + ) + + +class FalconMambaBlock(MambaBlock): + pass + + +@auto_docstring +class FalconMambaPreTrainedModel(MambaPreTrainedModel): + pass + + +class FalconMambaOutput(MambaOutput): + pass + + +class FalconMambaCausalLMOutput(MambaCausalLMOutput): + pass + + +class FalconMambaModel(MambaModel, FalconMambaPreTrainedModel): + def __init__(self, config): + FalconMambaPreTrainedModel.__init__(self, config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] + ) + + self.gradient_checkpointing = False + self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + raise AttributeError("Not needed for FalconMamba") + + +class FalconMambaForCausalLM(MambaForCausalLM): + pass + + +__all__ = [ + "FalconMambaForCausalLM", + "FalconMambaModel", + "FalconMambaPreTrainedModel", + "FalconMambaCache", + "FalconMambaConfig", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44d1ec7236310774ed6b1379683c144d7f93ecce --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_fastspeech2_conformer import * + from .modeling_fastspeech2_conformer import * + from .tokenization_fastspeech2_conformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..89d65a261c64fbabe493aa37677aaacb6f226a3b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py @@ -0,0 +1,480 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FastSpeech2Conformer model configuration""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class FastSpeech2ConformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FastSpeech2ConformerModel`]. It is used to + instantiate a FastSpeech2Conformer model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the + FastSpeech2Conformer [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 384): + The dimensionality of the hidden layers. + vocab_size (`int`, *optional*, defaults to 78): + The size of the vocabulary. + num_mel_bins (`int`, *optional*, defaults to 80): + The number of mel filters used in the filter bank. + encoder_num_attention_heads (`int`, *optional*, defaults to 2): + The number of attention heads in the encoder. + encoder_layers (`int`, *optional*, defaults to 4): + The number of layers in the encoder. + encoder_linear_units (`int`, *optional*, defaults to 1536): + The number of units in the linear layer of the encoder. + decoder_layers (`int`, *optional*, defaults to 4): + The number of layers in the decoder. + decoder_num_attention_heads (`int`, *optional*, defaults to 2): + The number of attention heads in the decoder. + decoder_linear_units (`int`, *optional*, defaults to 1536): + The number of units in the linear layer of the decoder. + speech_decoder_postnet_layers (`int`, *optional*, defaults to 5): + The number of layers in the post-net of the speech decoder. + speech_decoder_postnet_units (`int`, *optional*, defaults to 256): + The number of units in the post-net layers of the speech decoder. + speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5): + The kernel size in the post-net of the speech decoder. + positionwise_conv_kernel_size (`int`, *optional*, defaults to 3): + The size of the convolution kernel used in the position-wise layer. + encoder_normalize_before (`bool`, *optional*, defaults to `False`): + Specifies whether to normalize before encoder layers. + decoder_normalize_before (`bool`, *optional*, defaults to `False`): + Specifies whether to normalize before decoder layers. + encoder_concat_after (`bool`, *optional*, defaults to `False`): + Specifies whether to concatenate after encoder layers. + decoder_concat_after (`bool`, *optional*, defaults to `False`): + Specifies whether to concatenate after decoder layers. + reduction_factor (`int`, *optional*, defaults to 1): + The factor by which the speech frame rate is reduced. + speaking_speed (`float`, *optional*, defaults to 1.0): + The speed of the speech produced. + use_macaron_style_in_conformer (`bool`, *optional*, defaults to `True`): + Specifies whether to use macaron style in the conformer. + use_cnn_in_conformer (`bool`, *optional*, defaults to `True`): + Specifies whether to use convolutional neural networks in the conformer. + encoder_kernel_size (`int`, *optional*, defaults to 7): + The kernel size used in the encoder. + decoder_kernel_size (`int`, *optional*, defaults to 31): + The kernel size used in the decoder. + duration_predictor_layers (`int`, *optional*, defaults to 2): + The number of layers in the duration predictor. + duration_predictor_channels (`int`, *optional*, defaults to 256): + The number of channels in the duration predictor. + duration_predictor_kernel_size (`int`, *optional*, defaults to 3): + The kernel size used in the duration predictor. + energy_predictor_layers (`int`, *optional*, defaults to 2): + The number of layers in the energy predictor. + energy_predictor_channels (`int`, *optional*, defaults to 256): + The number of channels in the energy predictor. + energy_predictor_kernel_size (`int`, *optional*, defaults to 3): + The kernel size used in the energy predictor. + energy_predictor_dropout (`float`, *optional*, defaults to 0.5): + The dropout rate in the energy predictor. + energy_embed_kernel_size (`int`, *optional*, defaults to 1): + The kernel size used in the energy embed layer. + energy_embed_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate in the energy embed layer. + stop_gradient_from_energy_predictor (`bool`, *optional*, defaults to `False`): + Specifies whether to stop gradients from the energy predictor. + pitch_predictor_layers (`int`, *optional*, defaults to 5): + The number of layers in the pitch predictor. + pitch_predictor_channels (`int`, *optional*, defaults to 256): + The number of channels in the pitch predictor. + pitch_predictor_kernel_size (`int`, *optional*, defaults to 5): + The kernel size used in the pitch predictor. + pitch_predictor_dropout (`float`, *optional*, defaults to 0.5): + The dropout rate in the pitch predictor. + pitch_embed_kernel_size (`int`, *optional*, defaults to 1): + The kernel size used in the pitch embed layer. + pitch_embed_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate in the pitch embed layer. + stop_gradient_from_pitch_predictor (`bool`, *optional*, defaults to `True`): + Specifies whether to stop gradients from the pitch predictor. + encoder_dropout_rate (`float`, *optional*, defaults to 0.2): + The dropout rate in the encoder. + encoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2): + The positional dropout rate in the encoder. + encoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2): + The attention dropout rate in the encoder. + decoder_dropout_rate (`float`, *optional*, defaults to 0.2): + The dropout rate in the decoder. + decoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2): + The positional dropout rate in the decoder. + decoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2): + The attention dropout rate in the decoder. + duration_predictor_dropout_rate (`float`, *optional*, defaults to 0.2): + The dropout rate in the duration predictor. + speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5): + The dropout rate in the speech decoder postnet. + max_source_positions (`int`, *optional*, defaults to 5000): + if `"relative"` position embeddings are used, defines the maximum source input positions. + use_masking (`bool`, *optional*, defaults to `True`): + Specifies whether to use masking in the model. + use_weighted_masking (`bool`, *optional*, defaults to `False`): + Specifies whether to use weighted masking in the model. + num_speakers (`int`, *optional*): + Number of speakers. If set to > 1, assume that the speaker ids will be provided as the input and use + speaker id embedding layer. + num_languages (`int`, *optional*): + Number of languages. If set to > 1, assume that the language ids will be provided as the input and use the + language id embedding layer. + speaker_embed_dim (`int`, *optional*): + Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Specifies whether the model is an encoder-decoder. + + Example: + + ```python + >>> from transformers import FastSpeech2ConformerModel, FastSpeech2ConformerConfig + + >>> # Initializing a FastSpeech2Conformer style configuration + >>> configuration = FastSpeech2ConformerConfig() + + >>> # Initializing a model from the FastSpeech2Conformer style configuration + >>> model = FastSpeech2ConformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "fastspeech2_conformer" + base_config_key = "model_config" + attribute_map = {"num_hidden_layers": "encoder_layers", "num_attention_heads": "encoder_num_attention_heads"} + + def __init__( + self, + hidden_size=384, + vocab_size=78, + num_mel_bins=80, + encoder_num_attention_heads=2, + encoder_layers=4, + encoder_linear_units=1536, + decoder_layers=4, + decoder_num_attention_heads=2, + decoder_linear_units=1536, + speech_decoder_postnet_layers=5, + speech_decoder_postnet_units=256, + speech_decoder_postnet_kernel=5, + positionwise_conv_kernel_size=3, + encoder_normalize_before=False, + decoder_normalize_before=False, + encoder_concat_after=False, + decoder_concat_after=False, + reduction_factor=1, + speaking_speed=1.0, + use_macaron_style_in_conformer=True, + use_cnn_in_conformer=True, + encoder_kernel_size=7, + decoder_kernel_size=31, + duration_predictor_layers=2, + duration_predictor_channels=256, + duration_predictor_kernel_size=3, + energy_predictor_layers=2, + energy_predictor_channels=256, + energy_predictor_kernel_size=3, + energy_predictor_dropout=0.5, + energy_embed_kernel_size=1, + energy_embed_dropout=0.0, + stop_gradient_from_energy_predictor=False, + pitch_predictor_layers=5, + pitch_predictor_channels=256, + pitch_predictor_kernel_size=5, + pitch_predictor_dropout=0.5, + pitch_embed_kernel_size=1, + pitch_embed_dropout=0.0, + stop_gradient_from_pitch_predictor=True, + encoder_dropout_rate=0.2, + encoder_positional_dropout_rate=0.2, + encoder_attention_dropout_rate=0.2, + decoder_dropout_rate=0.2, + decoder_positional_dropout_rate=0.2, + decoder_attention_dropout_rate=0.2, + duration_predictor_dropout_rate=0.2, + speech_decoder_postnet_dropout=0.5, + max_source_positions=5000, + use_masking=True, + use_weighted_masking=False, + num_speakers=None, + num_languages=None, + speaker_embed_dim=None, + is_encoder_decoder=True, + **kwargs, + ): + if positionwise_conv_kernel_size % 2 == 0: + raise ValueError( + f"positionwise_conv_kernel_size must be odd, but got {positionwise_conv_kernel_size} instead." + ) + if encoder_kernel_size % 2 == 0: + raise ValueError(f"encoder_kernel_size must be odd, but got {encoder_kernel_size} instead.") + if decoder_kernel_size % 2 == 0: + raise ValueError(f"decoder_kernel_size must be odd, but got {decoder_kernel_size} instead.") + if duration_predictor_kernel_size % 2 == 0: + raise ValueError( + f"duration_predictor_kernel_size must be odd, but got {duration_predictor_kernel_size} instead." + ) + if energy_predictor_kernel_size % 2 == 0: + raise ValueError( + f"energy_predictor_kernel_size must be odd, but got {energy_predictor_kernel_size} instead." + ) + if energy_embed_kernel_size % 2 == 0: + raise ValueError(f"energy_embed_kernel_size must be odd, but got {energy_embed_kernel_size} instead.") + if pitch_predictor_kernel_size % 2 == 0: + raise ValueError( + f"pitch_predictor_kernel_size must be odd, but got {pitch_predictor_kernel_size} instead." + ) + if pitch_embed_kernel_size % 2 == 0: + raise ValueError(f"pitch_embed_kernel_size must be odd, but got {pitch_embed_kernel_size} instead.") + if hidden_size % encoder_num_attention_heads != 0: + raise ValueError("The hidden_size must be evenly divisible by encoder_num_attention_heads.") + if hidden_size % decoder_num_attention_heads != 0: + raise ValueError("The hidden_size must be evenly divisible by decoder_num_attention_heads.") + if use_masking and use_weighted_masking: + raise ValueError("Either use_masking or use_weighted_masking can be True, but not both.") + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.num_mel_bins = num_mel_bins + self.encoder_config = { + "num_attention_heads": encoder_num_attention_heads, + "layers": encoder_layers, + "kernel_size": encoder_kernel_size, + "attention_dropout_rate": encoder_attention_dropout_rate, + "dropout_rate": encoder_dropout_rate, + "positional_dropout_rate": encoder_positional_dropout_rate, + "linear_units": encoder_linear_units, + "normalize_before": encoder_normalize_before, + "concat_after": encoder_concat_after, + } + self.decoder_config = { + "num_attention_heads": decoder_num_attention_heads, + "layers": decoder_layers, + "kernel_size": decoder_kernel_size, + "attention_dropout_rate": decoder_attention_dropout_rate, + "dropout_rate": decoder_dropout_rate, + "positional_dropout_rate": decoder_positional_dropout_rate, + "linear_units": decoder_linear_units, + "normalize_before": decoder_normalize_before, + "concat_after": decoder_concat_after, + } + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_layers = encoder_layers + self.duration_predictor_channels = duration_predictor_channels + self.duration_predictor_kernel_size = duration_predictor_kernel_size + self.duration_predictor_layers = duration_predictor_layers + self.energy_embed_dropout = energy_embed_dropout + self.energy_embed_kernel_size = energy_embed_kernel_size + self.energy_predictor_channels = energy_predictor_channels + self.energy_predictor_dropout = energy_predictor_dropout + self.energy_predictor_kernel_size = energy_predictor_kernel_size + self.energy_predictor_layers = energy_predictor_layers + self.pitch_embed_dropout = pitch_embed_dropout + self.pitch_embed_kernel_size = pitch_embed_kernel_size + self.pitch_predictor_channels = pitch_predictor_channels + self.pitch_predictor_dropout = pitch_predictor_dropout + self.pitch_predictor_kernel_size = pitch_predictor_kernel_size + self.pitch_predictor_layers = pitch_predictor_layers + self.positionwise_conv_kernel_size = positionwise_conv_kernel_size + self.speech_decoder_postnet_units = speech_decoder_postnet_units + self.speech_decoder_postnet_dropout = speech_decoder_postnet_dropout + self.speech_decoder_postnet_kernel = speech_decoder_postnet_kernel + self.speech_decoder_postnet_layers = speech_decoder_postnet_layers + self.reduction_factor = reduction_factor + self.speaking_speed = speaking_speed + self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor + self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor + self.max_source_positions = max_source_positions + self.use_cnn_in_conformer = use_cnn_in_conformer + self.use_macaron_style_in_conformer = use_macaron_style_in_conformer + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + self.num_speakers = num_speakers + self.num_languages = num_languages + self.speaker_embed_dim = speaker_embed_dim + self.duration_predictor_dropout_rate = duration_predictor_dropout_rate + self.is_encoder_decoder = is_encoder_decoder + + super().__init__( + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class FastSpeech2ConformerHifiGanConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FastSpeech2ConformerHifiGanModel`]. It is used to + instantiate a FastSpeech2Conformer HiFi-GAN vocoder model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + FastSpeech2Conformer + [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + model_in_dim (`int`, *optional*, defaults to 80): + The number of frequency bins in the input log-mel spectrogram. + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the upsampling network. + upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 2, 2]`): + A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The + length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. + upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[16, 16, 4, 4]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The + length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of + *upsample_rates*. + resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field + fusion (MRF) module. + resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the + multi-receptive field fusion (MRF) module. + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation. + normalize_before (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance. + + Example: + + ```python + >>> from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig + + >>> # Initializing a FastSpeech2ConformerHifiGan configuration + >>> configuration = FastSpeech2ConformerHifiGanConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = FastSpeech2ConformerHifiGan(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "hifigan" + base_config_key = "vocoder_config" + + def __init__( + self, + model_in_dim=80, + upsample_initial_channel=512, + upsample_rates=[8, 8, 2, 2], + upsample_kernel_sizes=[16, 16, 4, 4], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + initializer_range=0.01, + leaky_relu_slope=0.1, + normalize_before=True, + **kwargs, + ): + self.model_in_dim = model_in_dim + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.initializer_range = initializer_range + self.leaky_relu_slope = leaky_relu_slope + self.normalize_before = normalize_before + super().__init__(**kwargs) + + +class FastSpeech2ConformerWithHifiGanConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`FastSpeech2ConformerWithHifiGan`]. It is used to + instantiate a `FastSpeech2ConformerWithHifiGanModel` model according to the specified sub-models configurations, + defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the + FastSpeech2ConformerModel [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer) and + FastSpeech2ConformerHifiGan + [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architectures. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + model_config (`typing.Dict`, *optional*): + Configuration of the text-to-speech model. + vocoder_config (`typing.Dict`, *optional*): + Configuration of the vocoder model. + model_config ([`FastSpeech2ConformerConfig`], *optional*): + Configuration of the text-to-speech model. + vocoder_config ([`FastSpeech2ConformerHiFiGanConfig`], *optional*): + Configuration of the vocoder model. + + Example: + + ```python + >>> from transformers import ( + ... FastSpeech2ConformerConfig, + ... FastSpeech2ConformerHifiGanConfig, + ... FastSpeech2ConformerWithHifiGanConfig, + ... FastSpeech2ConformerWithHifiGan, + ... ) + + >>> # Initializing FastSpeech2ConformerWithHifiGan sub-modules configurations. + >>> model_config = FastSpeech2ConformerConfig() + >>> vocoder_config = FastSpeech2ConformerHifiGanConfig() + + >>> # Initializing a FastSpeech2ConformerWithHifiGan module style configuration + >>> configuration = FastSpeech2ConformerWithHifiGanConfig(model_config.to_dict(), vocoder_config.to_dict()) + + >>> # Initializing a model (with random weights) + >>> model = FastSpeech2ConformerWithHifiGan(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "fastspeech2_conformer_with_hifigan" + sub_configs = {"model_config": FastSpeech2ConformerConfig, "vocoder_config": FastSpeech2ConformerHifiGanConfig} + + def __init__( + self, + model_config: Optional[dict] = None, + vocoder_config: Optional[dict] = None, + **kwargs, + ): + if model_config is None: + model_config = {} + logger.info("model_config is None. initializing the model with default values.") + + if vocoder_config is None: + vocoder_config = {} + logger.info("vocoder_config is None. initializing the coarse model with default values.") + + self.model_config = FastSpeech2ConformerConfig(**model_config) + self.vocoder_config = FastSpeech2ConformerHifiGanConfig(**vocoder_config) + + super().__init__(**kwargs) + + +__all__ = ["FastSpeech2ConformerConfig", "FastSpeech2ConformerHifiGanConfig", "FastSpeech2ConformerWithHifiGanConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..004a1c36f59cc7942a7a132012bdccb40a4a38de --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for FastSpeech2Conformer.""" + +import json +import os +from typing import Optional + +import regex + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging, requires_backends + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"} + + +class FastSpeech2ConformerTokenizer(PreTrainedTokenizer): + """ + Construct a FastSpeech2Conformer tokenizer. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The begin of sequence token. Note that for FastSpeech2, it is the same as the `eos_token`. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. Note that for FastSpeech2, it is the same as the `bos_token`. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + should_strip_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to strip the spaces from the list of tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + should_strip_spaces=False, + **kwargs, + ): + requires_backends(self, "g2p_en") + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + + import g2p_en + + self.g2p = g2p_en.G2p() + + self.decoder = {v: k for k, v in self.encoder.items()} + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + should_strip_spaces=should_strip_spaces, + **kwargs, + ) + + self.should_strip_spaces = should_strip_spaces + + @property + def vocab_size(self): + return len(self.decoder) + + def get_vocab(self): + "Returns vocab as a dict" + return dict(self.encoder, **self.added_tokens_encoder) + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + # expand symbols + text = regex.sub(";", ",", text) + text = regex.sub(":", ",", text) + text = regex.sub("-", " ", text) + text = regex.sub("&", "and", text) + + # strip unnecessary symbols + text = regex.sub(r"[\(\)\[\]\<\>\"]+", "", text) + + # strip whitespaces + text = regex.sub(r"\s+", " ", text) + + text = text.upper() + + return text, kwargs + + def _tokenize(self, text): + """Returns a tokenized string.""" + # phonemize + tokens = self.g2p(text) + + if self.should_strip_spaces: + tokens = list(filter(lambda s: s != " ", tokens)) + + tokens.append(self.eos_token) + + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + # Override since phonemes cannot be converted back to strings + def decode(self, token_ids, **kwargs): + logger.warning( + "Phonemes cannot be reliably converted to a string due to the one-many mapping, converting to tokens instead." + ) + return self.convert_ids_to_tokens(token_ids) + + # Override since phonemes cannot be converted back to strings + def convert_tokens_to_string(self, tokens, **kwargs): + logger.warning( + "Phonemes cannot be reliably converted to a string due to the one-many mapping, returning the tokens." + ) + return tokens + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.get_vocab(), ensure_ascii=False)) + + return (vocab_file,) + + def __getstate__(self): + state = self.__dict__.copy() + state["g2p"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + try: + import g2p_en + + self.g2p = g2p_en.G2p() + except ImportError: + raise ImportError( + "You need to install g2p-en to use FastSpeech2ConformerTokenizer. " + "See https://pypi.org/project/g2p-en/ for installation." + ) + + +__all__ = ["FastSpeech2ConformerTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flaubert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/flaubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e981d9cbcb1e456c206b3bec252df1598e23575a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flaubert/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_flaubert import * + from .modeling_flaubert import * + from .modeling_tf_flaubert import * + from .tokenization_flaubert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/flaubert/configuration_flaubert.py b/venv/lib/python3.13/site-packages/transformers/models/flaubert/configuration_flaubert.py new file mode 100644 index 0000000000000000000000000000000000000000..071a74fe69b420954ff6ce7154b227c0e0d7e4dd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flaubert/configuration_flaubert.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flaubert configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class FlaubertConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`FlaubertModel`] or a [`TFFlaubertModel`]. It is + used to instantiate a FlauBERT model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the FlauBERT + [flaubert/flaubert_base_uncased](https://huggingface.co/flaubert/flaubert_base_uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + pre_norm (`bool`, *optional*, defaults to `False`): + Whether to apply the layer normalization before or after the feed forward layer following the attention in + each layer (Vaswani et al., Tensor2Tensor for Neural Machine Translation. 2018) + layerdrop (`float`, *optional*, defaults to 0.0): + Probability to drop layers during training (Fan et al., Reducing Transformer Depth on Demand with + Structured Dropout. ICLR 2020) + vocab_size (`int`, *optional*, defaults to 30145): + Vocabulary size of the FlauBERT model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`FlaubertModel`] or [`TFFlaubertModel`]. + emb_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention mechanism + gelu_activation (`bool`, *optional*, defaults to `True`): + Whether or not to use a *gelu* activation instead of *relu*. + sinusoidal_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to use sinusoidal positional embeddings instead of absolute positional embeddings. + causal (`bool`, *optional*, defaults to `False`): + Whether or not the model should behave in a causal manner. Causal models use a triangular attention mask in + order to only attend to the left-side context instead if a bidirectional context. + asm (`bool`, *optional*, defaults to `False`): + Whether or not to use an adaptive log softmax projection layer instead of a linear layer for the prediction + layer. + n_langs (`int`, *optional*, defaults to 1): + The number of languages the model handles. Set to 1 for monolingual models. + use_lang_emb (`bool`, *optional*, defaults to `True`) + Whether to use language embeddings. Some models use additional language embeddings, see [the multilingual + models page](http://huggingface.co/transformers/multilingual.html#xlm-language-embeddings) for information + on how to use them. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + embed_init_std (`float`, *optional*, defaults to 2048^-0.5): + The standard deviation of the truncated_normal_initializer for initializing the embedding matrices. + init_std (`int`, *optional*, defaults to 50257): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices except the + embedding matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + bos_index (`int`, *optional*, defaults to 0): + The index of the beginning of sentence token in the vocabulary. + eos_index (`int`, *optional*, defaults to 1): + The index of the end of sentence token in the vocabulary. + pad_index (`int`, *optional*, defaults to 2): + The index of the padding token in the vocabulary. + unk_index (`int`, *optional*, defaults to 3): + The index of the unknown token in the vocabulary. + mask_index (`int`, *optional*, defaults to 5): + The index of the masking token in the vocabulary. + is_encoder(`bool`, *optional*, defaults to `True`): + Whether or not the initialized model should be a transformer encoder or decoder as seen in Vaswani et al. + summary_type (`string`, *optional*, defaults to "first"): + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Used in the sequence classification and multiple choice models. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Used in the sequence classification and multiple choice models. + + The dropout ratio to be used after the projection and activation. + start_n_top (`int`, *optional*, defaults to 5): + Used in the SQuAD evaluation script. + end_n_top (`int`, *optional*, defaults to 5): + Used in the SQuAD evaluation script. + mask_token_id (`int`, *optional*, defaults to 0): + Model agnostic parameter to identify masked tokens when generating text in an MLM context. + lang_id (`int`, *optional*, defaults to 1): + The ID of the language used by the model. This parameter is used when generating text in a given language. + """ + + model_type = "flaubert" + attribute_map = { + "hidden_size": "emb_dim", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + "n_words": "vocab_size", # For backward compatibility + } + + def __init__( + self, + pre_norm=False, + layerdrop=0.0, + vocab_size=30145, + emb_dim=2048, + n_layers=12, + n_heads=16, + dropout=0.1, + attention_dropout=0.1, + gelu_activation=True, + sinusoidal_embeddings=False, + causal=False, + asm=False, + n_langs=1, + use_lang_emb=True, + max_position_embeddings=512, + embed_init_std=2048**-0.5, + layer_norm_eps=1e-12, + init_std=0.02, + bos_index=0, + eos_index=1, + pad_index=2, + unk_index=3, + mask_index=5, + is_encoder=True, + summary_type="first", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + start_n_top=5, + end_n_top=5, + mask_token_id=0, + lang_id=0, + pad_token_id=2, + bos_token_id=0, + **kwargs, + ): + """Constructs FlaubertConfig.""" + self.pre_norm = pre_norm + self.layerdrop = layerdrop + self.vocab_size = vocab_size + self.emb_dim = emb_dim + self.n_layers = n_layers + self.n_heads = n_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.gelu_activation = gelu_activation + self.sinusoidal_embeddings = sinusoidal_embeddings + self.causal = causal + self.asm = asm + self.n_langs = n_langs + self.use_lang_emb = use_lang_emb + self.layer_norm_eps = layer_norm_eps + self.bos_index = bos_index + self.eos_index = eos_index + self.pad_index = pad_index + self.unk_index = unk_index + self.mask_index = mask_index + self.is_encoder = is_encoder + self.max_position_embeddings = max_position_embeddings + self.embed_init_std = embed_init_std + self.init_std = init_std + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_proj_to_labels = summary_proj_to_labels + self.summary_first_dropout = summary_first_dropout + self.start_n_top = start_n_top + self.end_n_top = end_n_top + self.mask_token_id = mask_token_id + self.lang_id = lang_id + + if "n_words" in kwargs: + self.n_words = kwargs["n_words"] + + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) + + +class FlaubertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["FlaubertConfig", "FlaubertOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flaubert/modeling_flaubert.py b/venv/lib/python3.13/site-packages/transformers/models/flaubert/modeling_flaubert.py new file mode 100644 index 0000000000000000000000000000000000000000..1dadc6f5377b4007c78f44d81d7a39acf2dedbdb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flaubert/modeling_flaubert.py @@ -0,0 +1,1700 @@ +# coding=utf-8 +# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Flaubert model, based on XLM.""" + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import gelu, get_activation +from ...cache_utils import DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging +from .configuration_flaubert import FlaubertConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.xlm.modeling_xlm.create_sinusoidal_embeddings +def create_sinusoidal_embeddings(n_pos, dim, out): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + out.requires_grad = False + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + + +# Copied from transformers.models.xlm.modeling_xlm.get_masks +def get_masks(slen, lengths, causal, padding_mask=None): + """ + Generate hidden states mask, and optionally an attention mask. + """ + alen = torch.arange(slen, dtype=torch.long, device=lengths.device) + if padding_mask is not None: + mask = padding_mask + else: + assert lengths.max().item() <= slen + mask = alen < lengths[:, None] + + # attention mask is the same as mask, or triangular inferior attention (causal) + bs = lengths.size(0) + if causal: + attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None] + else: + attn_mask = mask + + # sanity check + assert mask.size() == (bs, slen) + assert causal is False or attn_mask.size() == (bs, slen, slen) + + return mask, attn_mask + + +# Copied from transformers.models.xlm.modeling_xlm.MultiHeadAttention +class MultiHeadAttention(nn.Module): + def __init__(self, n_heads, dim, config, layer_idx: int = 0): + super().__init__() + self.layer_id = layer_idx + self.dim = dim + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.dropout = config.attention_dropout + assert self.dim % self.n_heads == 0 + + self.q_lin = nn.Linear(dim, dim) + self.k_lin = nn.Linear(dim, dim) + self.v_lin = nn.Linear(dim, dim) + self.out_lin = nn.Linear(dim, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + attention_head_size = self.dim // self.n_heads + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads) + # Prune linear layers + self.q_lin = prune_linear_layer(self.q_lin, index) + self.k_lin = prune_linear_layer(self.k_lin, index) + self.v_lin = prune_linear_layer(self.v_lin, index) + self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.dim = attention_head_size * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + input, + mask, + kv=None, + cache=None, + head_mask=None, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if kv is None) or attention over source sentence (provided by kv). + """ + # Input is (bs, qlen, dim) + # Mask is (bs, klen) (non-causal) or (bs, klen, klen) + bs, qlen, dim = input.size() + is_cross_attention = kv is not None + mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1) + + q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + if cache is not None: + if isinstance(cache, EncoderDecoderCache): + is_updated = cache.is_updated.get(self.layer_id) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = cache.cross_attention_cache + else: + curr_past_key_value = cache.self_attention_cache + else: + curr_past_key_value = cache + + current_states = kv if is_cross_attention else input + if is_cross_attention and cache is not None and is_updated: + # reuse k,v, cross_attentions + k = curr_past_key_value.key_cache[self.layer_id] + v = curr_past_key_value.value_cache[self.layer_id] + else: + k = self.k_lin(current_states) + v = self.v_lin(current_states) + k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + + if cache is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + k, v = curr_past_key_value.update(k, v, self.layer_id, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + cache.is_updated[self.layer_id] = True + + q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim) + scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) + mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) + scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) + + weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) + weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + + context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim) + context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim) + + outputs = (self.out_lin(context),) + if output_attentions: + outputs = outputs + (weights,) + return outputs + + +# Copied from transformers.models.xlm.modeling_xlm.TransformerFFN +class TransformerFFN(nn.Module): + def __init__(self, in_dim, dim_hidden, out_dim, config): + super().__init__() + self.dropout = config.dropout + self.lin1 = nn.Linear(in_dim, dim_hidden) + self.lin2 = nn.Linear(dim_hidden, out_dim) + self.act = gelu if config.gelu_activation else nn.functional.relu + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + def forward(self, input): + return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) + + def ff_chunk(self, input): + x = self.lin1(input) + x = self.act(x) + x = self.lin2(x) + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + return x + + +@auto_docstring( + custom_intro=""" + The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top. + """ +) +# Copied from transformers.models.xlm.modeling_xlm.XLMPredLayer with XLM->Flaubert +class FlaubertPredLayer(nn.Module): + """ + Prediction layer (cross_entropy or adaptive_softmax). + """ + + def __init__(self, config): + super().__init__() + self.asm = config.asm + self.n_words = config.n_words + self.pad_index = config.pad_index + dim = config.emb_dim + + if config.asm is False: + self.proj = nn.Linear(dim, config.n_words, bias=True) + else: + self.proj = nn.AdaptiveLogSoftmaxWithLoss( + in_features=dim, + n_classes=config.n_words, + cutoffs=config.asm_cutoffs, + div_value=config.asm_div_value, + head_bias=True, # default is False + ) + + def forward(self, x, y=None): + """Compute the loss, and optionally the scores.""" + outputs = () + if self.asm is False: + scores = self.proj(x) + outputs = (scores,) + outputs + if y is not None: + loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean") + outputs = (loss,) + outputs + else: + scores = self.proj.log_prob(x) + outputs = (scores,) + outputs + if y is not None: + _, loss = self.proj(x, y) + outputs = (loss,) + outputs + + return outputs + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of question answering models using a [`~modeling_utils.FlaubertSQuADHead`]. + """ +) +# Copied from transformers.models.xlm.modeling_xlm.XLMSquadHeadOutput with XLM->Flaubert +class FlaubertSquadHeadOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + """ + + loss: Optional[torch.FloatTensor] = None + start_top_log_probs: Optional[torch.FloatTensor] = None + start_top_index: Optional[torch.LongTensor] = None + end_top_log_probs: Optional[torch.FloatTensor] = None + end_top_index: Optional[torch.LongTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->Flaubert +class FlaubertPoolerStartLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->Flaubert +class FlaubertPoolerEndLogits(nn.Module): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The end logits for SQuAD. + """ + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if p_mask.dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->Flaubert +class FlaubertPoolerAnswerClass(nn.Module): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert start_states is not None or start_positions is not None, ( + "One of start_states, start_positions should be not None" + ) + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSQuADHead with XLM->Flaubert +class FlaubertSQuADHead(nn.Module): + r""" + A SQuAD head inspired by XLNet. + + Args: + config ([`FlaubertConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = FlaubertPoolerStartLogits(config) + self.end_logits = FlaubertPoolerEndLogits(config) + self.answer_class = FlaubertPoolerAnswerClass(config) + + @auto_docstring + def forward( + self, + hidden_states: torch.FloatTensor, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + is_impossible: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + return_dict: bool = False, + ) -> Union[FlaubertSquadHeadOutput, tuple[torch.FloatTensor]]: + r""" + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the first token for the labeled span. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the last token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Whether the question has a possible answer in the paragraph or not. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + return FlaubertSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + if not return_dict: + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + else: + return FlaubertSquadHeadOutput( + start_top_log_probs=start_top_log_probs, + start_top_index=start_top_index, + end_top_log_probs=end_top_log_probs, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Flaubert +class FlaubertSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`FlaubertConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: FlaubertConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +@auto_docstring +# Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert +class FlaubertPreTrainedModel(PreTrainedModel): + config: FlaubertConfig + load_tf_weights = None + base_model_prefix = "transformer" + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + @property + def dummy_inputs(self): + inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) + attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + if self.config.use_lang_emb and self.config.n_langs > 1: + langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + else: + langs_list = None + return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list} + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Embedding): + if self.config is not None and self.config.embed_init_std is not None: + nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, nn.Linear): + if self.config is not None and self.config.init_std is not None: + nn.init.normal_(module.weight, mean=0, std=self.config.init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings: + create_sinusoidal_embeddings( + self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight + ) + + +@auto_docstring +class FlaubertModel(FlaubertPreTrainedModel): + def __init__(self, config): # , dico, is_encoder, with_output): + super().__init__(config) + + # encoder / decoder, output layer + self.is_encoder = config.is_encoder + self.is_decoder = not config.is_encoder + if self.is_decoder: + raise NotImplementedError("Currently Flaubert can only be used as an encoder") + # self.with_output = with_output + self.causal = config.causal + + # dictionary / languages + self.n_langs = config.n_langs + self.use_lang_emb = config.use_lang_emb + self.n_words = config.n_words + self.eos_index = config.eos_index + self.pad_index = config.pad_index + # self.dico = dico + # self.id2lang = config.id2lang + # self.lang2id = config.lang2id + # assert len(self.dico) == self.n_words + # assert len(self.id2lang) == len(self.lang2id) == self.n_langs + + # model parameters + self.dim = config.emb_dim # 512 by default + self.hidden_dim = self.dim * 4 # 2048 by default + self.n_heads = config.n_heads # 8 by default + self.n_layers = config.n_layers + self.dropout = config.dropout + self.attention_dropout = config.attention_dropout + assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads" + + # embeddings + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) + if config.n_langs > 1 and config.use_lang_emb: + self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) + self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) + self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) + + # transformer layers + self.attentions = nn.ModuleList() + self.layer_norm1 = nn.ModuleList() + self.ffns = nn.ModuleList() + self.layer_norm2 = nn.ModuleList() + # if self.is_decoder: + # self.layer_norm15 = nn.ModuleList() + # self.encoder_attn = nn.ModuleList() + + for i in range(self.n_layers): + self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config, layer_idx=i)) + self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) + # if self.is_decoder: + # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) + # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) + self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) + self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) + + if hasattr(config, "pruned_heads"): + pruned_heads = config.pruned_heads.copy().items() + config.pruned_heads = {} + for layer, heads in pruned_heads: + if self.attentions[int(layer)].n_heads == config.n_heads: + self.prune_heads({int(layer): list(map(int, heads))}) + + # Initialize weights and apply final processing + self.post_init() + + self.layerdrop = getattr(config, "layerdrop", 0.0) + self.pre_norm = getattr(config, "pre_norm", False) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + # Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings + def get_input_embeddings(self): + return self.embeddings + + # Copied from transformers.models.xlm.modeling_xlm.XLMModel.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + # Copied from transformers.models.xlm.modeling_xlm.XLMModel._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.attentions[layer].prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + langs: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + lengths: Optional[torch.LongTensor] = None, + cache: Optional[dict[str, torch.FloatTensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in + `[0, ..., input_ids.size(-1)]`: + cache (`dict[str, torch.FloatTensor]`, *optional*): + Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the + attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + decoding. The dictionary object will be modified in-place during the forward pass to add newly computed + hidden-states. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # removed: src_enc=None, src_len=None + if input_ids is not None: + bs, slen = input_ids.size() + else: + bs, slen = inputs_embeds.size()[:-1] + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if cache is None: + cache = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if isinstance(cache, tuple): + cache = EncoderDecoderCache.from_legacy_cache(cache) + + if lengths is None: + if input_ids is not None: + lengths = (input_ids != self.pad_index).sum(dim=1).long() + else: + lengths = torch.tensor([slen] * bs, device=device) + # mask = input_ids != self.pad_index + + # check inputs + assert lengths.size(0) == bs + assert lengths.max().item() <= slen + # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 + # assert (src_enc is None) == (src_len is None) + # if src_enc is not None: + # assert self.is_decoder + # assert src_enc.size(0) == bs + + # generate masks + mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) + # if self.is_decoder and src_enc is not None: + # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] + + # Setting the position-ids to the registered buffer in constructor, it helps + # when tracing the model without passing position-ids, solves + # issues similar to issue #5664 + if position_ids is None: + if hasattr(self, "position_ids"): + position_ids = self.position_ids[:, :slen] + position_ids = position_ids.expand((bs, slen)) + else: + position_ids = torch.arange(slen, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand((bs, slen)) + else: + assert position_ids.size() == (bs, slen) # (slen, bs) + # position_ids = position_ids.transpose(0, 1) + + # langs + if langs is not None: + assert langs.size() == (bs, slen) # (slen, bs) + # langs = langs.transpose(0, 1) + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.n_layers) + + # do not recompute cached elements + if cache is not None and input_ids is not None: + _slen = slen - cache.get_seq_length() + input_ids = input_ids[:, -_slen:] + position_ids = position_ids[:, -_slen:] + if langs is not None: + langs = langs[:, -_slen:] + mask = mask[:, -_slen:] + attn_mask = attn_mask[:, -_slen:] + + # embeddings + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds) + if langs is not None and self.use_lang_emb and self.config.n_langs > 1: + tensor = tensor + self.lang_embeddings(langs) + if token_type_ids is not None: + tensor = tensor + self.embeddings(token_type_ids) + tensor = self.layer_norm_emb(tensor) + tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) + tensor *= mask.unsqueeze(-1).to(tensor.dtype) + + # transformer layers + hidden_states = () if output_hidden_states else None + attentions = () if output_attentions else None + for i in range(self.n_layers): + # LayerDrop + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + if output_hidden_states: + hidden_states = hidden_states + (tensor,) + + # self attention + if not self.pre_norm: + attn_outputs = self.attentions[i]( + tensor, + attn_mask, + cache=cache, + head_mask=head_mask[i], + output_attentions=output_attentions, + cache_position=cache_position, + ) + attn = attn_outputs[0] + if output_attentions: + attentions = attentions + (attn_outputs[1],) + attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) + tensor = tensor + attn + tensor = self.layer_norm1[i](tensor) + else: + tensor_normalized = self.layer_norm1[i](tensor) + attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache, head_mask=head_mask[i]) + attn = attn_outputs[0] + if output_attentions: + attentions = attentions + (attn_outputs[1],) + attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) + tensor = tensor + attn + + # FFN + if not self.pre_norm: + tensor = tensor + self.ffns[i](tensor) + tensor = self.layer_norm2[i](tensor) + else: + tensor_normalized = self.layer_norm2[i](tensor) + tensor = tensor + self.ffns[i](tensor_normalized) + + tensor *= mask.unsqueeze(-1).to(tensor.dtype) + + # Add last hidden state + if output_hidden_states: + hidden_states = hidden_states + (tensor,) + + if not return_dict: + return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) + + return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) + + +@auto_docstring( + custom_intro=""" + The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """ +) +class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["pred_layer.proj.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = FlaubertModel(config) + self.pred_layer = FlaubertPredLayer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.pred_layer.proj + + def set_output_embeddings(self, new_embeddings): + self.pred_layer.proj = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + # Overwritten -- uses a language id + + mask_token_id = self.config.mask_token_id + lang_id = self.config.lang_id + + effective_batch_size = input_ids.shape[0] + mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device) + input_ids = torch.cat([input_ids, mask_token], dim=1) + if lang_id is not None: + langs = torch.full_like(input_ids, lang_id) + else: + langs = None + return {"input_ids": input_ids, "langs": langs} + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + langs: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + cache: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in + `[0, ..., input_ids.size(-1)]`: + cache (`dict[str, torch.FloatTensor]`, *optional*): + Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the + attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + decoding. The dictionary object will be modified in-place during the forward pass to add newly computed + hidden-states. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + output = transformer_outputs[0] + outputs = self.pred_layer(output, labels) # (loss, logits) or (logits,) depending on if labels are provided. + + if not return_dict: + return outputs + transformer_outputs[1:] + + return MaskedLMOutput( + loss=outputs[0] if labels is not None else None, + logits=outputs[0] if labels is None else outputs[1], + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) + e.g. for GLUE tasks. + """ +) +# Copied from transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class FlaubertForSequenceClassification(FlaubertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.transformer = FlaubertModel(config) + self.sequence_summary = FlaubertSequenceSummary(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + langs: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + cache: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in + `[0, ..., input_ids.size(-1)]`. + cache (`dict[str, torch.FloatTensor]`, *optional*): + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential + decoding. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + output = transformer_outputs[0] + logits = self.sequence_summary(output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +# Copied from transformers.models.xlm.modeling_xlm.XLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class FlaubertForTokenClassification(FlaubertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = FlaubertModel(config) + self.dropout = nn.Dropout(config.dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + langs: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + cache: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in + `[0, ..., input_ids.size(-1)]`. + cache (`dict[str, torch.FloatTensor]`, *optional*): + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential + decoding. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """ +) +# Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class FlaubertForQuestionAnsweringSimple(FlaubertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.transformer = FlaubertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + langs: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + cache: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + r""" + langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in + `[0, ..., input_ids.size(-1)]`. + cache (`dict[str, torch.FloatTensor]`, *optional*): + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential + decoding. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = transformer_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of question answering models using a `SquadHead`. + """ +) +# Copied from transformer.models.xlm.modeling_xlm.XLMForQuestionAnsweringOutput with XLM->Flaubert +class FlaubertForQuestionAnsweringOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + """ + + loss: Optional[torch.FloatTensor] = None + start_top_log_probs: Optional[torch.FloatTensor] = None + start_top_index: Optional[torch.LongTensor] = None + end_top_log_probs: Optional[torch.FloatTensor] = None + end_top_index: Optional[torch.LongTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +# Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class FlaubertForQuestionAnswering(FlaubertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.transformer = FlaubertModel(config) + self.qa_outputs = FlaubertSQuADHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + langs: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + cache: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + is_impossible: Optional[torch.Tensor] = None, + cls_index: Optional[torch.Tensor] = None, + p_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, FlaubertForQuestionAnsweringOutput]: + r""" + langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in + `[0, ..., input_ids.size(-1)]`. + cache (`dict[str, torch.FloatTensor]`, *optional*): + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential + decoding. + is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels whether a question has an answer or no answer (SQuAD 2.0) + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the classification token to use as input for computing plausibility of the + answer. + p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be + masked. 0.0 mean token is not masked. + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaubertForQuestionAnswering + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048") + >>> model = FlaubertForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048") + + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze( + ... 0 + ... ) # Batch size 1 + >>> start_positions = torch.tensor([1]) + >>> end_positions = torch.tensor([3]) + + >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) + >>> loss = outputs.loss + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + output = transformer_outputs[0] + + outputs = self.qa_outputs( + output, + start_positions=start_positions, + end_positions=end_positions, + cls_index=cls_index, + is_impossible=is_impossible, + p_mask=p_mask, + return_dict=return_dict, + ) + + if not return_dict: + return outputs + transformer_outputs[1:] + + return FlaubertForQuestionAnsweringOutput( + loss=outputs.loss, + start_top_log_probs=outputs.start_top_log_probs, + start_top_index=outputs.start_top_index, + end_top_log_probs=outputs.end_top_log_probs, + end_top_index=outputs.end_top_index, + cls_logits=outputs.cls_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +# Copied from transformers.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class FlaubertForMultipleChoice(FlaubertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.transformer = FlaubertModel(config) + self.sequence_summary = FlaubertSequenceSummary(config) + self.logits_proj = nn.Linear(config.num_labels, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + langs: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + cache: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + langs (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in + `[0, ..., input_ids.size(-1)]`. + cache (`dict[str, torch.FloatTensor]`, *optional*): + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential + decoding. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + langs = langs.view(-1, langs.size(-1)) if langs is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + if lengths is not None: + logger.warning( + "The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the " + "attention mask instead." + ) + lengths = None + + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + output = transformer_outputs[0] + logits = self.sequence_summary(output) + logits = self.logits_proj(logits) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = [ + "FlaubertForMultipleChoice", + "FlaubertForQuestionAnswering", + "FlaubertForQuestionAnsweringSimple", + "FlaubertForSequenceClassification", + "FlaubertForTokenClassification", + "FlaubertModel", + "FlaubertWithLMHeadModel", + "FlaubertPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flaubert/modeling_tf_flaubert.py b/venv/lib/python3.13/site-packages/transformers/models/flaubert/modeling_tf_flaubert.py new file mode 100644 index 0000000000000000000000000000000000000000..88b7ae9f0c9ddd96e761b3739a035c55a217dec8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flaubert/modeling_tf_flaubert.py @@ -0,0 +1,1343 @@ +# coding=utf-8 +# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TF 2.0 Flaubert model. +""" + +from __future__ import annotations + +import itertools +import random +import warnings +from dataclasses import dataclass + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFSequenceSummary, + TFSharedEmbeddings, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + MULTIPLE_CHOICE_DUMMY_INPUTS, + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_flaubert import FlaubertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "flaubert/flaubert_base_cased" +_CONFIG_FOR_DOC = "FlaubertConfig" + + +FLAUBERT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`FlaubertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +FLAUBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - `1` for tokens that are **not masked**, + - `0` for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + langs (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are + languages ids which can be obtained from the language names by using two conversion mappings provided in + the configuration of the model (only provided for multilingual models). More precisely, the *language name + to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the + *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). + + See usage examples detailed in the [multilingual documentation](../multilingual). + token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - `0` corresponds to a *sentence A* token, + - `1` corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + lengths (`tf.Tensor` or `Numpy array` of shape `(batch_size,)`, *optional*): + Length of each sentence that can be used to avoid performing attention on padding token indices. You can + also use *attention_mask* for the same result (see above), kept here for compatibility Indices selected in + `[0, ..., input_ids.size(-1)]`: + cache (`dict[str, tf.Tensor]`, *optional*): + Dictionary string to `tf.FloatTensor` that contains precomputed hidden states (key and values in the + attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + decoding. + + The dictionary object will be modified in-place during the forward pass to add newly computed + hidden-states. + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - `1` indicates the head is **not masked**, + - `0` indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +def get_masks(slen, lengths, causal, padding_mask=None): + """ + Generate hidden states mask, and optionally an attention mask. + """ + bs = shape_list(lengths)[0] + if padding_mask is not None: + mask = padding_mask + else: + # assert lengths.max().item() <= slen + alen = tf.range(slen, dtype=lengths.dtype) + mask = alen < tf.expand_dims(lengths, axis=1) + + # attention mask is the same as mask, or triangular inferior attention (causal) + if causal: + attn_mask = tf.less_equal( + tf.tile(tf.reshape(alen, (1, 1, slen)), (bs, slen, 1)), tf.reshape(alen, (1, slen, 1)) + ) + else: + attn_mask = mask + + # sanity check + # assert shape_list(mask) == [bs, slen] + tf.debugging.assert_equal(shape_list(mask), [bs, slen]) + if causal: + tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen]) + + return mask, attn_mask + + +class TFFlaubertPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = FlaubertConfig + base_model_prefix = "transformer" + + @property + def dummy_inputs(self): + # Sometimes Flaubert has language embeddings so don't forget to build them as well if needed + inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]], dtype=tf.int32) + attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32) + if self.config.use_lang_emb and self.config.n_langs > 1: + return { + "input_ids": inputs_list, + "attention_mask": attns_list, + "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32), + } + else: + return {"input_ids": inputs_list, "attention_mask": attns_list} + + +@add_start_docstrings( + "The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.", + FLAUBERT_START_DOCSTRING, +) +class TFFlaubertModel(TFFlaubertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFFlaubertMainLayer(config, name="transformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: np.ndarray | tf.Tensor | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + langs: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + lengths: np.ndarray | tf.Tensor | None = None, + cache: dict[str, tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFBaseModelOutput: + outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMMultiHeadAttention with XLM->Flaubert +class TFFlaubertMultiHeadAttention(keras.layers.Layer): + NEW_ID = itertools.count() + + def __init__(self, n_heads, dim, config, **kwargs): + super().__init__(**kwargs) + self.layer_id = next(TFFlaubertMultiHeadAttention.NEW_ID) + self.dim = dim + self.n_heads = n_heads + self.output_attentions = config.output_attentions + assert self.dim % self.n_heads == 0 + + self.q_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin") + self.k_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin") + self.v_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin") + self.out_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin") + self.dropout = keras.layers.Dropout(config.attention_dropout) + self.pruned_heads = set() + self.dim = dim + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False): + """ + Self-attention (if kv is None) or attention over source sentence (provided by kv). + """ + # Input is (bs, qlen, dim) + # Mask is (bs, klen) (non-causal) or (bs, klen, klen) + bs, qlen, dim = shape_list(input) + + if kv is None: + klen = qlen if cache is None else cache["slen"] + qlen + else: + klen = shape_list(kv)[1] + + # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' + dim_per_head = self.dim // self.n_heads + mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) + + def shape(x): + """projection""" + return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3)) + + def unshape(x): + """compute context""" + return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) + + q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) + + if kv is None: + k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) + v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) + elif cache is None or self.layer_id not in cache: + k = v = kv + k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) + v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) + + if cache is not None: + if self.layer_id in cache: + if kv is None: + k_, v_ = cache[self.layer_id] + k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) + v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) + else: + k, v = cache[self.layer_id] + + cache[self.layer_id] = (k, v) + + f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype) + q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head) + k = tf.cast(k, dtype=q.dtype) + scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) + mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) + # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) + mask = tf.cast(mask, dtype=scores.dtype) + scores = scores - 1e30 * (1.0 - mask) + weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) + weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + + context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) + context = unshape(context) # (bs, qlen, dim) + outputs = (self.out_lin(context),) + + if output_attentions: + outputs = outputs + (weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "q_lin", None) is not None: + with tf.name_scope(self.q_lin.name): + self.q_lin.build([None, None, self.dim]) + if getattr(self, "k_lin", None) is not None: + with tf.name_scope(self.k_lin.name): + self.k_lin.build([None, None, self.dim]) + if getattr(self, "v_lin", None) is not None: + with tf.name_scope(self.v_lin.name): + self.v_lin.build([None, None, self.dim]) + if getattr(self, "out_lin", None) is not None: + with tf.name_scope(self.out_lin.name): + self.out_lin.build([None, None, self.dim]) + + +# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMTransformerFFN +class TFFlaubertTransformerFFN(keras.layers.Layer): + def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs): + super().__init__(**kwargs) + + self.lin1 = keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1") + self.lin2 = keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2") + self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu") + self.dropout = keras.layers.Dropout(config.dropout) + self.in_dim = in_dim + self.dim_hidden = dim_hidden + + def call(self, input, training=False): + x = self.lin1(input) + x = self.act(x) + x = self.lin2(x) + x = self.dropout(x, training=training) + + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "lin1", None) is not None: + with tf.name_scope(self.lin1.name): + self.lin1.build([None, None, self.in_dim]) + if getattr(self, "lin2", None) is not None: + with tf.name_scope(self.lin2.name): + self.lin2.build([None, None, self.dim_hidden]) + + +@keras_serializable +class TFFlaubertMainLayer(keras.layers.Layer): + config_class = FlaubertConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.n_heads = config.n_heads + self.n_langs = config.n_langs + self.dim = config.emb_dim + self.hidden_dim = self.dim * 4 + self.n_words = config.n_words + self.pad_index = config.pad_index + self.causal = config.causal + self.n_layers = config.n_layers + self.use_lang_emb = config.use_lang_emb + self.layerdrop = getattr(config, "layerdrop", 0.0) + self.pre_norm = getattr(config, "pre_norm", False) + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.max_position_embeddings = config.max_position_embeddings + self.embed_init_std = config.embed_init_std + self.dropout = keras.layers.Dropout(config.dropout) + self.embeddings = TFSharedEmbeddings( + self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" + ) + self.layer_norm_emb = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb") + self.attentions = [] + self.layer_norm1 = [] + self.ffns = [] + self.layer_norm2 = [] + + for i in range(self.n_layers): + self.attentions.append( + TFFlaubertMultiHeadAttention(self.n_heads, self.dim, config=config, name=f"attentions_._{i}") + ) + self.layer_norm1.append( + keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm1_._{i}") + ) + # if self.is_decoder: + # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) + # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) + self.ffns.append( + TFFlaubertTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name=f"ffns_._{i}") + ) + self.layer_norm2.append( + keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm2_._{i}") + ) + + def build(self, input_shape=None): + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.dim], + initializer=get_initializer(self.embed_init_std), + ) + + if self.n_langs > 1 and self.use_lang_emb: + with tf.name_scope("lang_embeddings"): + self.lang_embeddings = self.add_weight( + name="embeddings", + shape=[self.n_langs, self.dim], + initializer=get_initializer(self.embed_init_std), + ) + + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "layer_norm_emb", None) is not None: + with tf.name_scope(self.layer_norm_emb.name): + self.layer_norm_emb.build([None, None, self.dim]) + for layer in self.attentions: + with tf.name_scope(layer.name): + layer.build(None) + for layer in self.layer_norm1: + with tf.name_scope(layer.name): + layer.build([None, None, self.dim]) + for layer in self.ffns: + with tf.name_scope(layer.name): + layer.build(None) + for layer in self.layer_norm2: + with tf.name_scope(layer.name): + layer.build([None, None, self.dim]) + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + @unpack_inputs + def call( + self, + input_ids: np.ndarray | tf.Tensor | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + langs: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + lengths: np.ndarray | tf.Tensor | None = None, + cache: dict[str, tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFBaseModelOutput: + # removed: src_enc=None, src_len=None + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + bs, slen = shape_list(input_ids) + elif inputs_embeds is not None: + bs, slen = shape_list(inputs_embeds)[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if lengths is None: + if input_ids is not None: + lengths = tf.reduce_sum( + tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=input_ids.dtype), axis=1 + ) + else: + lengths = tf.convert_to_tensor([slen] * bs) + # mask = input_ids != self.pad_index + + # check inputs + # assert shape_list(lengths)[0] == bs + ( + tf.debugging.assert_equal(shape_list(lengths)[0], bs), + f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched", + ) + # assert lengths.max().item() <= slen + # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 + # assert (src_enc is None) == (src_len is None) + # if src_enc is not None: + # assert self.is_decoder + # assert src_enc.size(0) == bs + + # generate masks + mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) + # if self.is_decoder and src_enc is not None: + # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] + + # position_ids + if position_ids is None: + position_ids = tf.expand_dims(tf.range(slen), axis=0) + position_ids = tf.tile(position_ids, (bs, 1)) + + # assert shape_list(position_ids) == [bs, slen] # (slen, bs) + ( + tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]), + f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched", + ) + # position_ids = position_ids.transpose(0, 1) + + # langs + if langs is not None: + # assert shape_list(langs) == [bs, slen] # (slen, bs) + ( + tf.debugging.assert_equal(shape_list(langs), [bs, slen]), + f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched", + ) + # langs = langs.transpose(0, 1) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.n_layers + + # do not recompute cached elements + if cache is not None and input_ids is not None: + _slen = slen - cache["slen"] + input_ids = input_ids[:, -_slen:] + position_ids = position_ids[:, -_slen:] + if langs is not None: + langs = langs[:, -_slen:] + mask = mask[:, -_slen:] + attn_mask = attn_mask[:, -_slen:] + + # embeddings + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embeddings.vocab_size) + inputs_embeds = self.embeddings(input_ids) + + tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids) + + if langs is not None and self.use_lang_emb: + tensor = tensor + tf.gather(self.lang_embeddings, langs) + if token_type_ids is not None: + tensor = tensor + self.embeddings(token_type_ids) + + tensor = self.layer_norm_emb(tensor) + tensor = self.dropout(tensor, training=training) + mask = tf.cast(mask, dtype=tensor.dtype) + tensor = tensor * tf.expand_dims(mask, axis=-1) + + # hidden_states and attentions cannot be None in graph mode. + hidden_states = () if output_hidden_states else None + attentions = () if output_attentions else None + + # transformer layers + for i in range(self.n_layers): + # LayerDrop + dropout_probability = random.uniform(0, 1) + + if training and (dropout_probability < self.layerdrop): + continue + + if output_hidden_states: + hidden_states = hidden_states + (tensor,) + + # self attention + if not self.pre_norm: + attn_outputs = self.attentions[i]( + tensor, + attn_mask, + None, + cache, + head_mask[i], + output_attentions, + training=training, + ) + attn = attn_outputs[0] + + if output_attentions: + attentions = attentions + (attn_outputs[1],) + + attn = self.dropout(attn, training=training) + tensor = tensor + attn + tensor = self.layer_norm1[i](tensor) + else: + tensor_normalized = self.layer_norm1[i](tensor) + attn_outputs = self.attentions[i]( + tensor_normalized, + attn_mask, + None, + cache, + head_mask[i], + output_attentions, + training=training, + ) + attn = attn_outputs[0] + + if output_attentions: + attentions = attentions + (attn_outputs[1],) + + attn = self.dropout(attn, training=training) + tensor = tensor + attn + + # encoder attention (for decoder only) + # if self.is_decoder and src_enc is not None: + # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) + # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) + # tensor = tensor + attn + # tensor = self.layer_norm15[i](tensor) + + # FFN + if not self.pre_norm: + tensor = tensor + self.ffns[i](tensor) + tensor = self.layer_norm2[i](tensor) + else: + tensor_normalized = self.layer_norm2[i](tensor) + tensor = tensor + self.ffns[i](tensor_normalized) + + tensor = tensor * tf.expand_dims(mask, axis=-1) + + # Add last hidden state + if output_hidden_states: + hidden_states = hidden_states + (tensor,) + + # update cache length + if cache is not None: + cache["slen"] += tensor.size(1) + + # move back sequence length to dimension 0 + # tensor = tensor.transpose(0, 1) + + if not return_dict: + return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) + + return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) + + +# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMPredLayer +class TFFlaubertPredLayer(keras.layers.Layer): + """ + Prediction layer (cross_entropy or adaptive_softmax). + """ + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.asm = config.asm + self.n_words = config.n_words + self.pad_index = config.pad_index + + if config.asm is False: + self.input_embeddings = input_embeddings + else: + raise NotImplementedError + # self.proj = nn.AdaptiveLogSoftmaxWithLoss( + # in_features=dim, + # n_classes=config.n_words, + # cutoffs=config.asm_cutoffs, + # div_value=config.asm_div_value, + # head_bias=True, # default is False + # ) + + def build(self, input_shape): + # The output weights are the same as the input embeddings, but there is an output-only bias for each token. + self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias") + + super().build(input_shape) + + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.input_embeddings(hidden_states, mode="linear") + hidden_states = hidden_states + self.bias + + return hidden_states + + +@dataclass +class TFFlaubertWithLMHeadModelOutput(ModelOutput): + """ + Base class for [`TFFlaubertWithLMHeadModel`] outputs. + + Args: + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@add_start_docstrings( + """ + The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + FLAUBERT_START_DOCSTRING, +) +class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFFlaubertMainLayer(config, name="transformer") + self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") + # Flaubert does not have past caching features + self.supports_xla_generation = False + + def get_lm_head(self): + return self.pred_layer + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.pred_layer.name + + def prepare_inputs_for_generation(self, inputs, **kwargs): + mask_token_id = self.config.mask_token_id + lang_id = self.config.lang_id + + effective_batch_size = inputs.shape[0] + mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id + inputs = tf.concat([inputs, mask_token], axis=1) + + if lang_id is not None: + langs = tf.ones_like(inputs) * lang_id + else: + langs = None + return {"input_ids": inputs, "langs": langs} + + @unpack_inputs + @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFFlaubertWithLMHeadModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: np.ndarray | tf.Tensor | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + langs: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + lengths: np.ndarray | tf.Tensor | None = None, + cache: dict[str, tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFFlaubertWithLMHeadModelOutput: + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + output = transformer_outputs[0] + outputs = self.pred_layer(output) + + if not return_dict: + return (outputs,) + transformer_outputs[1:] + + return TFFlaubertWithLMHeadModelOutput( + logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "pred_layer", None) is not None: + with tf.name_scope(self.pred_layer.name): + self.pred_layer.build(None) + + +@add_start_docstrings( + """ + Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) + e.g. for GLUE tasks. + """, + FLAUBERT_START_DOCSTRING, +) +# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class TFFlaubertForSequenceClassification(TFFlaubertPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.transformer = TFFlaubertMainLayer(config, name="transformer") + self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") + + @unpack_inputs + @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + langs: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + lengths: np.ndarray | tf.Tensor | None = None, + cache: dict[str, tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + output = transformer_outputs[0] + + logits = self.sequence_summary(output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "sequence_summary", None) is not None: + with tf.name_scope(self.sequence_summary.name): + self.sequence_summary.build(None) + + +@add_start_docstrings( + """ + Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + FLAUBERT_START_DOCSTRING, +) +# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class TFFlaubertForQuestionAnsweringSimple(TFFlaubertPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFFlaubertMainLayer(config, name="transformer") + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + langs: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + lengths: np.ndarray | tf.Tensor | None = None, + cache: dict[str, tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = transformer_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Flaubert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + FLAUBERT_START_DOCSTRING, +) +# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class TFFlaubertForTokenClassification(TFFlaubertPreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.transformer = TFFlaubertMainLayer(config, name="transformer") + self.dropout = keras.layers.Dropout(config.dropout) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + langs: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + lengths: np.ndarray | tf.Tensor | None = None, + cache: dict[str, tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = transformer_outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Flaubert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + FLAUBERT_START_DOCSTRING, +) +# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert +class TFFlaubertForMultipleChoice(TFFlaubertPreTrainedModel, TFMultipleChoiceLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.transformer = TFFlaubertMainLayer(config, name="transformer") + self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") + self.logits_proj = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj" + ) + self.config = config + + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + tf.Tensor with dummy inputs + """ + # Sometimes Flaubert has language embeddings so don't forget to build them as well if needed + if self.config.use_lang_emb and self.config.n_langs > 1: + return { + "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), + "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), + } + else: + return { + "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), + } + + @unpack_inputs + @add_start_docstrings_to_model_forward( + FLAUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + langs: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + lengths: np.ndarray | tf.Tensor | None = None, + cache: dict[str, tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + + if lengths is not None: + logger.warning( + "The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the " + "attention mask instead.", + ) + lengths = None + + transformer_outputs = self.transformer( + flat_input_ids, + flat_attention_mask, + flat_langs, + flat_token_type_ids, + flat_position_ids, + lengths, + cache, + head_mask, + flat_inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + output = transformer_outputs[0] + logits = self.sequence_summary(output) + logits = self.logits_proj(logits) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "sequence_summary", None) is not None: + with tf.name_scope(self.sequence_summary.name): + self.sequence_summary.build(None) + if getattr(self, "logits_proj", None) is not None: + with tf.name_scope(self.logits_proj.name): + self.logits_proj.build([None, None, self.config.num_labels]) + + +__all__ = [ + "TFFlaubertForMultipleChoice", + "TFFlaubertForQuestionAnsweringSimple", + "TFFlaubertForSequenceClassification", + "TFFlaubertForTokenClassification", + "TFFlaubertModel", + "TFFlaubertPreTrainedModel", + "TFFlaubertWithLMHeadModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flaubert/tokenization_flaubert.py b/venv/lib/python3.13/site-packages/transformers/models/flaubert/tokenization_flaubert.py new file mode 100644 index 0000000000000000000000000000000000000000..dee653450ebacc02bcfee8142c5d08b54358c1a9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flaubert/tokenization_flaubert.py @@ -0,0 +1,538 @@ +# coding=utf-8 +# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Flaubert.""" + +import json +import os +import re +import unicodedata +from typing import Optional + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +def convert_to_unicode(text): + """ + Converts `text` to Unicode (if it's not already), assuming UTF-8 input. + """ + + def ensure_text(s, encoding="utf-8", errors="strict"): + if isinstance(s, bytes): + return s.decode(encoding, errors) + elif isinstance(s, str): + return s + else: + raise TypeError(f"not expecting type '{type(s)}'") + + return ensure_text(text, encoding="utf-8", errors="ignore") + + +# Copied from transformers.models.xlm.tokenization_xlm.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length + strings) + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +# Copied from transformers.models.xlm.tokenization_xlm.replace_unicode_punct +def replace_unicode_punct(text): + """ + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl + """ + text = text.replace(",", ",") + text = re.sub(r"。\s*", ". ", text) + text = text.replace("、", ",") + text = text.replace("”", '"') + text = text.replace("“", '"') + text = text.replace("∶", ":") + text = text.replace(":", ":") + text = text.replace("?", "?") + text = text.replace("《", '"') + text = text.replace("》", '"') + text = text.replace(")", ")") + text = text.replace("!", "!") + text = text.replace("(", "(") + text = text.replace(";", ";") + text = text.replace("1", "1") + text = text.replace("」", '"') + text = text.replace("「", '"') + text = text.replace("0", "0") + text = text.replace("3", "3") + text = text.replace("2", "2") + text = text.replace("5", "5") + text = text.replace("6", "6") + text = text.replace("9", "9") + text = text.replace("7", "7") + text = text.replace("8", "8") + text = text.replace("4", "4") + text = re.sub(r".\s*", ". ", text) + text = text.replace("~", "~") + text = text.replace("’", "'") + text = text.replace("…", "...") + text = text.replace("━", "-") + text = text.replace("〈", "<") + text = text.replace("〉", ">") + text = text.replace("【", "[") + text = text.replace("】", "]") + text = text.replace("%", "%") + return text + + +# Copied from transformers.models.xlm.tokenization_xlm.remove_non_printing_char +def remove_non_printing_char(text): + """ + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl + """ + output = [] + for char in text: + cat = unicodedata.category(char) + if cat.startswith("C"): + continue + output.append(char) + return "".join(output) + + +class FlaubertTokenizer(PreTrainedTokenizer): + """ + Construct a Flaubert tokenizer. Based on Byte-Pair Encoding. The tokenization process is the following: + + - Moses preprocessing and tokenization. + - Normalizing all inputs text. + - The arguments `special_tokens` and the function `set_special_tokens`, can be used to add additional symbols (like + "__classify__") to a vocabulary. + - The argument `do_lowercase` controls lower casing (automatically set for pretrained vocabularies). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Vocabulary file. + merges_file (`str`): + Merges file. + do_lowercase (`bool`, *optional*, defaults to `False`): + Controls lower casing. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + additional_special_tokens (`List[str]`, *optional*, defaults to `['', '', '', '', '', '', '', '', '', '']`): + List of additional special tokens. + lang2id (`Dict[str, int]`, *optional*): + Dictionary mapping languages string identifiers to their IDs. + id2lang (`Dict[int, str]`, *optional*): + Dictionary mapping language IDs to their string identifiers. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + merges_file, + do_lowercase=False, + unk_token="", + bos_token="", + sep_token="", + pad_token="", + cls_token="", + mask_token="", + additional_special_tokens=[ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + lang2id=None, + id2lang=None, + **kwargs, + ): + do_lowercase_and_remove_accent = kwargs.pop("do_lowercase_and_remove_accent", None) + if do_lowercase_and_remove_accent is not None: + logger.warning( + "`do_lowercase_and_remove_accent` is passed as a keyword argument, but this won't do anything." + " `FlaubertTokenizer` will always set it to `False`." + ) + # always `False` + self.do_lowercase_and_remove_accent = False + + self.do_lowercase = do_lowercase + + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use FlaubertTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses + + # cache of sm.MosesPunctNormalizer instance + self.cache_moses_punct_normalizer = {} + # cache of sm.MosesTokenizer instance + self.cache_moses_tokenizer = {} + self.lang_with_custom_tokenizer = {"zh", "th", "ja"} + self.lang2id = lang2id + self.id2lang = id2lang + if lang2id is not None and id2lang is not None: + assert len(lang2id) == len(id2lang) + + self.ja_word_tokenizer = None + self.zh_word_tokenizer = None + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[:-1] + merges = [tuple(merge.split()[:2]) for merge in merges] + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + + super().__init__( + do_lowercase=do_lowercase, + unk_token=unk_token, + bos_token=bos_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + additional_special_tokens=additional_special_tokens, + lang2id=lang2id, + id2lang=id2lang, + **kwargs, + ) + + @property + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.do_lower_case + def do_lower_case(self): + return self.do_lowercase_and_remove_accent + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_punct_norm + def moses_punct_norm(self, text, lang): + if lang not in self.cache_moses_punct_normalizer: + punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang) + self.cache_moses_punct_normalizer[lang] = punct_normalizer + else: + punct_normalizer = self.cache_moses_punct_normalizer[lang] + return punct_normalizer.normalize(text) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_tokenize + def moses_tokenize(self, text, lang): + if lang not in self.cache_moses_tokenizer: + moses_tokenizer = self.sm.MosesTokenizer(lang=lang) + self.cache_moses_tokenizer[lang] = moses_tokenizer + else: + moses_tokenizer = self.cache_moses_tokenizer[lang] + return moses_tokenizer.tokenize(text, return_str=False, escape=False) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_pipeline + def moses_pipeline(self, text, lang): + text = replace_unicode_punct(text) + text = self.moses_punct_norm(text, lang) + text = remove_non_printing_char(text) + return text + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.ja_tokenize + def ja_tokenize(self, text): + if self.ja_word_tokenizer is None: + try: + import Mykytea + + self.ja_word_tokenizer = Mykytea.Mykytea( + f"-model {os.path.expanduser('~')}/local/share/kytea/model.bin" + ) + except (AttributeError, ImportError): + logger.error( + "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper" + " (https://github.com/chezou/Mykytea-python) with the following steps" + ) + logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") + logger.error("2. autoreconf -i") + logger.error("3. ./configure --prefix=$HOME/local") + logger.error("4. make && make install") + logger.error("5. pip install kytea") + raise + return list(self.ja_word_tokenizer.getWS(text)) + + @property + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.vocab_size + def vocab_size(self): + return len(self.encoder) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.bpe + def bpe(self, token): + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + def preprocess_text(self, text): + text = text.replace("``", '"').replace("''", '"') + text = convert_to_unicode(text) + text = unicodedata.normalize("NFC", text) + + if self.do_lowercase: + text = text.lower() + + return text + + def _tokenize(self, text, bypass_tokenizer=False): + """ + Tokenize a string given language code using Moses. + + Details of tokenization: + + - [sacremoses](https://github.com/alvations/sacremoses): port of Moses + - Install with `pip install sacremoses` + + Args: + - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) + (bool). If True, we only apply BPE. + + Returns: + List of tokens. + """ + lang = "fr" + if lang and self.lang2id and lang not in self.lang2id: + logger.error( + "Supplied language code not found in lang2id mapping. Please check that your language is supported by" + " the loaded pretrained model." + ) + + if bypass_tokenizer: + text = text.split() + else: + text = self.preprocess_text(text) + text = self.moses_pipeline(text, lang=lang) + text = self.moses_tokenize(text, lang=lang) + + split_tokens = [] + for token in text: + if token: + split_tokens.extend(list(self.bpe(token).split(" "))) + + return split_tokens + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = "".join(tokens).replace("", " ").strip() + return out_string + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An XLM sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + + """ + bos = [self.bos_token_id] + sep = [self.sep_token_id] + + if token_ids_1 is None: + return bos + token_ids_0 + sep + return bos + token_ids_0 + sep + token_ids_1 + sep + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sm"] = None + return state + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use XLMTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses + + +__all__ = ["FlaubertTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flava/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/flava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..292593cb4a201e35a9fd571baec639d9b940e76c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flava/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_flava import * + from .feature_extraction_flava import * + from .image_processing_flava import * + from .image_processing_flava_fast import * + from .modeling_flava import * + from .processing_flava import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/flava/configuration_flava.py b/venv/lib/python3.13/site-packages/transformers/models/flava/configuration_flava.py new file mode 100644 index 0000000000000000000000000000000000000000..b7bcb920e47acfacadd7a066e773ca2a3d42e2dc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flava/configuration_flava.py @@ -0,0 +1,697 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FLAVA model configurations""" + +from typing import Any, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class FlavaImageConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an + FLAVA model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + mask_token (`bool`, *optional*, defaults to `True`): + Whether to use a mask token or not. Used in MIM (Masked Image Modeling) loss for FLAVA. + vocab_size (`int`, *optional*, defaults to 8192): + Vocabulary size of the [`FlavaImageCodebook`] used in conjunction with [`FlavaImageModel`] for MIM (Masked + Image Modeling) loss for FLAVA. + + Example: + + ```python + >>> from transformers import FlavaImageConfig, FlavaImageModel + + >>> # Initializing a FlavaImageModel with style configuration + >>> configuration = FlavaImageConfig() + + >>> # Initializing a FlavaImageModel model (with random weights) from the style configuration + >>> model = FlavaImageModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "flava_image_model" + base_config_key = "image_config" + + def __init__( + self, + hidden_size: int = 768, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: int = "gelu", + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + qkv_bias: bool = True, + mask_token: bool = True, + vocab_size: int = 8192, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.mask_token = mask_token + self.vocab_size = vocab_size + + +class FlavaTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an + FLAVA model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FlavaTextModel`]. + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`FlavaTextModel`]. Note that even though + text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is + used similar to RoBERTa. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + Example: + + ```python + >>> from transformers import FlavaTextConfig, FlavaTextModel + + >>> # Initializing a FlavaTextModel with style configuration + >>> configuration = FlavaTextConfig() + + >>> # Initializing a FlavaTextModel model (with random weights) from the style configuration + >>> model = FlavaTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "flava_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size: int = 30522, + type_vocab_size: int = 2, + max_position_embeddings: int = 512, + position_embedding_type: str = "absolute", + hidden_size: int = 768, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + pad_token_id: int = 0, + qkv_bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.type_vocab_size = type_vocab_size + self.max_position_embeddings = max_position_embeddings + self.position_embedding_type = position_embedding_type + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.pad_token_id = pad_token_id + + +class FlavaMultimodalConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate + an FLAVA model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + use_cls_token (`bool`, *optional*, defaults to `True`): + Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model. + + + Example: + + ```python + >>> from transformers import FlavaMultimodalConfig, FlavaMultimodalModel + + >>> # Initializing a FlavaMultimodalModel with style configuration + >>> configuration = FlavaMultimodalConfig() + + >>> # Initializing a FlavaMultimodalModel model (with random weights) from the style configuration + >>> model = FlavaMultimodalModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "flava_multimodal_model" + base_config_key = "multimodal_config" + + def __init__( + self, + hidden_size: int = 768, + num_hidden_layers: int = 6, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: int = "gelu", + hidden_dropout_prob: int = 0.0, + attention_probs_dropout_prob: int = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + qkv_bias: bool = True, + use_cls_token: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.use_cls_token = use_cls_token + + +class FlavaImageCodebookConfig(PretrainedConfig): + model_type = "flava_image_codebook" + base_config_key = "image_codebook_config" + + r""" + [`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It + is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_groups (`int`, *optional*, defaults to 4): + Number of groups to be created. This parameter as of now doesn't affect the model and is used for some + internal calculation and estimations. + input_channels (`int`, *optional*, defaults to 3): + Number of channels in the image to be passed. + num_blocks_per_group (`int`, *optional*, defaults to 2): + Number of conv-based blocks per group. + hidden_size (`int`, *optional*, defaults to 256): + Size of hidden dim for the blocks. + vocab_size (`int`, *optional*, defaults to 8192): + Size of the output vocabulary for the codebook. + freeze (`bool`, defaults to `True`): + Whether to freeze the weights of the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import FlavaImageCodebookConfig, FlavaImageCodebook + + >>> # Initializing a FlavaImageCodebook with style configuration + >>> configuration = FlavaImageCodebookConfig() + + >>> # Initializing a FlavaImageCodebook model (with random weights) from the style configuration + >>> model = FlavaImageCodebook(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + def __init__( + self, + num_groups: int = 4, + input_channels: int = 3, + num_blocks_per_group: int = 2, + hidden_size: int = 256, + vocab_size: int = 8192, + freeze: int = True, + initializer_range: float = 0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.num_groups = num_groups + self.input_channels = input_channels + self.num_blocks_per_group = num_blocks_per_group + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.freeze = freeze + self.initializer_range = initializer_range + + +class FlavaConfig(PretrainedConfig): + r""" + [`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to + instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook + and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to + that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`FlavaTextConfig`]. + image_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`FlavaImageConfig`]. + multimodal_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and image projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original FLAVA/CLIP + implementation. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + ce_ignore_index (`int`, *optional*, defaults to -100): + Cross entropy index to ignore. + mim_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MIM (Masked Image Modeling) unimodal loss + mlm_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MLM (Masked Language Modeling) unimodal loss + global_contrastive_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to global contrastive cross-alignment loss. + itm_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to image-text matching multimodal loss. + mmm_image_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MMM loss's image part. + mmm_text_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MMM loss's text part. + global_backprop_contrastive (`bool`, *optional*, defaults to `True`): + Whether to use global backpropgation through all workers in contrastive loss. + skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`): + Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses. + return_loss (`bool`, *optional*, defaults to `True`): + Whether to return loss or not + + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import FlavaConfig, FlavaModel, FlavaForPreTraining + + >>> # Initializing a FlavaConfig with style configuration + >>> configuration = FlavaConfig() + + >>> # Initializing a FlavaModel and FlavaForPreTraining model (with random weights) from the style configuration + >>> model = FlavaModel(configuration) + >>> model_pre = FlavaForPreTraining(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + >>> configuration_pre = model_pre.config + ``` + """ + + model_type = "flava" + sub_configs = { + "text_config": FlavaTextConfig, + "image_config": FlavaImageConfig, + "multimodal_config": FlavaMultimodalConfig, + "image_codebook_config": FlavaImageCodebookConfig, + } + + def __init__( + self, + image_config: Optional[dict[str, Any]] = None, + text_config: Optional[dict[str, Any]] = None, + multimodal_config: Optional[dict[str, Any]] = None, + image_codebook_config: Optional[dict[str, Any]] = None, + hidden_size: int = 768, + layer_norm_eps: float = 1e-12, + projection_dim: int = 768, + init_codebook: bool = True, + logit_scale_init_value: float = 2.6592, + initializer_range: float = 0.02, + ce_ignore_index: int = -100, + mim_weight: float = 1.0, + mlm_weight: float = 1.0, + global_contrastive_weight: float = 1.0, + itm_weight: float = 1.0, + mmm_image_weight: float = 1.0, + mmm_text_weight: float = 1.0, + global_backprop_contrastive: bool = True, + skip_unmasked_multimodal_encoder: bool = True, + return_loss: bool = True, + **kwargs, + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + image_config_dict = kwargs.pop("image_config_dict", None) + multimodal_config_dict = kwargs.pop("multimodal_config_dict", None) + image_codebook_config_dict = kwargs.pop("image_codebook_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = FlavaTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key != "transformers_version": + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The " + f'value `text_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if image_config_dict is not None: + if image_config is None: + image_config = {} + + # This is the complete result when using `image_config_dict`. + _image_config_dict = FlavaImageConfig(**image_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _image_config_dict: + _image_config_dict["id2label"] = { + str(key): value for key, value in _image_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_image_config_dict` and `image_config` but being different. + for key, value in _image_config_dict.items(): + if key in image_config and value != image_config[key] and key != "transformers_version": + # If specified in `image_config_dict` + if key in image_config_dict: + message = ( + f"`{key}` is found in both `image_config_dict` and `image_config` but with different " + f'values. The value `image_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`image_config_dict` is provided which will be used to initialize `FlavaImageConfig`. " + f'The value `image_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `image_config` with the ones in `_image_config_dict`. + image_config.update(_image_config_dict) + + if multimodal_config_dict is not None: + if multimodal_config is None: + multimodal_config = {} + + # This is the complete result when using `multimodal_config_dict`. + _multimodal_config_dict = FlavaMultimodalConfig(**multimodal_config_dict).to_dict() + + # Give a warning if the values exist in both `_multimodal_config_dict` and `multimodal_config` but being + # different. + for key, value in _multimodal_config_dict.items(): + if key in multimodal_config and value != multimodal_config[key] and key != "transformers_version": + # If specified in `multimodal_config_dict` + if key in multimodal_config_dict: + message = ( + f"`{key}` is found in both `multimodal_config_dict` and `multimodal_config` but with " + f'different values. The value `multimodal_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`multimodal_config_dict` is provided which will be used to initialize " + f'`FlavaMultimodalConfig`. The value `multimodal_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `multimodal_config` with the ones in `_multimodal_config_dict`. + multimodal_config.update(_multimodal_config_dict) + + if image_codebook_config_dict is not None: + if image_codebook_config is None: + image_codebook_config = {} + + # This is the complete result when using `image_codebook_config_dict`. + _image_codebook_config_dict = FlavaImageCodebookConfig(**image_codebook_config_dict).to_dict() + + # Give a warning if the values exist in both `_image_codebook_config_dict` and `image_codebook_config` but + # being different. + for key, value in _image_codebook_config_dict.items(): + if ( + key in image_codebook_config + and value != image_codebook_config[key] + and key != "transformers_version" + ): + # If specified in `image_codebook_config_dict` + if key in image_codebook_config_dict: + message = ( + f"`{key}` is found in both `image_codebook_config_dict` and `image_codebook_config` but " + f'with different values. The value `image_codebook_config_dict["{key}"]` will be used ' + "instead." + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`image_codebook_config_dict` is provided which will be used to initialize " + f'`FlavaImageCodebookConfig`. The value `image_codebook_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `image_codebook_config` with the ones in `_image_codebook_config_dict`. + image_codebook_config.update(_image_codebook_config_dict) + + if image_config is None: + image_config = {} + logger.info("`image_config` is `None`. initializing the `FlavaImageConfig` with default values.") + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `FlavaTextConfig` with default values.") + + if multimodal_config is None: + multimodal_config = {} + logger.info("`multimodal_config` is `None`. initializing the `FlavaMultimodalConfig` with default values.") + + if image_codebook_config is None: + image_codebook_config = {} + logger.info( + "`image_codebook_config` is `None`. initializing the `FlavaImageCodebookConfig` with default values." + ) + + self.image_config = FlavaImageConfig(**image_config) + self.text_config = FlavaTextConfig(**text_config) + self.multimodal_config = FlavaMultimodalConfig(**multimodal_config) + self.image_codebook_config = FlavaImageCodebookConfig(**image_codebook_config) + self.projection_dim = projection_dim + self.init_codebook = init_codebook + + self.hidden_size = hidden_size + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + self.ce_ignore_index = ce_ignore_index + self.mim_weight = mim_weight + self.mlm_weight = mlm_weight + self.global_contrastive_weight = global_contrastive_weight + self.itm_weight = itm_weight + self.mmm_image_weight = mmm_image_weight + self.mmm_text_weight = mmm_text_weight + self.global_backprop_contrastive = global_backprop_contrastive + self.skip_unmasked_multimodal_encoder = skip_unmasked_multimodal_encoder + self.return_loss = return_loss + + @classmethod + def from_configs( + cls, + image_config: FlavaImageConfig, + text_config: FlavaTextConfig, + multimodal_config: FlavaMultimodalConfig, + image_codebook_config: FlavaImageCodebookConfig, + **kwargs, + ): + r""" + Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model + configuration, flava multimodal model and flava codebook model configuration. + + Returns: + [`FlavaConfig`]: An instance of a configuration object + """ + + return cls( + image_config=image_config.to_dict(), + text_config=text_config.to_dict(), + multimodal_config=multimodal_config.to_dict(), + image_codebook_config=image_codebook_config.to_dict(), + **kwargs, + ) + + +__all__ = ["FlavaConfig", "FlavaImageCodebookConfig", "FlavaImageConfig", "FlavaMultimodalConfig", "FlavaTextConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flava/feature_extraction_flava.py b/venv/lib/python3.13/site-packages/transformers/models/flava/feature_extraction_flava.py new file mode 100644 index 0000000000000000000000000000000000000000..19bcccc889f546442af71229c998880fbbb2db31 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flava/feature_extraction_flava.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for FLAVA.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_flava import FlavaImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class FlavaFeatureExtractor(FlavaImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class FlavaFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use FlavaImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["FlavaFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flava/image_processing_flava.py b/venv/lib/python3.13/site-packages/transformers/models/flava/image_processing_flava.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4db246a8fa4195efd8d9840e351aae43af2a53 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flava/image_processing_flava.py @@ -0,0 +1,706 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Flava.""" + +import math +import random +from collections.abc import Iterable +from functools import lru_cache +from typing import Any, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +# These values are taken from CLIP +FLAVA_IMAGE_MEAN = OPENAI_CLIP_MEAN +FLAVA_IMAGE_STD = OPENAI_CLIP_STD +FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0] +FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0] +LOGIT_LAPLACE_EPS: float = 0.1 + + +# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py +class FlavaMaskingGenerator: + def __init__( + self, + input_size: Union[int, tuple[int, int]] = 14, + total_mask_patches: int = 75, + mask_group_max_patches: Optional[int] = None, + mask_group_min_patches: int = 16, + mask_group_min_aspect_ratio: Optional[float] = 0.3, + mask_group_max_aspect_ratio: Optional[float] = None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.total_mask_patches = total_mask_patches + + self.mask_group_min_patches = mask_group_min_patches + self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches + + mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio + self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio)) + + def __repr__(self): + repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.mask_group_min_patches, + self.mask_group_max_patches, + self.total_mask_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _attempt in range(10): + target_area = random.uniform(self.mask_group_min_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + height = int(round(math.sqrt(target_area * aspect_ratio))) + width = int(round(math.sqrt(target_area / aspect_ratio))) + if width < self.width and height < self.height: + top = random.randint(0, self.height - height) + left = random.randint(0, self.width - width) + + num_masked = mask[top : top + height, left : left + width].sum() + # Overlap + if 0 < height * width - num_masked <= max_mask_patches: + for i in range(top, top + height): + for j in range(left, left + width): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self): + mask = np.zeros(shape=self.get_shape(), dtype=int) + mask_count = 0 + while mask_count < self.total_mask_patches: + max_mask_patches = self.total_mask_patches - mask_count + max_mask_patches = min(max_mask_patches, self.mask_group_max_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask + + +@requires(backends=("vision",)) +class FlavaImageProcessor(BaseImageProcessor): + r""" + Constructs a Flava image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in `preprocess`. + size (`dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by the `size` parameter in `preprocess`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in + `preprocess`. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the images. Can be overridden by the `do_center_crop` parameter in `preprocess`. + crop_size (`dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of image after the center crop `(crop_size["height"], crop_size["width"])`. Can be overridden by the + `crop_size` parameter in `preprocess`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in `preprocess`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in + `preprocess`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in `preprocess`. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + return_image_mask (`bool`, *optional*, defaults to `False`): + Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`. + input_size_patches (`int`, *optional*, defaults to 14): + Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden + by the `input_size_patches` parameter in `preprocess`. + total_mask_patches (`int`, *optional*, defaults to 75): + Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in + `preprocess`. + mask_group_min_patches (`int`, *optional*, defaults to 16): + Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches` + parameter in `preprocess`. + mask_group_max_patches (`int`, *optional*): + Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches` + parameter in `preprocess`. + mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3): + Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter + in `preprocess`. + mask_group_max_aspect_ratio (`float`, *optional*): + Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter + in `preprocess`. + codebook_do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize` + parameter in `preprocess`. `codebook_size`. + codebook_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in + `preprocess`. + codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample` + parameter in `preprocess`. + codebook_do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input for codebook at the center. If the input size is smaller than + `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be + overridden by the `codebook_do_center_crop` parameter in `preprocess`. + codebook_crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size for codebook input when applying center-cropping. Can be overridden by the + `codebook_crop_size` parameter in `preprocess`. + codebook_do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be + overridden by the `codebook_do_rescale` parameter in `preprocess`. + codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the codebook image. Can be overridden by the + `codebook_rescale_factor` parameter in `preprocess`. + codebook_do_map_pixels (`bool`, *optional*, defaults to `True`): + Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the + `codebook_do_map_pixels` parameter in `preprocess`. + codebook_do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can + be overridden by the `codebook_do_normalize` parameter in `preprocess`. + codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`): + The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden + by the `codebook_image_mean` parameter in `preprocess`. + codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can + be overridden by the `codebook_image_std` parameter in `preprocess`. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, Iterable[float]]] = None, + image_std: Optional[Union[float, Iterable[float]]] = None, + # Mask related params + return_image_mask: bool = False, + input_size_patches: int = 14, + total_mask_patches: int = 75, + mask_group_min_patches: int = 16, + mask_group_max_patches: Optional[int] = None, + mask_group_min_aspect_ratio: float = 0.3, + mask_group_max_aspect_ratio: Optional[float] = None, + # Codebook related params + return_codebook_pixels: bool = False, + codebook_do_resize: bool = True, + codebook_size: Optional[bool] = None, + codebook_resample: int = PILImageResampling.LANCZOS, + codebook_do_center_crop: bool = True, + codebook_crop_size: Optional[int] = None, + codebook_do_rescale: bool = True, + codebook_rescale_factor: Union[int, float] = 1 / 255, + codebook_do_map_pixels: bool = True, + codebook_do_normalize: bool = True, + codebook_image_mean: Optional[Union[float, Iterable[float]]] = None, + codebook_image_std: Optional[Union[float, Iterable[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112} + codebook_size = get_size_dict(codebook_size, param_name="codebook_size") + codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112} + codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN + self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD + + self.return_image_mask = return_image_mask + self.input_size_patches = input_size_patches + self.total_mask_patches = total_mask_patches + self.mask_group_min_patches = mask_group_min_patches + self.mask_group_max_patches = mask_group_max_patches + self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio + self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio + + self.return_codebook_pixels = return_codebook_pixels + self.codebook_do_resize = codebook_do_resize + self.codebook_size = codebook_size + self.codebook_resample = codebook_resample + self.codebook_do_center_crop = codebook_do_center_crop + self.codebook_crop_size = codebook_crop_size + self.codebook_do_rescale = codebook_do_rescale + self.codebook_rescale_factor = codebook_rescale_factor + self.codebook_do_map_pixels = codebook_do_map_pixels + self.codebook_do_normalize = codebook_do_normalize + self.codebook_image_mean = codebook_image_mean + self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN + self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)` + """ + image_processor_dict = image_processor_dict.copy() + if "codebook_size" in kwargs: + image_processor_dict["codebook_size"] = kwargs.pop("codebook_size") + if "codebook_crop_size" in kwargs: + image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size") + return super().from_dict(image_processor_dict, **kwargs) + + @lru_cache + def masking_generator( + self, + input_size_patches, + total_mask_patches, + mask_group_min_patches, + mask_group_max_patches, + mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio, + ) -> FlavaMaskingGenerator: + return FlavaMaskingGenerator( + input_size=input_size_patches, + total_mask_patches=total_mask_patches, + mask_group_min_patches=mask_group_min_patches, + mask_group_max_patches=mask_group_max_patches, + mask_group_min_aspect_ratio=mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio=mask_group_max_aspect_ratio, + ) + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def map_pixels(self, image: np.ndarray) -> np.ndarray: + return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_map_pixels: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[ChannelDimension] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + if do_map_pixels: + image = self.map_pixels(image) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + # Mask related params + return_image_mask: Optional[bool] = None, + input_size_patches: Optional[int] = None, + total_mask_patches: Optional[int] = None, + mask_group_min_patches: Optional[int] = None, + mask_group_max_patches: Optional[int] = None, + mask_group_min_aspect_ratio: Optional[float] = None, + mask_group_max_aspect_ratio: Optional[float] = None, + # Codebook related params + return_codebook_pixels: Optional[bool] = None, + codebook_do_resize: Optional[bool] = None, + codebook_size: Optional[dict[str, int]] = None, + codebook_resample: Optional[int] = None, + codebook_do_center_crop: Optional[bool] = None, + codebook_crop_size: Optional[dict[str, int]] = None, + codebook_do_rescale: Optional[bool] = None, + codebook_rescale_factor: Optional[float] = None, + codebook_do_map_pixels: Optional[bool] = None, + codebook_do_normalize: Optional[bool] = None, + codebook_image_mean: Optional[Iterable[float]] = None, + codebook_image_std: Optional[Iterable[float]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_image_mask (`bool`, *optional*, defaults to `self.return_image_mask`): + Whether to return the image mask. + input_size_patches (`int`, *optional*, defaults to `self.input_size_patches`): + Size of the patches to extract from the image. + total_mask_patches (`int`, *optional*, defaults to `self.total_mask_patches`): + Total number of patches to extract from the image. + mask_group_min_patches (`int`, *optional*, defaults to `self.mask_group_min_patches`): + Minimum number of patches to extract from the image. + mask_group_max_patches (`int`, *optional*, defaults to `self.mask_group_max_patches`): + Maximum number of patches to extract from the image. + mask_group_min_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_min_aspect_ratio`): + Minimum aspect ratio of the patches to extract from the image. + mask_group_max_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_max_aspect_ratio`): + Maximum aspect ratio of the patches to extract from the image. + return_codebook_pixels (`bool`, *optional*, defaults to `self.return_codebook_pixels`): + Whether to return the codebook pixels. + codebook_do_resize (`bool`, *optional*, defaults to `self.codebook_do_resize`): + Whether to resize the codebook pixels. + codebook_size (`dict[str, int]`, *optional*, defaults to `self.codebook_size`): + Size of the codebook pixels. + codebook_resample (`int`, *optional*, defaults to `self.codebook_resample`): + Resampling filter to use if resizing the codebook pixels. This can be one of the enum + `PILImageResampling`, Only has an effect if `codebook_do_resize` is set to `True`. + codebook_do_center_crop (`bool`, *optional*, defaults to `self.codebook_do_center_crop`): + Whether to center crop the codebook pixels. + codebook_crop_size (`dict[str, int]`, *optional*, defaults to `self.codebook_crop_size`): + Size of the center crop of the codebook pixels. Only has an effect if `codebook_do_center_crop` is set + to `True`. + codebook_do_rescale (`bool`, *optional*, defaults to `self.codebook_do_rescale`): + Whether to rescale the codebook pixels values between [0 - 1]. + codebook_rescale_factor (`float`, *optional*, defaults to `self.codebook_rescale_factor`): + Rescale factor to rescale the codebook pixels by if `codebook_do_rescale` is set to `True`. + codebook_do_map_pixels (`bool`, *optional*, defaults to `self.codebook_do_map_pixels`): + Whether to map the codebook pixels values. + codebook_do_normalize (`bool`, *optional*, defaults to `self.codebook_do_normalize`): + Whether to normalize the codebook pixels. + codebook_image_mean (`float` or `list[float]`, *optional*, defaults to `self.codebook_image_mean`): + Codebook pixels mean to normalize the codebook pixels by if `codebook_do_normalize` is set to `True`. + codebook_image_std (`float` or `list[float]`, *optional*, defaults to `self.codebook_image_std`): + Codebook pixels standard deviation to normalize the codebook pixels by if `codebook_do_normalize` is + set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + return_image_mask = return_image_mask if return_image_mask is not None else self.return_image_mask + input_size_patches = input_size_patches if input_size_patches is not None else self.input_size_patches + total_mask_patches = total_mask_patches if total_mask_patches is not None else self.total_mask_patches + mask_group_min_patches = ( + mask_group_min_patches if mask_group_min_patches is not None else self.mask_group_min_patches + ) + mask_group_max_patches = ( + mask_group_max_patches if mask_group_max_patches is not None else self.mask_group_max_patches + ) + mask_group_min_aspect_ratio = ( + mask_group_min_aspect_ratio + if mask_group_min_aspect_ratio is not None + else self.mask_group_min_aspect_ratio + ) + mask_group_max_aspect_ratio = ( + mask_group_max_aspect_ratio + if mask_group_max_aspect_ratio is not None + else self.mask_group_max_aspect_ratio + ) + + return_codebook_pixels = ( + return_codebook_pixels if return_codebook_pixels is not None else self.return_codebook_pixels + ) + codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize + codebook_size = codebook_size if codebook_size is not None else self.codebook_size + codebook_size = get_size_dict(codebook_size, param_name="codebook_size") + codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample + codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale + codebook_rescale_factor = ( + codebook_rescale_factor if codebook_rescale_factor is not None else self.codebook_rescale_factor + ) + codebook_do_center_crop = ( + codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop + ) + codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size + codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size") + codebook_do_map_pixels = ( + codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels + ) + codebook_do_normalize = ( + codebook_do_normalize if codebook_do_normalize is not None else self.codebook_do_normalize + ) + codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean + codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + processed_images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_map_pixels=False, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + data = {"pixel_values": processed_images} + + if return_codebook_pixels: + codebook_images = [ + self._preprocess_image( + image=img, + do_resize=codebook_do_resize, + size=codebook_size, + resample=codebook_resample, + do_center_crop=codebook_do_center_crop, + crop_size=codebook_crop_size, + do_rescale=codebook_do_rescale, + rescale_factor=codebook_rescale_factor, + do_normalize=codebook_do_normalize, + image_mean=codebook_image_mean, + image_std=codebook_image_std, + do_map_pixels=codebook_do_map_pixels, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + data["codebook_pixel_values"] = codebook_images + + if return_image_mask: + mask_generator = self.masking_generator( + input_size_patches=input_size_patches, + total_mask_patches=total_mask_patches, + mask_group_min_patches=mask_group_min_patches, + mask_group_max_patches=mask_group_max_patches, + mask_group_min_aspect_ratio=mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio=mask_group_max_aspect_ratio, + ) + masks = [mask_generator() for _ in images] + data["bool_masked_pos"] = masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["FlavaImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flava/image_processing_flava_fast.py b/venv/lib/python3.13/site-packages/transformers/models/flava/image_processing_flava_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..732d25e71f697e083e784199c1db1782298951ac --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flava/image_processing_flava_fast.py @@ -0,0 +1,491 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Flava.""" + +import math +import random +from collections.abc import Iterable +from functools import lru_cache +from typing import Any, Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, + get_size_dict, +) +from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, SizeDict, pil_torch_interpolation_mapping +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) +from .image_processing_flava import ( + FLAVA_CODEBOOK_MEAN, + FLAVA_CODEBOOK_STD, + FLAVA_IMAGE_MEAN, + FLAVA_IMAGE_STD, + LOGIT_LAPLACE_EPS, +) + + +class FlavaMaskingGenerator: + def __init__( + self, + input_size: Union[int, tuple[int, int]] = 14, + total_mask_patches: int = 75, + mask_group_max_patches: Optional[int] = None, + mask_group_min_patches: int = 16, + mask_group_min_aspect_ratio: Optional[float] = 0.3, + mask_group_max_aspect_ratio: Optional[float] = None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.total_mask_patches = total_mask_patches + + self.mask_group_min_patches = mask_group_min_patches + self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches + + mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio + self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio)) + + def __repr__(self): + repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.mask_group_min_patches, + self.mask_group_max_patches, + self.total_mask_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _attempt in range(10): + target_area = random.uniform(self.mask_group_min_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + height = int(round(math.sqrt(target_area * aspect_ratio))) + width = int(round(math.sqrt(target_area / aspect_ratio))) + if width < self.width and height < self.height: + top = random.randint(0, self.height - height) + left = random.randint(0, self.width - width) + + num_masked = mask[top : top + height, left : left + width].sum() + # Overlap + if 0 < height * width - num_masked <= max_mask_patches: + zeros_pos = mask[top : top + height, left : left + width] == 0 + mask[top : top + height, left : left + width][zeros_pos] = 1 + delta += zeros_pos.sum() + + if delta > 0: + break + return delta + + def __call__(self): + mask = torch.zeros(self.get_shape(), dtype=torch.int) + mask_count = 0 + while mask_count < self.total_mask_patches: + max_mask_patches = self.total_mask_patches - mask_count + max_mask_patches = min(max_mask_patches, self.mask_group_max_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask + + +class FlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + Args: + return_image_mask (`bool`, *optional*, defaults to `False`): + Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`. + input_size_patches (`int`, *optional*, defaults to 14): + Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden + by the `input_size_patches` parameter in `preprocess`. + total_mask_patches (`int`, *optional*, defaults to 75): + Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in + `preprocess`. + mask_group_min_patches (`int`, *optional*, defaults to 16): + Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches` + parameter in `preprocess`. + mask_group_max_patches (`int`, *optional*): + Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches` + parameter in `preprocess`. + mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3): + Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter + in `preprocess`. + mask_group_max_aspect_ratio (`float`, *optional*): + Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter + in `preprocess`. + return_codebook_pixels (`bool`, *optional*, defaults to `False`): + Whether to return the codebook pixel values. + codebook_do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize` + parameter in `preprocess`. `codebook_size`. + codebook_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in + `preprocess`. + codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample` + parameter in `preprocess`. + codebook_do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input for codebook at the center. If the input size is smaller than + `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be + overridden by the `codebook_do_center_crop` parameter in `preprocess`. + codebook_crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size for codebook input when applying center-cropping. Can be overridden by the + `codebook_crop_size` parameter in `preprocess`. + codebook_do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be + overridden by the `codebook_do_rescale` parameter in `preprocess`. + codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the codebook image. Can be overridden by the + `codebook_rescale_factor` parameter in `preprocess`. + codebook_do_map_pixels (`bool`, *optional*, defaults to `True`): + Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the + `codebook_do_map_pixels` parameter in `preprocess`. + codebook_do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can + be overridden by the `codebook_do_normalize` parameter in `preprocess`. + codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`): + The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden + by the `codebook_image_mean` parameter in `preprocess`. + codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can + be overridden by the `codebook_image_std` parameter in `preprocess`. + """ + + # Mask related params + return_image_mask: Optional[bool] + input_size_patches: Optional[int] + total_mask_patches: Optional[int] + mask_group_min_patches: Optional[int] + mask_group_max_patches: Optional[int] + mask_group_min_aspect_ratio: Optional[float] + mask_group_max_aspect_ratio: Optional[float] + # Codebook related params + return_codebook_pixels: Optional[bool] + codebook_do_resize: Optional[bool] + codebook_size: Optional[bool] + codebook_resample: Optional[int] + codebook_do_center_crop: Optional[bool] + codebook_crop_size: Optional[int] + codebook_do_rescale: Optional[bool] + codebook_rescale_factor: Optional[Union[int, float]] + codebook_do_map_pixels: Optional[bool] + codebook_do_normalize: Optional[bool] + codebook_image_mean: Optional[Union[float, Iterable[float]]] + codebook_image_std: Optional[Union[float, Iterable[float]]] + + +@auto_docstring +class FlavaImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = FLAVA_IMAGE_MEAN + image_std = FLAVA_IMAGE_STD + size = {"height": 224, "width": 224} + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + + # Mask related params + return_image_mask = False + input_size_patches = 14 + total_mask_patches = 75 + mask_group_min_patches = 16 + mask_group_max_patches = None + mask_group_min_aspect_ratio = 0.3 + mask_group_max_aspect_ratio = None + # Codebook related params + return_codebook_pixels = False + codebook_do_resize = True + codebook_size = {"height": 112, "width": 112} + # LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative + codebook_resample = PILImageResampling.BICUBIC + codebook_do_center_crop = True + codebook_crop_size = {"height": 112, "width": 112} + codebook_do_rescale = True + codebook_rescale_factor = 1 / 255 + codebook_do_map_pixels = True + codebook_do_normalize = True + codebook_image_mean = FLAVA_CODEBOOK_MEAN + codebook_image_std = FLAVA_CODEBOOK_STD + valid_kwargs = FlavaFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[FlavaFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)` + """ + image_processor_dict = image_processor_dict.copy() + if "codebook_size" in kwargs: + image_processor_dict["codebook_size"] = kwargs.pop("codebook_size") + if "codebook_crop_size" in kwargs: + image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size") + return super().from_dict(image_processor_dict, **kwargs) + + @lru_cache + def masking_generator( + self, + input_size_patches, + total_mask_patches, + mask_group_min_patches, + mask_group_max_patches, + mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio, + ) -> FlavaMaskingGenerator: + return FlavaMaskingGenerator( + input_size=input_size_patches, + total_mask_patches=total_mask_patches, + mask_group_min_patches=mask_group_min_patches, + mask_group_max_patches=mask_group_max_patches, + mask_group_min_aspect_ratio=mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio=mask_group_max_aspect_ratio, + ) + + def map_pixels(self, image: "torch.Tensor") -> "torch.Tensor": + return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + crop_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + codebook_size: Optional[SizeDict] = None, + codebook_crop_size: Optional[SizeDict] = None, + codebook_image_mean: Optional[Union[float, list[float]]] = None, + codebook_image_std: Optional[Union[float, list[float]]] = None, + codebook_resample: Optional[PILImageResampling] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if crop_size is not None: + crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + if codebook_size is not None: + codebook_size = SizeDict(**get_size_dict(size=codebook_size, default_to_square=default_to_square)) + if codebook_crop_size is not None: + codebook_crop_size = SizeDict(**get_size_dict(codebook_crop_size, param_name="codebook_crop_size")) + if isinstance(codebook_image_mean, list): + codebook_image_mean = tuple(codebook_image_mean) + if isinstance(codebook_image_std, list): + codebook_image_std = tuple(codebook_image_std) + + kwargs["size"] = size + kwargs["crop_size"] = crop_size + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["codebook_size"] = codebook_size + kwargs["codebook_crop_size"] = codebook_crop_size + kwargs["codebook_image_mean"] = codebook_image_mean + kwargs["codebook_image_std"] = codebook_image_std + kwargs["data_format"] = data_format + kwargs["codebook_interpolation"] = ( + pil_torch_interpolation_mapping[codebook_resample] + if isinstance(codebook_resample, (PILImageResampling, int)) + else codebook_resample + ) + + # torch resize uses interpolation instead of resample + # Check if resample is an int before checking if it's an instance of PILImageResampling + # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. + # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + return kwargs + + def _preprocess_image( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_map_pixels: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + ) -> "torch.Tensor": + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if do_map_pixels: + stacked_images = self.map_pixels(image=stacked_images) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return processed_images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + # Mask related params + return_image_mask: Optional[bool], + input_size_patches: Optional[int], + total_mask_patches: Optional[int], + mask_group_min_patches: Optional[int], + mask_group_max_patches: Optional[int], + mask_group_min_aspect_ratio: Optional[float], + mask_group_max_aspect_ratio: Optional[float], + # Codebook related params + return_codebook_pixels: Optional[bool], + codebook_do_resize: Optional[bool], + codebook_size: Optional[SizeDict], + codebook_interpolation: Optional["F.InterpolationMode"], + codebook_do_center_crop: Optional[bool], + codebook_crop_size: Optional[SizeDict], + codebook_do_rescale: Optional[bool], + codebook_rescale_factor: Optional[float], + codebook_do_map_pixels: Optional[bool], + codebook_do_normalize: Optional[bool], + codebook_image_mean: Optional[Union[float, list[float]]], + codebook_image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + processed_images = self._preprocess_image( + images=images, + do_resize=do_resize, + size=size, + interpolation=interpolation, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + do_map_pixels=False, + image_mean=image_mean, + image_std=image_std, + disable_grouping=disable_grouping, + return_tensors=return_tensors, + ) + data = { + "pixel_values": processed_images, + } + + if return_codebook_pixels: + codebook_processed_images = self._preprocess_image( + images=images, + do_resize=codebook_do_resize, + size=codebook_size, + interpolation=codebook_interpolation, + do_center_crop=codebook_do_center_crop, + crop_size=codebook_crop_size, + do_rescale=codebook_do_rescale, + rescale_factor=codebook_rescale_factor, + do_normalize=codebook_do_normalize, + do_map_pixels=codebook_do_map_pixels, + image_mean=codebook_image_mean, + image_std=codebook_image_std, + disable_grouping=disable_grouping, + return_tensors=return_tensors, + ) + data["codebook_pixel_values"] = codebook_processed_images + + if return_image_mask: + mask_generator = self.masking_generator( + input_size_patches=input_size_patches, + total_mask_patches=total_mask_patches, + mask_group_min_patches=mask_group_min_patches, + mask_group_max_patches=mask_group_max_patches, + mask_group_min_aspect_ratio=mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio=mask_group_max_aspect_ratio, + ) + masks = [mask_generator() for _ in range(len(images))] + masks = torch.stack(masks, dim=0) if return_tensors else masks + data["bool_masked_pos"] = masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["FlavaImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flava/modeling_flava.py b/venv/lib/python3.13/site-packages/transformers/models/flava/modeling_flava.py new file mode 100644 index 0000000000000000000000000000000000000000..266c3e96af5a263db2b49a31223f9652563ae213 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flava/modeling_flava.py @@ -0,0 +1,2030 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FLAVA model.""" + +import collections +import math +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, filter_out_non_signature_kwargs, logging, torch_int +from .configuration_flava import ( + FlavaConfig, + FlavaImageCodebookConfig, + FlavaImageConfig, + FlavaMultimodalConfig, + FlavaTextConfig, +) + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook" + +LOGIT_SCALE_CLAMP_MIN = 0 +LOGIT_SCALE_CLAMP_MAX = 4.6052 + +FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig] + + +@dataclass +@auto_docstring( + custom_intro=""" + Output from FlavaModel containing embeddings and outputs from individual encoders. + + Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a + transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and + `text_projection` layers on `image_embeddings` and `text_embeddings` respectively. + """ +) +class FlavaModelOutput(ModelOutput): + r""" + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FlavaImageModel`]. + image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FlavaImageModel`]. + text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present): + The text embeddings which are basically the pooled output of [`FlavaTextModel`]. + text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): + The output of the [`FlavaTextModel`]. + multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): + The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`]. + multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): + The output of the [`FlavaMultimodalModel`]. + """ + + image_embeddings: Optional[torch.FloatTensor] = None + image_output: Optional[BaseModelOutputWithPooling] = None + text_embeddings: Optional[torch.FloatTensor] = None + text_output: Optional[BaseModelOutputWithPooling] = None + multimodal_embeddings: Optional[torch.FloatTensor] = None + multimodal_output: Optional[BaseModelOutputWithPooling] = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class representing pretraining losses from FLAVA model + """ +) +class FlavaLosses(ModelOutput): + r""" + mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.): + Masked Image Modeling loss as used in BeIT calculated only for unimodal image data. + mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.): + Masked Language Modeling loss as used in BERT calculated only for unimodal text data. + itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.): + Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on + masked pairs in FLAVA. + global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.): + Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text + data. This is calculated on unmasked images and texts. + mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.): + Masked Multimodal Modeling loss's image component calculated on paired image-text data. + mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.): + Masked Multimodal Modeling loss's text component calculated on paired image-text data. + """ + + mim: Optional[torch.FloatTensor] = None + mlm: Optional[torch.FloatTensor] = None + itm: Optional[torch.FloatTensor] = None + global_contrastive: Optional[torch.FloatTensor] = None + mmm_image: Optional[torch.FloatTensor] = None + mmm_text: Optional[torch.FloatTensor] = None + + def all_none(self) -> bool: + all_none = True + for v in self.values(): + if v is not None: + all_none = False + break + return all_none + + +@dataclass +@auto_docstring( + custom_intro=""" + Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders. + + Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a + transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and + `text_projection` layers on `image_embeddings` and `text_embeddings` respectively. + """ +) +class FlavaForPreTrainingOutput(ModelOutput): + r""" + loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True): + Total loss calculated for this model. + loss_info (`FlavaLosses`): + Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on + the keys. + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FlavaImageModel`]. + image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FlavaImageModel`]. + text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present): + The text embeddings which are basically the pooled output of [`FlavaTextModel`]. + text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): + The output of the [`FlavaTextModel`]. + multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): + The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`]. + multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): + The output of the [`FlavaMultimodalModel`]. + image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos` + to create masked images. + image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images. + text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present): + The text embeddings which are basically the pooled output of [`FlavaTextModel`]. + text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present): + The output of the [`FlavaTextModel`]. + multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present): + The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`]. + multimodal_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` and `pixel_values` are present): + The output of the [`FlavaMultimodalModel`]. + mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not): + The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is + returned when `bool_masked_pos` has some of the patches masked. + mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not): + The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of + the tokens masked. + itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present): + The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA. + contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's + `image_projection` and `text_projection` layers respectively. This represents the image-text similarity + scores. This is calculated on unmasked images and texts. + contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's + `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and + texts. + mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present): + The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened + output is returned when `bool_masked_pos` has some of the patches masked. + mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present): + The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has + some of the tokens masked. + """ + + loss: Optional[torch.FloatTensor] = None + loss_info: FlavaLosses = None + image_embeddings: Optional[torch.FloatTensor] = None + image_output: Optional[BaseModelOutputWithPooling] = None + text_embeddings: Optional[torch.FloatTensor] = None + text_output: Optional[BaseModelOutputWithPooling] = None + multimodal_embeddings: Optional[torch.FloatTensor] = None + multimodal_output: Optional[BaseModelOutputWithPooling] = None + image_masked_embeddings: Optional[torch.FloatTensor] = None + image_masked_output: Optional[BaseModelOutputWithPooling] = None + text_masked_embeddings: Optional[torch.FloatTensor] = None + text_masked_output: Optional[BaseModelOutputWithPooling] = None + multimodal_masked_embeddings: Optional[torch.FloatTensor] = None + multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None + mim_logits: Optional[torch.FloatTensor] = None + mlm_logits: Optional[torch.FloatTensor] = None + itm_logits: Optional[torch.FloatTensor] = None + contrastive_logits_per_image: Optional[torch.FloatTensor] = None + contrastive_logits_per_text: Optional[torch.FloatTensor] = None + mmm_image_logits: Optional[torch.FloatTensor] = None + mmm_text_logits: Optional[torch.FloatTensor] = None + + def to_tuple(self) -> tuple[Any]: + transformer_outputs = [ + "text_output", + "image_output", + "multimodal_output", + "text_masked_output", + "image_masked_output", + "multimodal_masked_output", + ] + return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys()) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py +class FlavaImageEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None: + super().__init__() + + use_mask_token = use_mask_token or config.mask_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = PatchEmbeddings( + image_size=config.image_size, + patch_size=config.patch_size, + num_channels=config.num_channels, + embed_dim=config.hidden_size, + ) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + batch_size, seq_len, _ = embeddings.size() + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # B X H X W = B X HW + if bool_masked_pos.dim() == 3: + bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py +class PatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + image_size: int = 224, + patch_size: Union[int, tuple[int, int]] = 16, + num_channels: int = 3, + embed_dim: int = 768, + ): + super().__init__() + if not isinstance(image_size, collections.abc.Iterable): + image_size = (image_size, image_size) + if not isinstance(patch_size, collections.abc.Iterable): + patch_size = (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x + + +class FlavaTextEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + input_shape = input_ids.size() + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class FlavaSelfAttention(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class FlavaSelfOutput(nn.Module): + """ + The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other + models), due to the layernorm applied before each block. + """ + + def __init__(self, config: FlavaPossibleConfigs) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class FlavaAttention(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: + super().__init__() + self.attention = FlavaSelfAttention(config) + self.output = FlavaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + self_outputs = self.attention( + hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions + ) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class FlavaIntermediate(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + # Copied from transformers.models.vit.modeling_vit.ViTIntermediate.forward + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class FlavaOutput(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Copied from transformers.models.vit.modeling_vit.ViTOutput.forward + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +class FlavaLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: FlavaPossibleConfigs) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = FlavaAttention(config) + self.intermediate = FlavaIntermediate(config) + self.output = FlavaOutput(config) + + # TODO: Check fp32 layer norm possibility + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class FlavaEncoder(nn.Module): + def __init__(self, config: FlavaConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + ) + + +class FlavaPooler(nn.Module): + def __init__(self, config: FlavaPossibleConfigs): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class FlavaPreTrainedModel(PreTrainedModel): + config: FlavaConfig + base_model_prefix = "flava" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, FlavaMaskedPredictionHead): + module.bias.data.zero_() + elif isinstance(module, FlavaImageEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + elif isinstance(module, FlavaMultimodalModel): + if module.use_cls_token: + module.cls_token.data.zero_() + elif isinstance(module, FlavaModel): + module.logit_scale.data.fill_(self.config.logit_scale_init_value) + + +@auto_docstring +class FlavaImageModel(FlavaPreTrainedModel): + config: FlavaImageConfig + # This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints. + base_model_prefix = "flava.image_model" + main_input_name = "pixel_values" + + def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + + self.config = config + + self.embeddings = FlavaImageEmbeddings(config) + self.encoder = FlavaEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = FlavaPooler(config) if add_pooling_layer else None + + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embeddings + + def set_input_embeddings(self, value: nn.Module): + self.embeddings.patch_embeddings = value + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class FlavaTextModel(FlavaPreTrainedModel): + config: FlavaTextConfig + # This override allows us to load FlavaTextModel from FlavaModel/FlavaForPreTraining checkpoints. + base_model_prefix = "flava.text_model" + + def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = FlavaTextEmbeddings(config) + self.encoder = FlavaEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = FlavaPooler(config) if add_pooling_layer else None + + self.post_init() + + def get_input_embeddings(self) -> PatchEmbeddings: + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: nn.Module): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=input_ids.device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, input_ids.device + ) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class FlavaMultimodalModel(FlavaPreTrainedModel): + config: FlavaMultimodalConfig + # This override allows us to load FlavaMultimodalModel from FlavaModel/FlavaForPreTraining checkpoints. + base_model_prefix = "flava.multimodal_model" + main_input_name = "hidden_states" + + def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + self.use_cls_token = self.config.use_cls_token + if self.use_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + self.encoder = FlavaEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = FlavaPooler(config) if add_pooling_layer else None + + self.post_init() + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`): + The concatenated hidden states of unimodal encoders. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length, _ = hidden_states.size() + + if self.use_cls_token: + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + hidden_states = torch.cat((cls_tokens, hidden_states), dim=1) + seq_length += 1 + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states.device + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class FlavaModel(FlavaPreTrainedModel): + config: FlavaConfig + + def __init__(self, config: FlavaConfig): + super().__init__(config) + + if not isinstance(config.text_config, FlavaTextConfig): + raise TypeError( + "config.text_config is expected to be of type FlavaTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.image_config, FlavaImageConfig): + raise TypeError( + "config.image_config is expected to be of type FlavaImageConfig but is of type" + f" {type(config.image_config)}." + ) + + if not isinstance(config.multimodal_config, FlavaMultimodalConfig): + raise TypeError( + "config.multimodal_config is expected to be of type FlavaMultimodalConfig but " + + f"is of type {type(config.multimodal_config)}." + ) + + text_config = config.text_config + image_config = config.image_config + multimodal_config = config.multimodal_config + + self.projection_dim = config.projection_dim + self.text_hidden_size = text_config.hidden_size + self.image_hidden_size = image_config.hidden_size + self.mm_hidden_size = multimodal_config.hidden_size + + self.text_model = FlavaTextModel(text_config) + self.image_model = FlavaImageModel(image_config) + self.multimodal_model = FlavaMultimodalModel(multimodal_config) + + self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim) + self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size) + self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size) + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) + + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`FlavaTextModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, FlavaModel + + >>> model = FlavaModel.from_pretrained("{0}") + >>> processor = AutoProcessor.from_pretrained("{0}") + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt" + ... ) + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ``` + """ + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + pooled_output = text_outputs.last_hidden_state + text_features = self.text_projection(pooled_output) + + return text_features + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`FlavaImageModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, FlavaModel + >>> from transformers.image_utils import load_image + + >>> model = FlavaModel.from_pretrained("{0}") + >>> processor = AutoProcessor.from_pretrained("{0}") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.inference_mode(): + ... image_features = model.get_image_features(**inputs) + ``` + """ + image_outputs: BaseModelOutputWithPooling = self.image_model( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + attention_mask=attention_mask, + head_mask=head_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + pooled_output = image_outputs.last_hidden_state + image_features = self.image_projection(pooled_output) + + return image_features + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + skip_multimodal_encoder: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: bool = True, + return_dict: Optional[bool] = None, + ) -> Union[tuple, FlavaOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + image_attention_mask (`torch.Tensor` of shape `(batch_size, image_num_patches)`, *optional*): + Mask to avoid performing attention on padding pixel values for image inputs. Mask values selected in `[0, 1]`: + - 1 for pixel values that are real (i.e., **not masked**), + - 0 for pixel values that are padding (i.e., **masked**). + skip_multimodal_encoder (*bool*, *optional*): + Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, FlavaModel + + >>> model = FlavaModel.from_pretrained("facebook/flava-full") + >>> processor = AutoProcessor.from_pretrained("facebook/flava-full") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True) + + >>> outputs = model(**inputs) + + >>> image_embeddings = outputs.image_embeddings + >>> text_embeddings = outputs.text_embeddings + >>> multimodal_embeddings = outputs.multimodal_embeddings + + >>> outputs.image_embeddings.shape + torch.Size([1, 197, 768]) + + >>> text_embeddings.shape + torch.Size([1, 7, 768]) + + >>> multimodal_embeddings.shape + torch.Size([1, 205, 768]) + ``` + """ + + return_dict = return_dict if return_dict is not None else self.config.return_dict + if not output_hidden_states: + raise ValueError("FLAVA model requires hidden states to work. Please set `output_hidden_states=True`") + image_embeddings = None + image_states = None + image_mm_projection = None + image_output = None + if pixel_values is not None: + image_output = self.image_model( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + attention_mask=image_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings, image_states = image_output[0], image_output[2] + # Note that these states don't use final layernorm in the transformer model + image_mm_projection = self.image_to_mm_projection(image_states[-1]) + + text_embeddings = None + text_states = None + text_mm_projection = None + text_output = None + if input_ids is not None: + text_output = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + token_type_ids=token_type_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_embeddings, text_states = text_output[0], text_output[2] + # Note that these states don't use final layernorm in the transformer model + text_mm_projection = self.text_to_mm_projection(text_states[-1]) + + multimodal_embeddings = None + multimodal_output = None + if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder: + if attention_mask is not None: + batch_size, seq_len, _ = image_mm_projection.shape + if self.multimodal_model.use_cls_token: + seq_len += 1 + attention_mask_image = torch.ones(batch_size, seq_len, device=image_mm_projection.device) + attention_multimodal = torch.cat([attention_mask_image, attention_mask], dim=1) + else: + attention_multimodal = None + multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1) + multimodal_output = self.multimodal_model( + multimodal_input, attention_mask=attention_multimodal, return_dict=return_dict + ) + multimodal_embeddings = multimodal_output[0] + + if not return_dict: + return ( + image_embeddings, + image_output, + text_embeddings, + text_output, + multimodal_embeddings, + multimodal_output, + ) + + return FlavaModelOutput( + image_embeddings=image_embeddings, + image_output=image_output, + text_embeddings=text_embeddings, + text_output=text_output, + multimodal_embeddings=multimodal_embeddings, + multimodal_output=multimodal_output, + ) + + +class FlavaImageCodebookResPath(nn.Module): + def __init__(self, in_size: int, out_size: int, **kwargs): + super().__init__() + hid_size = out_size // 4 + + path = OrderedDict() + path["relu_1"] = nn.ReLU() + path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1) + path["relu_2"] = nn.ReLU() + path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1) + path["relu_3"] = nn.ReLU() + path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1) + path["relu_4"] = nn.ReLU() + path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0) + + self.path = nn.Sequential(path) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.path(x) + + +class FlavaImageCodebookBlock(nn.Module): + def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs): + super().__init__() + + self.post_gain = 1 / (num_layers**2) + + if in_size != out_size: + self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0) + else: + self.id_path = nn.Identity() + + self.res_path = FlavaImageCodebookResPath(in_size, out_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +class FlavaImageCodebookLayerGroup(nn.Module): + def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True): + super().__init__() + blocks = OrderedDict() + for i in range(num_blocks): + if i == 0: + blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers) + else: + blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers) + + if use_pool: + blocks["pool"] = nn.MaxPool2d(kernel_size=2) + + self.group = nn.Sequential(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.group(x) + + +# Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42 +@auto_docstring( + custom_intro=""" + The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used + to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use + `get_codebook_indices` to get image tokens for an image. + """ +) +class FlavaImageCodebook(FlavaPreTrainedModel): + base_model_prefix = "" + config: FlavaImageCodebookConfig + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + + def __init__( + self, + config: FlavaImageCodebookConfig, + **kwargs: Any, + ): + super().__init__(config) + + self.config = config + self.num_groups = config.num_groups + self.input_channels = config.input_channels + self.num_blocks_per_group = config.num_blocks_per_group + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + + num_layers = self.num_groups * self.num_blocks_per_group + + output_blocks = OrderedDict() + output_blocks["relu"] = nn.ReLU() + output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0) + + blocks = OrderedDict() + blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3) + blocks["group_1"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size + ) + blocks["group_2"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size + ) + blocks["group_3"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size + ) + blocks["group_4"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False + ) + blocks["output"] = nn.Sequential(output_blocks) + + self.blocks = nn.Sequential(blocks) + + self.post_init() + + if self.config.freeze: + for param in self.parameters(): + param.requires_grad = False + + def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor: + f""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing + `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoImageProcessor, FlavaImageCodebook + + >>> model = FlavaImageCodebook.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}") + >>> image_processor = AutoImageProcessor.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt") + >>> inputs = dict(pixel_values=inputs.codebook_pixel_values) + + >>> outputs = model.get_codebook_indices(**inputs) + ``` + """ + z_logits = self.blocks(pixel_values) + return torch.argmax(z_logits, axis=1) + + def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor: + z_logits = self.blocks(pixel_values) + return nn.Softmax(dim=1)(z_logits) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + f""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing + `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoImageProcessor, FlavaImageCodebook + + >>> model = FlavaImageCodebook.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}") + >>> image_processor = AutoImageProcessor.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt") + >>> inputs = dict(pixel_values=inputs.codebook_pixel_values) + + >>> outputs = model(**inputs) + >>> print(outputs.shape) + (1, 196) + ``` + """ + if len(pixel_values.shape) != 4: + raise ValueError(f"input shape {pixel_values.shape} is not 4d") + if pixel_values.shape[1] != self.input_channels: + raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}") + return self.blocks(pixel_values) + + +class FlavaPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class FlavaMaskedPredictionHead(nn.Module): + def __init__(self, config, weight=None): + super().__init__() + self.config = config + self.transform = FlavaPredictionHeadTransform(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + if weight is not None: + self.decoder.weight = weight + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, x): + x = self.transform(x) + x = self.decoder(x) + return x + + +class FlavaITMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pooler = FlavaPooler(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, x): + x = self.pooler(x) + x = self.seq_relationship(x) + return x + + +class FlavaGlobalContrastiveHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.global_backprop_contrastive = config.global_backprop_contrastive + + def forward(self, image_embeddings, text_embeddings, logit_scale): + temperature = torch.exp(logit_scale) + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device) + image_embeddings_all = [image_embeddings] + text_embeddings_all = [text_embeddings] + else: + local_batch_size = image_embeddings.size(0) + world_size = torch.distributed.get_world_size() + + if self.global_backprop_contrastive: + # `torch.distributed.nn.functional.all_gather` does backprop on all active workers + # whereas `torch.distributed.all_gather` does only backpropagates on the current worker. + image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings) + text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings) + else: + image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)] + text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)] + torch.distributed.all_gather(image_embeddings_all, image_embeddings) + torch.distributed.all_gather(text_embeddings_all, text_embeddings) + + labels = local_batch_size * torch.distributed.get_rank() + torch.arange( + local_batch_size, device=image_embeddings.device + ) + + image_embeddings_all = torch.cat(image_embeddings_all) + text_embeddings_all = torch.cat(text_embeddings_all) + + logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature + logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature + + return logits_per_image, logits_per_text, labels + + +@auto_docstring( + custom_intro=""" + The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs. + """ +) +class FlavaForPreTraining(FlavaPreTrainedModel): + # Those are linked to xxx.bias + _tied_weights_keys = [ + "mmm_text_head.decoder.bias", + "mmm_image_head.decoder.bias", + "mlm_head.decoder.bias", + "mim_head.decoder.bias", + ] + + def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None): + r""" + image_codebook ([`nn.Module`]): + If passed, the image codebook will be set to this. Otherwise, it will be initialized using the + image_codebook_config defined in the config first as the first parameter. + """ + super().__init__(config) + self.flava = FlavaModel(config) + + self.image_codebook = image_codebook + if self.image_codebook is None and config.init_codebook: + self.image_codebook = FlavaImageCodebook(config.image_codebook_config) + + # Levarage text and image encoder configs to create the masked + # head since it has the right vocab + self.mim_head = FlavaMaskedPredictionHead(config.image_config) + self.mlm_head = FlavaMaskedPredictionHead(config.text_config) + self.itm_head = FlavaITMHead(config) + self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config) + self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config) + self.global_contrastive_head = FlavaGlobalContrastiveHead(config) + + self.image_vocab_size = config.image_config.vocab_size + self.text_vocab_size = config.text_config.vocab_size + self.mlm_weight = config.mlm_weight + self.mim_weight = config.mim_weight + self.global_contrastive_weight = config.global_contrastive_weight + self.ce_ignore_index = config.ce_ignore_index + self.itm_weight = config.itm_weight + self.mmm_image_weight = config.mmm_image_weight + self.mmm_text_weight = config.mmm_text_weight + self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder + + self.post_init() + + def _resize_to_2d(self, x: torch.Tensor): + if x.dim() > 2: + x = x.view(x.size(0), -1) + return x + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_ids_masked: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + codebook_pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + skip_unmasked_multimodal_encoder: Optional[bool] = None, + mlm_labels: Optional[torch.Tensor] = None, + mim_labels: Optional[torch.Tensor] = None, + itm_labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: bool = True, + return_dict: Optional[bool] = None, + return_loss: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], FlavaForPreTrainingOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + input_ids_masked (`torch.LongTensor` of shape `(batch_size, text_seq_len)`): + Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task + to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with + [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) + codebook_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_image_patches, patch_size, patch_size, 3)`, *optional*): + Pixel values for image patches that are used to compute the image codebook labels for masked image modeling. + token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + image_attention_mask (`torch.FloatTensor` of shape `(batch_size, image_num_patches)`, *optional*): + Mask to avoid performing attention on padding token indices specifically for images. Mask values selected + in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + skip_unmasked_multimodal_encoder (*bool*, *optional*): + Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked + multimodal embeddings or outputs as of now. + mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*): + Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction). + Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with + indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, + ..., text_config.vocab_size - 1]`. + mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*): + Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ..., + image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are + generated automatically using the image codebook assigned to the model. By default, it uses + [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels. + itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. + The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well. + return_loss (`bool`, *optional*, default to None): + Whether to return calculated loss or not. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import FlavaForPreTraining, AutoProcessor + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full") + >>> processor = AutoProcessor.from_pretrained("facebook/flava-full") + + >>> text = ["a photo of a cat"] + + >>> inputs = processor( + ... images=[image], + ... text=text, + ... return_masks=True, + ... return_codebook_pixels=True, + ... padding=True, + ... max_length=77, + ... return_tensors="pt", + ... ) + + + >>> output = model(**inputs) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_loss = return_loss if return_loss is not None else self.config.return_loss + + skip_unmasked_multimodal_encoder = ( + skip_unmasked_multimodal_encoder + if skip_unmasked_multimodal_encoder is not None + else self.skip_unmasked_multimodal_encoder + ) + + if input_ids_masked is None and input_ids is not None: + logger.warning( + "`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to" + " `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if" + " you are doing inference on unmasked text..." + ) + input_ids_masked = input_ids + + flava_output = self.flava( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + image_attention_mask=image_attention_mask, + # Don't need unmasked multimodal embedding for anything so skip it + # NOTE: ITM uses masked version + skip_multimodal_encoder=skip_unmasked_multimodal_encoder, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + # Pass true to have deterministic outputs + return_dict=True, + ) + + flava_masked_output = self.flava( + input_ids=input_ids_masked, + pixel_values=pixel_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + image_attention_mask=image_attention_mask, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pos_mask = None + + image_embeddings = flava_output.image_embeddings + text_embeddings = flava_output.text_embeddings + image_masked_embeddings = flava_masked_output.image_embeddings + text_masked_embeddings = flava_masked_output.text_embeddings + multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings + + total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None + mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None + itm_logits = logits_per_image = logits_per_text = None + + # Calculate mim_labels if necessary from the image_codebook + if image_masked_embeddings is not None or multimodal_masked_embeddings is not None: + if mim_labels is None and return_loss: + if self.image_codebook is None: + raise RuntimeError( + "`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` " + " have been passed. Reinstantiate the model with `init_codebook` set to True or " + "pass in your custom `mim_labels`" + ) + if codebook_pixel_values is None: + raise ValueError( + "`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. " + "Call `AutoProcessor` with `return_codebook_pixels` set to True" + ) + mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values) + # Unimodal MIM Loss + # If multimodal embeddings are present, we will calculate MMM loss + if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None: + sequence_for_image = image_masked_embeddings + + if mim_labels is not None: + mim_labels = self._resize_to_2d(mim_labels) + bool_masked_pos = self._resize_to_2d(bool_masked_pos) + mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index + + sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :] + masked_tokens = mim_labels.ne(self.ce_ignore_index) + mim_labels_filtered = mim_labels[masked_tokens] + sequence_for_image = sequence_for_image[masked_tokens, :] + mim_logits = self.mim_head(sequence_for_image) + if return_loss: + mim_loss = nn.functional.cross_entropy( + mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1) + ) + mim_loss *= self.mim_weight + else: + mim_logits = self.mim_head(sequence_for_image) + + # Unimodal MLM Loss + if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None: + sequence_for_text = text_masked_embeddings + if mlm_labels is not None: + mlm_labels = self._resize_to_2d(mlm_labels) + sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :] + masked_tokens = mlm_labels.ne(self.ce_ignore_index) + mlm_labels_filtered = mlm_labels[masked_tokens] + sequence_for_text = sequence_for_text[masked_tokens, :] + mlm_logits = self.mlm_head(sequence_for_text) + if return_loss: + mlm_loss = nn.functional.cross_entropy( + mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1) + ) + mlm_loss *= self.mlm_weight + else: + mlm_logits = self.mlm_head(sequence_for_text) + + # ITM Loss + if self.itm_weight > 0 and multimodal_masked_embeddings is not None: + itm_logits = self.itm_head(multimodal_masked_embeddings) + + if itm_labels is not None: + pos_pairs = itm_labels.ne(0) + pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True])) + if return_loss: + itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels) + itm_loss *= self.itm_weight + + if multimodal_masked_embeddings is not None: + multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask] + + if mlm_labels is not None: + mlm_labels = mlm_labels[pos_mask] + + if mim_labels is not None: + mim_labels = mim_labels[pos_mask] + bool_masked_pos = bool_masked_pos[pos_mask] + + # MMM Image Loss + if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0: + sequence_for_image = multimodal_masked_embeddings + end_index = image_masked_embeddings.size(1) - 1 + sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :] + + if mim_labels is not None: + mim_labels = self._resize_to_2d(mim_labels) + bool_masked_pos = self._resize_to_2d(bool_masked_pos) + mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index + + masked_tokens = mim_labels.ne(self.ce_ignore_index) + mim_labels_filtered = mim_labels[masked_tokens] + sequence_for_image = sequence_for_image[masked_tokens, :] + mmm_image_logits = self.mmm_image_head(sequence_for_image) + if return_loss: + mmm_image_loss = nn.functional.cross_entropy( + mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1) + ) + mmm_image_loss *= self.mmm_image_weight + else: + mmm_image_logits = self.mmm_image_head(sequence_for_image) + + # MMM Text Loss + if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0: + sequence_for_text = multimodal_masked_embeddings + sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :] + + if mlm_labels is not None: + mlm_labels = self._resize_to_2d(mlm_labels) + masked_tokens = mlm_labels.ne(self.ce_ignore_index) + mlm_labels_filtered = mlm_labels[masked_tokens] + sequence_for_text = sequence_for_text[masked_tokens, :] + mmm_text_logits = self.mmm_text_head(sequence_for_text) + if return_loss: + mmm_text_loss = nn.functional.cross_entropy( + mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1) + ) + mmm_text_loss *= self.mmm_text_weight + else: + mmm_text_logits = self.mmm_text_head(sequence_for_text) + + # Global Contrastive Loss + if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0: + text_embedding = self.flava.text_projection(text_embeddings[:, 0, :]) + text_embedding = nn.functional.normalize(text_embedding, dim=-1) + + image_embedding = self.flava.image_projection(image_embeddings[:, 0, :]) + image_embedding = nn.functional.normalize(image_embedding, dim=-1) + + self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX) + + logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head( + image_embedding, text_embedding, self.flava.logit_scale + ) + + # Apply ITM negative mask if any + if pos_mask is not None: + logits_per_image = logits_per_image[pos_mask] + logits_per_text = logits_per_text[pos_mask] + gc_labels = gc_labels[pos_mask] + + if return_loss: + gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels) + gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels) + gc_loss = (gc_loss_image + gc_loss_text) / 2 + gc_loss *= self.global_contrastive_weight + + flava_losses = FlavaLosses( + mim=mim_loss, + mlm=mlm_loss, + itm=itm_loss, + global_contrastive=gc_loss, + mmm_image=mmm_image_loss, + mmm_text=mmm_text_loss, + ) + + if return_loss and not flava_losses.all_none(): + total_loss = sum(loss if loss is not None else 0 for loss in flava_losses.values()) + + if not return_dict: + output = ( + image_embeddings, + flava_output.image_output.to_tuple() if flava_output.image_output is not None else None, + text_embeddings, + flava_output.text_output.to_tuple() if flava_output.text_output is not None else None, + flava_output.multimodal_embeddings, + flava_output.multimodal_output.to_tuple() if flava_output.multimodal_output is not None else None, + image_masked_embeddings, + flava_masked_output.image_output.to_tuple() if flava_masked_output.image_output is not None else None, + text_masked_embeddings, + flava_masked_output.text_output.to_tuple() if flava_masked_output.text_output is not None else None, + multimodal_masked_embeddings, + flava_masked_output.multimodal_output.to_tuple() + if flava_masked_output.multimodal_output is not None + else None, + mim_logits, + mlm_logits, + itm_logits, + logits_per_image, + logits_per_image, + mmm_image_logits, + mmm_text_logits, + ) + if return_loss and not flava_losses.all_none(): + output = ( + total_loss, + flava_losses, + ) + output + + # Filter None as transformer by default won't handle it + return tuple(x for x in output if x is None) + + return FlavaForPreTrainingOutput( + loss=total_loss, + loss_info=flava_losses, + image_embeddings=image_embeddings, + image_output=flava_output.image_output, + text_embeddings=text_embeddings, + text_output=flava_output.text_output, + multimodal_embeddings=flava_output.multimodal_embeddings, + multimodal_output=flava_output.multimodal_output, + image_masked_embeddings=image_masked_embeddings, + image_masked_output=flava_masked_output.image_output, + text_masked_embeddings=text_masked_embeddings, + text_masked_output=flava_masked_output.text_output, + multimodal_masked_embeddings=multimodal_masked_embeddings, + multimodal_masked_output=flava_masked_output.multimodal_output, + mim_logits=mim_logits, + mlm_logits=mlm_logits, + itm_logits=itm_logits, + contrastive_logits_per_image=logits_per_image, + contrastive_logits_per_text=logits_per_text, + mmm_image_logits=mmm_image_logits, + mmm_text_logits=mmm_text_logits, + ) + + +__all__ = [ + "FlavaForPreTraining", + "FlavaImageCodebook", + "FlavaImageModel", + "FlavaModel", + "FlavaMultimodalModel", + "FlavaPreTrainedModel", + "FlavaTextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/flava/processing_flava.py b/venv/lib/python3.13/site-packages/transformers/models/flava/processing_flava.py new file mode 100644 index 0000000000000000000000000000000000000000..ceebdb6efa49134c3d078a8dd03b08da7c59fb9d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/flava/processing_flava.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for FLAVA +""" + +import warnings +from collections.abc import Iterable +from typing import Optional, Union + +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin + + +class FlavaImagesKwargs(ImagesKwargs): + # Mask related params + return_image_mask: Optional[bool] + input_size_patches: Optional[int] + total_mask_patches: Optional[int] + mask_group_min_patches: Optional[int] + mask_group_max_patches: Optional[int] + mask_group_min_aspect_ratio: Optional[float] + mask_group_max_aspect_ratio: Optional[float] + # Codebook related params + return_codebook_pixels: Optional[bool] + codebook_do_resize: Optional[bool] + codebook_size: Optional[bool] + codebook_resample: Optional[int] + codebook_do_center_crop: Optional[bool] + codebook_crop_size: Optional[int] + codebook_do_rescale: Optional[bool] + codebook_rescale_factor: Optional[Union[int, float]] + codebook_do_map_pixels: Optional[bool] + codebook_do_normalize: Optional[bool] + codebook_image_mean: Optional[Union[float, Iterable[float]]] + codebook_image_std: Optional[Union[float, Iterable[float]]] + + +class FlavaProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: FlavaImagesKwargs + _defaults = {} + + +class FlavaProcessor(ProcessorMixin): + r""" + Constructs a FLAVA processor which wraps a FLAVA image processor and a FLAVA tokenizer into a single processor. + + [`FlavaProcessor`] offers all the functionalities of [`FlavaImageProcessor`] and [`BertTokenizerFast`]. See the + [`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information. + + Args: + image_processor ([`FlavaImageProcessor`], *optional*): The image processor is a required input. + tokenizer ([`BertTokenizerFast`], *optional*): The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "FlavaImageProcessor" + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + valid_processor_kwargs = FlavaProcessorKwargs + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + @property + def feature_extractor(self): + warnings.warn( + "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", + FutureWarning, + ) + return self.image_processor + + +__all__ = ["FlavaProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/florence2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/florence2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..686295d1fbd952e124bab885980ac682261a07fb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/florence2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_florence2 import * + from .modeling_florence2 import * + from .processing_florence2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/florence2/configuration_florence2.py b/venv/lib/python3.13/site-packages/transformers/models/florence2/configuration_florence2.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbd4b3a03e9f2c591e77689e1606c60214ab430 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/florence2/configuration_florence2.py @@ -0,0 +1,215 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/florence2/modular_florence2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_florence2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class Florence2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Florence2VisionModel architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + in_channels (`int`, *optional*, defaults to 3): + Number of input image channels. + depths (`Tuple[int]`, *optional*, defaults to `(1, 1, 9, 1)`): + The depth of the model. + patch_size (`Tuple[int]`, *optional*, defaults to `(7, 3, 3, 3)`): + The patch size of the image. + patch_stride (`Tuple[int]`, *optional*, defaults to `(4, 2, 2, 2)`): + The patch stride of the image. + patch_padding (`Tuple[int]`, *optional*, defaults to `(3, 1, 1, 1)`): + The patch padding of the image. + patch_prenorm (`Tuple[bool]`, *optional*, defaults to `(False, True, True, True)`): + Whether to apply layer normalization before the patch embedding layer. + embed_dim (`Tuple[int]`, *optional*, defaults to `(128, 256, 512, 1024)`): + The dimension of the embedding layer. + num_heads (`Tuple[int]`, *optional*, defaults to `(4, 8, 16, 32)`): + The number of attention heads. + num_groups (`Tuple[int]`, *optional*, defaults to `(4, 8, 16, 32)`): + The number of groups. + window_size (`int`, *optional*, defaults to 12): + The window size of the model. + drop_path_rate (`float`, *optional*, defaults to 0.1): + The dropout rate of the drop path layer. + mlp_ratio (`int`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + qkv_bias (`bool`, *optional*, defaults to `True`): + If True, add a learnable bias to query, key, value. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + projection_dim (`int`, *optional*, defaults to 1024): + The dimension of the projection layer. + max_temporal_embeddings (`int`, *optional*, defaults to 100): + The configuration of the visual temporal embedding. + max_position_embeddings (`int`, *optional*, defaults to 50): + The configuration of the image position embedding. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Example: + + ```python + >>> from transformers import Florence2VisionConfig, Florence2VisionModel + + >>> # Initializing a Florence2 Vision style configuration + >>> configuration = Florence2VisionConfig() + + >>> # Initializing a model (with random weights) + >>> model = Florence2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence_vision" + + def __init__( + self, + in_channels=3, + depths=(1, 1, 9, 1), + patch_size=(7, 3, 3, 3), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 1, 1, 1), + patch_prenorm=(False, True, True, True), + embed_dim=(128, 256, 512, 1024), + num_heads=(4, 8, 16, 32), + num_groups=(4, 8, 16, 32), + window_size=12, + drop_path_rate=0.1, + mlp_ratio=4.0, + qkv_bias=True, + activation_function="gelu", + projection_dim=1024, + max_temporal_embeddings=100, + max_position_embeddings=50, + initializer_range=0.02, + **kwargs, + ): + self.in_channels = in_channels + self.depths = list(depths) + self.patch_size = list(patch_size) + self.patch_stride = list(patch_stride) + self.patch_padding = list(patch_padding) + self.patch_prenorm = list(patch_prenorm) + self.embed_dim = list(embed_dim) + self.num_heads = list(num_heads) + self.num_groups = list(num_groups) + self.window_size = window_size + self.drop_path_rate = drop_path_rate + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.projection_dim = projection_dim + self.max_temporal_embeddings = max_temporal_embeddings + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.activation_function = activation_function + + super().__init__(**kwargs) + + +class Florence2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an + Florence-2 model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the Florence-2 + [microsoft/Florence-2-base](https://huggingface.co/microsoft/Florence-2-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`AutoConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Florence2VisionConfig`]. + image_token_id (`int`, *optional*, defaults to 51289): + The image token index to encode the image prompt. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + + Example: + + ```python + >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig + + >>> # Initializing a clip-like vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Bart config + >>> text_config = BartConfig() + + >>> # Initializing a Florence-2 configuration + >>> configuration = Florence2Config(vision_config, text_config) + + >>> # Initializing a model from the florence-2 configuration + >>> model = Florence2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence2" + sub_configs = { + "text_config": AutoConfig, + "vision_config": Florence2VisionConfig, + } + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=51289, + is_encoder_decoder=True, + **kwargs, + ): + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "bart") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["bart"]() + + if isinstance(vision_config, dict): + vision_config = Florence2VisionConfig(**vision_config) + elif vision_config is None: + logger.info("vision_config is None. Initializing the Florence2VisionConfig with default values.") + vision_config = Florence2VisionConfig() + + self.text_config = text_config + self.vision_config = vision_config + self.image_token_id = image_token_id + + super().__init__( + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +__all__ = ["Florence2Config", "Florence2VisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/florence2/modeling_florence2.py b/venv/lib/python3.13/site-packages/transformers/models/florence2/modeling_florence2.py new file mode 100644 index 0000000000000000000000000000000000000000..763756faf73f6b9172238ff95ab18285f465c28d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/florence2/modeling_florence2.py @@ -0,0 +1,1055 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/florence2/modular_florence2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_florence2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + is_torch_available, + logging, +) +from ..auto import AutoModel +from .configuration_florence2 import Florence2Config, Florence2VisionConfig + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class Florence2VisionDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class Florence2VisionLearnedAbsolutePositionEmbedding2D(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, config: Florence2Config): + super().__init__() + num_pos = config.vision_config.max_position_embeddings + embedding_dim = config.vision_config.embed_dim[-1] + self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) + self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) + + def forward(self, pixel_values, pixel_mask=None): + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + return pos + + +class Florence2VisionPositionalEmbeddingCosine1D(nn.Module): + """ + This module generates 1D cosine positional embeddings using precomputed sinusoidal functions. + """ + + def __init__(self, config: Florence2Config): + super().__init__() + self.embed_dim = config.vision_config.embed_dim[-1] + self.max_seq_len = config.vision_config.max_temporal_embeddings + pos_idx_to_embed = torch.empty((self.max_seq_len, self.embed_dim)) + sine, cosine = self.get_sinusoid_embeddings( + max_positions=self.max_seq_len, + embed_dim=self.embed_dim, + ) + pos_idx_to_embed[:, 0::2] = sine + pos_idx_to_embed[:, 1::2] = cosine + # Save the positional embeddings in a constant buffer. + self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) + + @staticmethod + def get_sinusoid_embeddings(max_positions: int, embed_dim: int): + half_dim = embed_dim // 2 + emb = math.log(10000) / half_dim + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(max_positions, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + return torch.sin(emb), torch.cos(emb) + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + len_seq = seq_embeds.size(1) + if len_seq > self.max_seq_len: + raise ValueError(f"Maximum sequence length {self.max_seq_len}, got {len_seq}") + pos_embeds = self.pos_idx_to_embed[0:len_seq, :] + return pos_embeds + + +class Florence2VisionMLP(nn.Module): + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.activation_function] + self.fc1 = nn.Linear(config.embed_dim[stage_idx], int(config.embed_dim[stage_idx] * config.mlp_ratio)) + self.fc2 = nn.Linear(int(config.embed_dim[stage_idx] * config.mlp_ratio), config.embed_dim[stage_idx]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Florence2VisionConvEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__() + self.config = config + self.stage_idx = stage_idx + self.patch_size = config.patch_size[stage_idx] + self.in_channels = config.in_channels if stage_idx == 0 else config.embed_dim[stage_idx - 1] + self.embed_dim = config.embed_dim[stage_idx] + self.stride = config.patch_stride[stage_idx] + self.padding = config.patch_padding[stage_idx] + self.pre_norm = config.patch_prenorm[stage_idx] + + self.conv = nn.Conv2d( + self.in_channels, + self.embed_dim, + kernel_size=self.patch_size, + stride=self.stride, + padding=self.padding, + ) + + dim_norm = self.in_channels if self.pre_norm else self.embed_dim + self.norm = nn.LayerNorm(dim_norm) + + def forward(self, hidden_states: torch.Tensor): + if self.norm and self.pre_norm: + hidden_states = hidden_states.permute(0, 2, 3, 1) + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 1, 2) + + hidden_states = self.conv(hidden_states) + + if self.norm and not self.pre_norm: + hidden_states = hidden_states.permute(0, 2, 3, 1) + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 1, 2) + return hidden_states + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Florence2VisionChannelAttention(nn.Module): + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__() + self.config = config + self.dim = config.embed_dim[stage_idx] + self.groups = config.num_groups[stage_idx] + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=config.qkv_bias) + self.proj = nn.Linear(self.dim, self.dim) + self.is_causal = False + + def forward(self, hidden_states: torch.Tensor): + batch_size, num_tokens, hidden_size = hidden_states.shape + + # Reshape for grouped channel attention + qkv = self.qkv(hidden_states).reshape(batch_size, num_tokens, 3, self.groups, hidden_size // self.groups) + qkv = qkv.permute(2, 0, 3, 4, 1) + query, key, value = qkv.unbind(0) + + scale = num_tokens**-0.5 + # Channel-to-channel attention within groups: + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + hidden_states, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + scaling=scale, + ) + hidden_states = hidden_states.permute(0, 3, 2, 1) + hidden_states = hidden_states.reshape(batch_size, num_tokens, hidden_size) + + # Final projection + hidden_states = self.proj(hidden_states) + return hidden_states + + +class Florence2VisionChannelBlock(nn.Module): + def __init__( + self, + config: Florence2VisionConfig, + stage_idx: int, + drop_path_rate: float, + ): + super().__init__() + + self.config = config + dim_in = config.embed_dim[stage_idx] + + self.conv1 = nn.Conv2d( + dim_in, + dim_in, + kernel_size=3, + padding=1, + groups=dim_in, + ) + self.norm1 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.channel_attn = Florence2VisionChannelAttention(config=config, stage_idx=stage_idx) + self.drop_path1 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + self.conv2 = nn.Conv2d( + dim_in, + dim_in, + kernel_size=3, + padding=1, + groups=dim_in, + ) + self.norm2 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.ffn = Florence2VisionMLP(config=config, stage_idx=stage_idx) + self.drop_path2 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, hidden_states: torch.Tensor): + batch_size, embed_dim, height, width = hidden_states.shape + + # First channel block: Depthwise Conv + Channel Attention + hidden_states = self.conv1(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # Channel group attention self-attention mechanism + hidden_states = self.norm1(hidden_states) + hidden_states = self.channel_attn(hidden_states) + hidden_states = residual + self.drop_path1(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + # Second channel block: Depthwise Conv + FFN + hidden_states = self.conv2(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # FFN + hidden_states = self.norm2(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + self.drop_path2(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + return hidden_states + + +class Florence2VisionWindowAttention(nn.Module): + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__() + self.config = config + self.dim = config.embed_dim[stage_idx] + self.window_size = config.window_size + self.num_heads = config.num_heads[stage_idx] + head_dim = self.dim // self.num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=config.qkv_bias) + self.proj = nn.Linear(self.dim, self.dim) + self.is_causal = False + + def forward(self, hidden_states: torch.Tensor): + batch_size, height, width, embed_dim = hidden_states.shape + + # Pad the input if necessary + pad_left = pad_top = 0 + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + hidden_states = F.pad(hidden_states, (0, 0, pad_left, pad_right, pad_top, pad_bottom)) + _, padded_height, padded_width, _ = hidden_states.shape + + # Partition input into non-overlapping windows (for local spatial attention in DaViT) + hidden_states = hidden_states.view( + batch_size, + padded_height // self.window_size, + self.window_size, + padded_width // self.window_size, + self.window_size, + embed_dim, + ) + windowed_hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous() + windowed_hidden_states = windowed_hidden_states.view(-1, self.window_size * self.window_size, embed_dim) + + # Generate Q, K, V for each window + num_windows_per_batch, num_tokens_per_window, embed_dim = windowed_hidden_states.shape + qkv = self.qkv(windowed_hidden_states).reshape( + num_windows_per_batch, num_tokens_per_window, 3, self.num_heads, embed_dim // self.num_heads + ) + qkv = qkv.permute(2, 0, 3, 1, 4) + query, key, value = qkv.unbind(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + windowed_hidden_states, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + scaling=self.scale, + ) + windowed_hidden_states = windowed_hidden_states.view(num_windows_per_batch, num_tokens_per_window, embed_dim) + windowed_hidden_states = self.proj(windowed_hidden_states) + + # Merge windows back to original spatial layout + windowed_hidden_states = windowed_hidden_states.view(-1, self.window_size, self.window_size, embed_dim) + hidden_states = windowed_hidden_states.view( + -1, + padded_height // self.window_size, + padded_width // self.window_size, + self.window_size, + self.window_size, + embed_dim, + ) + hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous() + hidden_states = hidden_states.view(-1, padded_height, padded_width, embed_dim) + hidden_states = hidden_states[:, :height, :width, :].contiguous() + hidden_states = hidden_states.view(batch_size, height * width, embed_dim) + + return hidden_states + + +class Florence2VisionSpatialBlock(nn.Module): + def __init__( + self, + config: Florence2VisionConfig, + stage_idx: int, + drop_path_rate: float, + ): + super().__init__() + + self.conv1 = nn.Conv2d( + config.embed_dim[stage_idx], + config.embed_dim[stage_idx], + kernel_size=3, + padding=1, + groups=config.embed_dim[stage_idx], + ) + self.norm1 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.window_attn = Florence2VisionWindowAttention(config=config, stage_idx=stage_idx) + self.drop_path1 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + self.conv2 = nn.Conv2d( + config.embed_dim[stage_idx], + config.embed_dim[stage_idx], + kernel_size=3, + padding=1, + groups=config.embed_dim[stage_idx], + ) + self.norm2 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.ffn = Florence2VisionMLP(config=config, stage_idx=stage_idx) + self.drop_path2 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, hidden_states: torch.Tensor): + batch_size, embed_dim, height, width = hidden_states.shape + + # First spatial mixing block: Conv + Window Attention + hidden_states = self.conv1(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # Spatial Window-based self-attention mechanism + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states.view(batch_size, height, width, embed_dim) + hidden_states = self.window_attn(hidden_states) + hidden_states = residual + self.drop_path1(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + # Second spatial mixing block: Conv + FFN + hidden_states = self.conv2(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # FFN + hidden_states = self.norm2(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + self.drop_path2(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + return hidden_states + + +class Florence2VisionBlock(nn.Module): + def __init__( + self, + config: Florence2VisionConfig, + stage_idx: int, + spatial_drop_path_rate: float, + channel_drop_path_rate: float, + ): + super().__init__() + self.spatial_block = Florence2VisionSpatialBlock( + config=config, + stage_idx=stage_idx, + drop_path_rate=spatial_drop_path_rate, + ) + self.channel_block = Florence2VisionChannelBlock( + config=config, + stage_idx=stage_idx, + drop_path_rate=channel_drop_path_rate, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.spatial_block(hidden_states) + hidden_states = self.channel_block(hidden_states) + return hidden_states + + +@auto_docstring +class Florence2VisionPreTrainedModel(PreTrainedModel): + config_class = Florence2VisionConfig + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + +@auto_docstring +class Florence2VisionBackbone(Florence2VisionPreTrainedModel): + def __init__(self, config: Florence2VisionConfig): + super().__init__(config) + self.config = config + + self.embed_dim = config.embed_dim + self.num_heads = config.num_heads + self.num_groups = config.num_groups + self.num_stages = len(self.embed_dim) + + if not (self.num_stages == len(self.num_heads) == len(self.num_groups)): + raise ValueError( + f"Expected self.num_stages ({self.num_stages}) == " + f"len(self.num_heads) ({len(self.num_heads)}) == " + f"len(self.num_groups) ({len(self.num_groups)})" + ) + + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths) * 2, device="cpu")] + depth_offset = 0 + + convs = [] + blocks = [] + for stage_idx in range(self.num_stages): + conv_embed = Florence2VisionConvEmbed( + config=config, + stage_idx=stage_idx, + ) + convs.append(conv_embed) + + block = nn.ModuleList( + Florence2VisionBlock( + config=config, + stage_idx=stage_idx, + spatial_drop_path_rate=dpr[depth_offset + block_idx * 2], + channel_drop_path_rate=dpr[depth_offset + block_idx * 2 + 1], + ) + for block_idx in range(config.depths[stage_idx]) + ) + blocks.append(block) + depth_offset += config.depths[stage_idx] * 2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, hidden_states: torch.Tensor): + for conv, block in zip(self.convs, self.blocks): + hidden_states = conv(hidden_states) + for layer in block: + hidden_states = layer(hidden_states) + return hidden_states + + +class Florence2MultiModalProjector(nn.Module): + def __init__(self, config: Florence2Config): + super().__init__() + self.vision_embedding_dim = config.vision_config.embed_dim[-1] + self.vision_projection_dim = config.vision_config.projection_dim + self.image_projection = nn.Linear(self.vision_embedding_dim, self.vision_projection_dim, bias=False) + self.image_proj_norm = nn.LayerNorm(self.vision_projection_dim) + self.image_position_embed = Florence2VisionLearnedAbsolutePositionEmbedding2D(config=config) + self.visual_temporal_embed = Florence2VisionPositionalEmbeddingCosine1D(config=config) + + def forward(self, image_features): + position_features = image_features + self.image_position_embed(image_features) + position_features = position_features.flatten(2).transpose(1, 2) + temporal_features = self.visual_temporal_embed(position_features[:, :1, :]) + temporal_features = temporal_features.unsqueeze(1) + visual_token_features = position_features + temporal_features + visual_token_features = visual_token_features.unsqueeze(1) + spatial_image_features = visual_token_features.mean(dim=2) + temporal_image_features = visual_token_features.mean(dim=1) + image_features = torch.cat([spatial_image_features, temporal_image_features], dim=1) + image_features = self.image_projection(image_features) + image_features = self.image_proj_norm(image_features) + return image_features + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput): + r""" + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@auto_docstring +class Florence2PreTrainedModel(PreTrainedModel): + config: Florence2Config + base_model_prefix = "" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_flex_attn = True + + _supports_attention_backend = False + config_class = Florence2Config + + +@auto_docstring( + custom_intro=""" + Florence-2 is a vision model for captioning, detection, and segmentation. + """ +) +class Florence2Model(Florence2PreTrainedModel): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = [ + "language_model.encoder.embed_tokens.weight", + "language_model.decoder.embed_tokens.weight", + ] + + def __init__(self, config: Florence2Config): + super().__init__(config) + self.vision_tower = Florence2VisionBackbone(config=config.vision_config) + + self.multi_modal_projector = Florence2MultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_image_features(self, pixel_values: torch.Tensor, **kwargs): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_features = self.vision_tower(pixel_values, **kwargs) + image_embeds = self.multi_modal_projector(image_features) + return image_embeds + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Florence2Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + encoder_outputs = self.language_model.encoder( + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + if decoder_input_ids is None: + decoder_start_token_id = self.config.text_config.decoder_start_token_id + decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device) + decoder_input_ids *= decoder_start_token_id + + decoder_outputs = self.language_model.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=True, + ) + + return Florence2Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def get_encoder(self): + return self.language_model.get_encoder() + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +@auto_docstring( + custom_intro=""" + Florence-2 is a vision model for captioning, detection, and segmentation. + """ +) +class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = [ + "model.language_model.encoder.embed_tokens.weight", + "model.language_model.decoder.embed_tokens.weight", + "lm_head.weight", + ] + + def __init__(self, config: Florence2Config): + super().__init__(config) + self.model = Florence2Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values: torch.Tensor, **kwargs): + return self.model.get_image_features(pixel_values=pixel_values, **kwargs) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Florence2Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration + + >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") + >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") + + >>> prompt = "" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=100) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A green car parked in front of a yellow building." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.text_config.pad_token_id, self.config.text_config.decoder_start_token_id + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + # **kwargs, ## TODO: add back when Bart attention is refactored and takes kwargs + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return Florence2Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + def get_encoder(self): + return self.model.get_encoder() + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + return self.model.get_placeholder_mask( + input_ids=input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str], + generation_config, + ) -> dict[str, Any]: + # override to handle merging image and text embeddings before passing to language encoder + inputs_embeds = model_kwargs.pop("inputs_embeds", None) + pixel_values = model_kwargs.pop("pixel_values", None) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(inputs_tensor) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + inputs_tensor, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + model_kwargs["inputs_embeds"] = inputs_embeds + model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation( + None, model_kwargs, model_input_name, generation_config + ) + model_kwargs.pop("inputs_embeds", None) + return model_kwargs + + +__all__ = [ + "Florence2Model", + "Florence2ForConditionalGeneration", + "Florence2PreTrainedModel", + "Florence2VisionBackbone", + "Florence2VisionPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/florence2/modular_florence2.py b/venv/lib/python3.13/site-packages/transformers/models/florence2/modular_florence2.py new file mode 100644 index 0000000000000000000000000000000000000000..f8732257f10249a847f39fc94630d2f40718cfc9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/florence2/modular_florence2.py @@ -0,0 +1,1806 @@ +# coding=utf-8 +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import re +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import MultiModalData, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, logging +from ..auto import CONFIG_MAPPING, AutoConfig +from ..bart.modeling_bart import eager_attention_forward, shift_tokens_right +from ..beit.modeling_beit import BeitDropPath +from ..llama4.modeling_llama4 import Llama4VisionMLP +from ..llava.modeling_llava import LlavaForConditionalGeneration, LlavaModel, LlavaPreTrainedModel +from ..llava.processing_llava import LlavaProcessorKwargs + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class Florence2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Florence2VisionModel architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + in_channels (`int`, *optional*, defaults to 3): + Number of input image channels. + depths (`Tuple[int]`, *optional*, defaults to `(1, 1, 9, 1)`): + The depth of the model. + patch_size (`Tuple[int]`, *optional*, defaults to `(7, 3, 3, 3)`): + The patch size of the image. + patch_stride (`Tuple[int]`, *optional*, defaults to `(4, 2, 2, 2)`): + The patch stride of the image. + patch_padding (`Tuple[int]`, *optional*, defaults to `(3, 1, 1, 1)`): + The patch padding of the image. + patch_prenorm (`Tuple[bool]`, *optional*, defaults to `(False, True, True, True)`): + Whether to apply layer normalization before the patch embedding layer. + embed_dim (`Tuple[int]`, *optional*, defaults to `(128, 256, 512, 1024)`): + The dimension of the embedding layer. + num_heads (`Tuple[int]`, *optional*, defaults to `(4, 8, 16, 32)`): + The number of attention heads. + num_groups (`Tuple[int]`, *optional*, defaults to `(4, 8, 16, 32)`): + The number of groups. + window_size (`int`, *optional*, defaults to 12): + The window size of the model. + drop_path_rate (`float`, *optional*, defaults to 0.1): + The dropout rate of the drop path layer. + mlp_ratio (`int`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + qkv_bias (`bool`, *optional*, defaults to `True`): + If True, add a learnable bias to query, key, value. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + projection_dim (`int`, *optional*, defaults to 1024): + The dimension of the projection layer. + max_temporal_embeddings (`int`, *optional*, defaults to 100): + The configuration of the visual temporal embedding. + max_position_embeddings (`int`, *optional*, defaults to 50): + The configuration of the image position embedding. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Example: + + ```python + >>> from transformers import Florence2VisionConfig, Florence2VisionModel + + >>> # Initializing a Florence2 Vision style configuration + >>> configuration = Florence2VisionConfig() + + >>> # Initializing a model (with random weights) + >>> model = Florence2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence_vision" + + def __init__( + self, + in_channels=3, + depths=(1, 1, 9, 1), + patch_size=(7, 3, 3, 3), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 1, 1, 1), + patch_prenorm=(False, True, True, True), + embed_dim=(128, 256, 512, 1024), + num_heads=(4, 8, 16, 32), + num_groups=(4, 8, 16, 32), + window_size=12, + drop_path_rate=0.1, + mlp_ratio=4.0, + qkv_bias=True, + activation_function="gelu", + projection_dim=1024, + max_temporal_embeddings=100, + max_position_embeddings=50, + initializer_range=0.02, + **kwargs, + ): + self.in_channels = in_channels + self.depths = list(depths) + self.patch_size = list(patch_size) + self.patch_stride = list(patch_stride) + self.patch_padding = list(patch_padding) + self.patch_prenorm = list(patch_prenorm) + self.embed_dim = list(embed_dim) + self.num_heads = list(num_heads) + self.num_groups = list(num_groups) + self.window_size = window_size + self.drop_path_rate = drop_path_rate + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.projection_dim = projection_dim + self.max_temporal_embeddings = max_temporal_embeddings + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.activation_function = activation_function + + super().__init__(**kwargs) + + +class Florence2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an + Florence-2 model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the Florence-2 + [microsoft/Florence-2-base](https://huggingface.co/microsoft/Florence-2-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`AutoConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Florence2VisionConfig`]. + image_token_id (`int`, *optional*, defaults to 51289): + The image token index to encode the image prompt. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + + Example: + + ```python + >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig + + >>> # Initializing a clip-like vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Bart config + >>> text_config = BartConfig() + + >>> # Initializing a Florence-2 configuration + >>> configuration = Florence2Config(vision_config, text_config) + + >>> # Initializing a model from the florence-2 configuration + >>> model = Florence2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence2" + sub_configs = { + "text_config": AutoConfig, + "vision_config": Florence2VisionConfig, + } + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=51289, + is_encoder_decoder=True, + **kwargs, + ): + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "bart") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["bart"]() + + if isinstance(vision_config, dict): + vision_config = Florence2VisionConfig(**vision_config) + elif vision_config is None: + logger.info("vision_config is None. Initializing the Florence2VisionConfig with default values.") + vision_config = Florence2VisionConfig() + + self.text_config = text_config + self.vision_config = vision_config + self.image_token_id = image_token_id + + super().__init__( + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class Florence2ProcessorKwargs(LlavaProcessorKwargs): + pass + + +class Florence2Processor(ProcessorMixin): + r""" + Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor. + + [`Florence2Processor`] offers all the functionalities of [`AutoImageProcessor`] and [`BartTokenizerFast`]. See the + [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information. + + Args: + image_processor (`AutoImageProcessor`, *optional*): + The image processor is a required input. + tokenizer (`Union[BartTokenizer, BartTokenizerFast]`, *optional*): + The tokenizer is a required input. + num_additional_image_tokens (`int`, *optional*, defaults to 0): + Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other + extra tokens appended, no need to set this arg. + post_processor_config (`dict`, *optional*, defaults to 0): + Task-specific parsing rules for [`Florence2PostProcessor`], e.g. regex patterns, + thresholds, or banned tokens. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("BartTokenizer", "BartTokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + num_additional_image_tokens: int = 0, + post_processor_config: Optional[dict] = None, + **kwargs, + ): + self.tasks_answer_post_processing_type = { + "": "pure_text", + "": "ocr", + "": "pure_text", + "": "pure_text", + "": "pure_text", + "": "description_with_bboxes", + "": "description_with_bboxes", + "": "phrase_grounding", + "": "polygons", + "": "polygons", + "": "description_with_bboxes_or_polygons", + "": "pure_text", + "": "pure_text", + "": "pure_text", + "": "bboxes", + } + + self.task_prompts_without_inputs = { + "": "What is the text in the image?", + "": "What is the text in the image, with regions?", + "": "What does the image describe?", + "": "Describe in detail what is shown in the image.", + "": "Describe with a paragraph what is shown in the image.", + "": "Locate the objects with category name in the image.", + "": "Locate the objects in the image, with their descriptions.", + "": "Locate the region proposals in the image.", + } + + self.task_prompts_with_input = { + "": "Locate the phrases in the caption: {input}", + "": "Locate {input} in the image with mask", + "": "What is the polygon mask of region {input}", + "": "Locate {input} in the image.", + "": "What is the region {input}?", + "": "What does the region {input} describe?", + "": "What text is in the region {input}?", + } + + self.num_image_tokens = image_processor.image_seq_length + self.num_additional_image_tokens = num_additional_image_tokens + self.post_processor_config = post_processor_config + self.post_processor = Florence2PostProcessor(config=post_processor_config, tokenizer=tokenizer) + self.image_token = tokenizer.image_token + self.image_token_id = tokenizer.image_token_id + + super().__init__(image_processor, tokenizer, **kwargs) + + def _construct_prompts(self, text: Union[str, list[str]]) -> list[str]: + """ + Construct prompts by replacing task tokens with corresponding prompt strings. + """ + if isinstance(text, str): + text = [text] + + prompts = [] + for prompt in text: + # Check for tasks without inputs + for task_token, task_prompt in self.task_prompts_without_inputs.items(): + if task_token in prompt: + if prompt != task_token: + raise ValueError(f"Task token {task_token} should be the only content in the prompt.") + prompt = task_prompt + break + # Check for tasks with inputs + for task_token, task_prompt in self.task_prompts_with_input.items(): + if task_token in prompt: + input_text = prompt.replace(task_token, "").strip() + prompt = task_prompt.format(input=input_text) + break + prompts.append(prompt) + return prompts + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + **kwargs: Unpack[Florence2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is None and text is None: + raise ValueError("You have to specify at least one of `images` or `text`.") + + output_kwargs = self._merge_kwargs( + Florence2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + image_inputs = {} + if images is not None: + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text is None: + logger.warning_once("You are using Florence-2 without a text prefix.") + text = [""] * (1 if not isinstance(images, list) else len(images)) + elif isinstance(text, str): + text = [text] + + if not isinstance(text, list) or not all(isinstance(token, str) for token in text): + raise ValueError("`text` must be a string or list of strings.") + + if isinstance(images, list) and len(images) != len(text): + raise ValueError(f"Number of images ({len(images)}) must match number of texts ({len(text)}).") + + prompt_strings = self._construct_prompts(text) + + # Add image tokens and special tokens if images are provided + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + expanded_image_prompts = [] + for sample in prompt_strings: + sample = ( + self.image_token * self.num_image_tokens + + self.tokenizer.bos_token + + sample + + self.tokenizer.eos_token + ) + expanded_image_prompts.append(sample) + prompt_strings = expanded_image_prompts + + # Construct and tokenize prompts + output_kwargs["text_kwargs"].pop("add_special_tokens", None) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer( + prompt_strings, **output_kwargs["text_kwargs"], add_special_tokens=False, return_tensors=None + ) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**image_inputs, **text_inputs}, tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + num_image_tokens = [self.image_seq_length] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=False, **kwargs): + """ + Post-processes the output of `FuyuForConditionalGeneration` to only return the text output. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + containing the token ids of the generated sequences. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text output. + """ + return self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + + def post_process_generation(self, text=None, sequence=None, task=None, image_size=None) -> dict[str, Any]: + """ + Post-process generation outputs based on the task. + + Args: + text (`str`, *optional*): + Generated text. + sequence (`Union[List[int], torch.Tensor]`, *optional*): + Generated token sequence. + task (`str`, *optional*): + The task for post-processing. + image_size (`Tuple[int, int]`, *optional*): + Image size for dequantization. + + Returns: + `Dict[str, Any]`: Post-processed results keyed by task. + """ + if task is None: + raise ValueError("`task` must be provided for post-processing.") + + post_proc_type = self.tasks_answer_post_processing_type.get(task, "pure_text") + parsed = self.post_processor( + text=text, + sequence=sequence, + image_size=image_size, + parse_tasks=[post_proc_type], + )[post_proc_type] + + if post_proc_type == "pure_text": + final_answer = parsed.replace("", "").replace("", "").strip() + elif post_proc_type in ["description_with_bboxes", "bboxes"]: + bboxes = [inst["bbox"] for inst in parsed] + labels = [inst["cat_name"] for inst in parsed] + final_answer = {"bboxes": bboxes, "labels": labels} + if parsed and "score" in parsed[0]: + final_answer["scores"] = [inst["score"] for inst in parsed] + elif post_proc_type == "ocr": + quad_boxes = [inst["quad_box"] for inst in parsed] + labels = [inst["text"] for inst in parsed] + final_answer = {"quad_boxes": quad_boxes, "labels": labels} + elif post_proc_type == "phrase_grounding": + bboxes = [] + labels = [] + for inst in parsed: + for bbox in inst["bbox"]: + bboxes.append(bbox) + labels.append(inst["cat_name"]) + final_answer = {"bboxes": bboxes, "labels": labels} + elif post_proc_type in ["description_with_polygons", "polygons"]: + polygons = [inst["polygons"] for inst in parsed] + labels = [inst["cat_name"] for inst in parsed] + final_answer = {"polygons": polygons, "labels": labels} + elif post_proc_type == "description_with_bboxes_or_polygons": + bboxes = [] + bboxes_labels = [] + polygons = [] + polygons_labels = [] + for inst in parsed: + label = inst["cat_name"] + if "polygons" in inst: + polygons.append(inst["polygons"]) + polygons_labels.append(label) + else: + bboxes.append(inst["bbox"]) + bboxes_labels.append(label) + final_answer = { + "bboxes": bboxes, + "bboxes_labels": bboxes_labels, + "polygons": polygons, + "polygons_labels": polygons_labels, + } + else: + raise ValueError(f"Unknown post-processing type: {post_proc_type}") + + return {task: final_answer} + + +class Florence2PostProcessor: + """ + Post-processor for Florence-2 model outputs. Parses generated text into structured results for various tasks + like object detection, OCR, phrase grounding, etc. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used for decoding model outputs. + """ + + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.parse_task_config = config or {} + self.banned_grounding_tokens = set( + self.parse_task_config.get("phrase_grounding", {}).get("banned_grounding_tokens", []) + ) + self.all_special_tokens = set(self.tokenizer.all_special_tokens) + self.quantize_bins = (1000, 1000) + + def quantize(self, locations: "torch.Tensor", size: tuple[int, int]) -> "torch.Tensor": + """ + Quantize locations. + + Args: + locations (`torch.Tensor`): + Tensor of shape (N, 4) for boxes or (N, 2) for points/coordinates. + size (`tuple[int, int]`): + Original image size (width, height). + + Returns: + `torch.Tensor`: Quantized locations as integers. + """ + bins_w, bins_h = self.quantize_bins + size_w, size_h = size + per_bin_w = size_w / bins_w + per_bin_h = size_h / bins_h + + if locations.shape[-1] == 4: # Bounding boxes: [xmin, ymin, xmax, ymax] + xmin, ymin, xmax, ymax = locations.split(1, dim=-1) + q_xmin = (xmin / per_bin_w).floor().clamp(0, bins_w - 1) + q_ymin = (ymin / per_bin_h).floor().clamp(0, bins_h - 1) + q_xmax = (xmax / per_bin_w).floor().clamp(0, bins_w - 1) + q_ymax = (ymax / per_bin_h).floor().clamp(0, bins_h - 1) + return torch.cat([q_xmin, q_ymin, q_xmax, q_ymax], dim=-1).int() + + elif locations.shape[-1] == 2: # Points/coordinates: [x, y] + x, y = locations.split(1, dim=-1) + q_x = (x / per_bin_w).floor().clamp(0, bins_w - 1) + q_y = (y / per_bin_h).floor().clamp(0, bins_h - 1) + return torch.cat([q_x, q_y], dim=-1).int() + + else: + raise ValueError(f"Unsupported location shape: last dim must be 2 or 4, got {locations.shape[-1]}.") + + def dequantize(self, locations: "torch.Tensor", size: tuple[int, int]) -> "torch.Tensor": + """ + Dequantize locations back to original scale. + + Args: + locations (`torch.Tensor`): + Quantized tensor of shape (N, 4) for boxes or (N, 2) for points/coordinates. + size (`tuple[int, int]`): + Original image size (width, height). + + Returns: + `torch.Tensor`: Dequantized locations as floats. + """ + bins_w, bins_h = self.quantize_bins + size_w, size_h = size + per_bin_w = size_w / bins_w + per_bin_h = size_h / bins_h + + # Add 0.5 to use the center position of the bin as the coordinate. + if locations.shape[-1] == 4: # Bounding boxes + xmin, ymin, xmax, ymax = locations.split(1, dim=-1) + dq_xmin = (xmin + 0.5) * per_bin_w + dq_ymin = (ymin + 0.5) * per_bin_h + dq_xmax = (xmax + 0.5) * per_bin_w + dq_ymax = (ymax + 0.5) * per_bin_h + return torch.cat([dq_xmin, dq_ymin, dq_xmax, dq_ymax], dim=-1).int() + + elif locations.shape[-1] == 2: # Points/coordinates + x, y = locations.split(1, dim=-1) + dq_x = (x + 0.5) * per_bin_w + dq_y = (y + 0.5) * per_bin_h + return torch.cat([dq_x, dq_y], dim=-1).int() + + else: + raise ValueError(f"Unsupported location shape: last dim must be 2 or 4, got {locations.shape[-1]}.") + + def decode_with_spans(self, token_ids: list[int]) -> tuple[str, list[tuple[int, int]]]: + """ + Decode token IDs to text and compute character spans. + + Args: + token_ids (`list[int]`): + list of token IDs to decode. + + Returns: + `tuple[str, list[tuple[int, int]]]`: Decoded text and list of spans (start, end) for each token. + """ + filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False) + text = "" + spans = [] + for token in filtered_tokens: + if token in self.all_special_tokens: + sub_text = token + else: + sub_text = self.tokenizer.convert_tokens_to_string([token]) + span = (len(text), len(text) + len(sub_text)) + text += sub_text + spans.append(span) + return text, spans + + def parse_ocr_from_text_and_spans( + self, text: str, pattern: Optional[str], image_size: tuple[int, int], area_threshold: float = 0.0 + ) -> list[dict[str, Any]]: + """ + Parse OCR results with quadrilateral boxes. + + Args: + text (`str`): + The generated text. + pattern (`str`): + Regex pattern for matching. + image_size (`tuple[int, int]`): + Image size (width, height). + area_threshold (`float`, *optional*, defaults to 0.0): + Minimum area threshold for filtering boxes. + + Returns: + `list[dict[str, Any]]`: list of instances with 'quad_box' and 'text'. + """ + text = text.replace("", "").replace("", "").replace("", "") + if pattern is None: + pattern = r"(.+?)" + + matches = re.findall(pattern, text) + instances = [] + width, height = image_size + + for content, *quad_str in matches: + quad_bins = [int(i) for i in quad_str] + quad_box = self.dequantize(torch.tensor(quad_bins).reshape(-1, 2), size=image_size).flatten().tolist() + + if area_threshold > 0: + x_coords = quad_box[0::2] + y_coords = quad_box[1::2] + # Apply the Shoelace formula + area = 0.5 * abs( + sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)) + ) + + if area < (width * height) * area_threshold: + continue + + instances.append({"quad_box": quad_box, "text": content.strip()}) + return instances + + def parse_phrase_grounding_from_text_and_spans( + self, text: str, image_size: tuple[int, int] + ) -> list[dict[str, Any]]: + """ + Parse phrase grounding results. + + Args: + text (`str`): + The generated text. + image_size (`tuple[int, int]`): + Image size (width, height). + + Returns: + `list[dict[str, Any]]`: list of instances with 'bbox' and 'cat_name'. + """ + text = text.replace("", "").replace("", "").replace("", "") + phrase_pattern = r"([^<]+(?:){4,})" + phrases = re.findall(phrase_pattern, text) + text_pattern = r"^\s*(.*?)(?=||||||" + + instances = [] + for phrase_text in phrases: + phrase_text = phrase_text.replace("", "", 1).replace("", "", 1) + if not phrase_text: + continue + match = re.search(text_pattern, phrase_text) + if not match: + continue + phrase = match.group().strip() + if phrase in self.banned_grounding_tokens: + continue + boxes_matches = list(re.finditer(box_pattern, phrase_text)) + if not boxes_matches: + continue + bbox_bins = [[int(m.group(j)) for j in range(1, 5)] for m in boxes_matches] + bboxes = self.dequantize(torch.tensor(bbox_bins), size=image_size).tolist() + phrase = phrase.encode("ascii", "ignore").decode("ascii") + instances.append({"bbox": bboxes, "cat_name": phrase}) + return instances + + def _find_matched_token_indices(self, cur_span: tuple[int, int], token_spans: list[tuple[int, int]]) -> list[int]: + return [i for i, span in enumerate(token_spans) if not (span[1] <= cur_span[0] or span[0] >= cur_span[1])] + + def parse_description_with_bboxes_from_text_and_spans( + self, + text: str, + image_size: tuple[int, int], + allow_empty_phrase: bool = False, + ) -> list[dict[str, Any]]: + """ + Parse descriptions with bounding boxes. + + Args: + text (`str`): + The generated text. + image_size (`tuple[int, int]`): + Image size (width, height). + allow_empty_phrase (`bool`, *optional*, defaults to `False`): + Allow phrases without text. + + Returns: + `list[dict[str, Any]]`: list of instances with 'bbox', 'cat_name', and optional 'score'. + """ + text = text.replace("", "").replace("", "").replace("", "") + + if allow_empty_phrase: + pattern = r"(?:(?:){4,})" + else: + pattern = r"([^<]+(?:){4,})" + phrases = re.findall(pattern, text) + + text_pattern = r"^\s*(.*?)(?=||||||" + + instances = [] + for phrase_text in phrases: + phrase_text = phrase_text.replace("", "", 1).replace("", "", 1) + if not phrase_text and not allow_empty_phrase: + continue + match = re.search(text_pattern, phrase_text) + if not match: + continue + phrase = match.group().strip() + boxes_matches = list(re.finditer(box_pattern, phrase_text)) + if not boxes_matches: + continue + bbox_bins = [[int(m.group(j)) for j in range(1, 5)] for m in boxes_matches] + bboxes = self.dequantize(torch.tensor(bbox_bins), size=image_size).tolist() + + phrase = phrase.encode("ascii", "ignore").decode("ascii") + for bbox in bboxes: + instance = {"bbox": bbox, "cat_name": phrase} + instances.append(instance) + + return instances + + def parse_description_with_polygons_from_text_and_spans( + self, + text: str, + image_size: tuple[int, int], + allow_empty_phrase: bool = False, + polygon_sep_token: str = "", + polygon_start_token: str = "", + polygon_end_token: str = "", + with_box_at_start: bool = False, + ) -> list[dict[str, Any]]: + """ + Parse descriptions with polygons. + + Args: + text (`str`): + The generated text. + image_size (`tuple[int, int]`): + Image size (width, height). + allow_empty_phrase (`bool`, *optional*, defaults to `False`): + Allow phrases without text. + polygon_sep_token (`str`, *optional*, defaults to ""): + Token separating polygons. + polygon_start_token (`str`, *optional*, defaults to ""): + Start token for polygons. + polygon_end_token (`str`, *optional*, defaults to ""): + End token for polygons. + with_box_at_start (`bool`, *optional*, defaults to `False`): + Whether a bounding box is at the start of polygons. + + Returns: + `list[dict[str, Any]]`: list of instances with 'polygons', 'cat_name', and optional 'bbox'. + """ + text = text.replace("", "").replace("", "").replace("", "") + + if allow_empty_phrase: + pattern = rf"(?:(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})" + else: + pattern = rf"([^<]+(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})" + phrases = re.findall(pattern, text) + phrase_pattern = r"^\s*(.*?)(?=||||||)" + poly_instance_pattern = rf"{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}" + box_pattern = rf"((?:)+)(?:{re.escape(polygon_sep_token)}|$)" + + instances = [] + for phrase_text in phrases: + phrase_text_strip = re.sub(r"^", "", phrase_text, count=1) + if not phrase_text_strip and not allow_empty_phrase: + continue + match = re.search(phrase_pattern, phrase_text_strip) + if not match: + continue + phrase = match.group().strip() + + if polygon_start_token in phrase_text and polygon_end_token in phrase_text: + poly_instances = [m.group(1) for m in re.finditer(poly_instance_pattern, phrase_text)] + else: + poly_instances = [phrase_text] + + for poly_inst in poly_instances: + poly_matches = list(re.finditer(box_pattern, poly_inst)) + if len(poly_matches) == 0: + continue + bbox = [] + polygons = [] + for poly_match in poly_matches: + poly_str = poly_match.group(1) + poly_bins = [int(m.group(1)) for m in re.finditer(r"", poly_str)] + if with_box_at_start and not bbox: + if len(poly_bins) > 4: + bbox = poly_bins[:4] + poly_bins = poly_bins[4:] + else: + bbox = [0, 0, 0, 0] + if len(poly_bins) % 2 == 1: + poly_bins = poly_bins[:-1] + poly_coords = ( + self.dequantize(torch.tensor(poly_bins).reshape(-1, 2), size=image_size).flatten().tolist() + ) + polygons.append(poly_coords) + + instance = {"cat_name": phrase, "polygons": polygons} + if bbox: + instance["bbox"] = self.dequantize(torch.tensor([bbox]), size=image_size)[0].tolist() + instances.append(instance) + return instances + + def __call__(self, text=None, sequence=None, image_size=None, parse_tasks=None) -> dict[str, Any]: + """ + Process model output and parse into task-specific results. + + Args: + text (`Optional[str]`, *optional*): + Generated text. Either this or `sequence` must be provided. + sequence (`Optional[Union[list[int], torch.Tensor]]`, *optional*): + Token sequence. Either this or `text` must be provided. + image_size (`Optional[tuple[int, int]]`, *optional*): + Image size (width, height) required for dequantization. + parse_tasks (`Optional[Union[str, list[str]]]`, *optional*): + Specific tasks to parse. If None, parse all supported tasks. + + Returns: + `dict[str, Any]`: Parsed results for each task, including the raw 'text'. + """ + if parse_tasks is not None: + parse_tasks = [parse_tasks] if isinstance(parse_tasks, str) else parse_tasks + for task in parse_tasks: + if task not in self.parse_task_config.keys(): + raise ValueError(f"Unsupported parse task: {task}") + + if (text is None and sequence is None) or (text is not None and sequence is not None): + raise ValueError("Exactly one of 'text' or 'sequence' must be provided.") + + if sequence is not None: + if isinstance(sequence, torch.Tensor): + sequence = sequence.tolist() + sequence = sequence[1:] if sequence[0] == self.tokenizer.bos_token_id else sequence # Skip BOS if present + text, _ = self.decode_with_spans(sequence) + + parsed_dict = {"text": text} + + tasks_to_parse = parse_tasks or self.parse_task_config.keys() + for task in tasks_to_parse: + config = self.parse_task_config[task] + pattern = config.get("PATTERN") + + if task == "ocr": + parsed_dict["ocr"] = self.parse_ocr_from_text_and_spans( + text, pattern=pattern, image_size=image_size, area_threshold=config.get("AREA_THRESHOLD", 0.0) + ) + elif task == "phrase_grounding": + parsed_dict["phrase_grounding"] = self.parse_phrase_grounding_from_text_and_spans( + text, image_size=image_size + ) + elif task == "pure_text": + parsed_dict["pure_text"] = text + elif task == "description_with_bboxes": + parsed_dict["description_with_bboxes"] = self.parse_description_with_bboxes_from_text_and_spans( + text, image_size=image_size + ) + elif task == "description_with_polygons": + parsed_dict["description_with_polygons"] = self.parse_description_with_polygons_from_text_and_spans( + text, image_size=image_size + ) + elif task == "polygons": + parsed_dict["polygons"] = self.parse_description_with_polygons_from_text_and_spans( + text, image_size=image_size, allow_empty_phrase=True + ) + elif task == "bboxes": + parsed_dict["bboxes"] = self.parse_description_with_bboxes_from_text_and_spans( + text, image_size=image_size, allow_empty_phrase=True + ) + elif task == "description_with_bboxes_or_polygons": + if "" in text: + instances = self.parse_description_with_polygons_from_text_and_spans(text, image_size=image_size) + else: + instances = self.parse_description_with_bboxes_from_text_and_spans(text, image_size=image_size) + parsed_dict["description_with_bboxes_or_polygons"] = instances + else: + raise ValueError(f"task {task} is not supported") + + return parsed_dict + + +class Florence2VisionDropPath(BeitDropPath): + pass + + +class Florence2VisionLearnedAbsolutePositionEmbedding2D(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, config: Florence2Config): + super().__init__() + num_pos = config.vision_config.max_position_embeddings + embedding_dim = config.vision_config.embed_dim[-1] + self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) + self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) + + def forward(self, pixel_values, pixel_mask=None): + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + return pos + + +class Florence2VisionPositionalEmbeddingCosine1D(nn.Module): + """ + This module generates 1D cosine positional embeddings using precomputed sinusoidal functions. + """ + + def __init__(self, config: Florence2Config): + super().__init__() + self.embed_dim = config.vision_config.embed_dim[-1] + self.max_seq_len = config.vision_config.max_temporal_embeddings + pos_idx_to_embed = torch.empty((self.max_seq_len, self.embed_dim)) + sine, cosine = self.get_sinusoid_embeddings( + max_positions=self.max_seq_len, + embed_dim=self.embed_dim, + ) + pos_idx_to_embed[:, 0::2] = sine + pos_idx_to_embed[:, 1::2] = cosine + # Save the positional embeddings in a constant buffer. + self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) + + @staticmethod + def get_sinusoid_embeddings(max_positions: int, embed_dim: int): + half_dim = embed_dim // 2 + emb = math.log(10000) / half_dim + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(max_positions, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + return torch.sin(emb), torch.cos(emb) + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + len_seq = seq_embeds.size(1) + if len_seq > self.max_seq_len: + raise ValueError(f"Maximum sequence length {self.max_seq_len}, got {len_seq}") + pos_embeds = self.pos_idx_to_embed[0:len_seq, :] + return pos_embeds + + +class Florence2VisionMLP(Llama4VisionMLP): + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__(config) + self.fc1 = nn.Linear(config.embed_dim[stage_idx], int(config.embed_dim[stage_idx] * config.mlp_ratio)) + self.activation_fn = ACT2FN[config.activation_function] + self.fc2 = nn.Linear(int(config.embed_dim[stage_idx] * config.mlp_ratio), config.embed_dim[stage_idx]) + + +class Florence2VisionConvEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__() + self.config = config + self.stage_idx = stage_idx + self.patch_size = config.patch_size[stage_idx] + self.in_channels = config.in_channels if stage_idx == 0 else config.embed_dim[stage_idx - 1] + self.embed_dim = config.embed_dim[stage_idx] + self.stride = config.patch_stride[stage_idx] + self.padding = config.patch_padding[stage_idx] + self.pre_norm = config.patch_prenorm[stage_idx] + + self.conv = nn.Conv2d( + self.in_channels, + self.embed_dim, + kernel_size=self.patch_size, + stride=self.stride, + padding=self.padding, + ) + + dim_norm = self.in_channels if self.pre_norm else self.embed_dim + self.norm = nn.LayerNorm(dim_norm) + + def forward(self, hidden_states: torch.Tensor): + if self.norm and self.pre_norm: + hidden_states = hidden_states.permute(0, 2, 3, 1) + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 1, 2) + + hidden_states = self.conv(hidden_states) + + if self.norm and not self.pre_norm: + hidden_states = hidden_states.permute(0, 2, 3, 1) + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 1, 2) + return hidden_states + + +class Florence2VisionChannelAttention(nn.Module): + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__() + self.config = config + self.dim = config.embed_dim[stage_idx] + self.groups = config.num_groups[stage_idx] + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=config.qkv_bias) + self.proj = nn.Linear(self.dim, self.dim) + self.is_causal = False + + def forward(self, hidden_states: torch.Tensor): + batch_size, num_tokens, hidden_size = hidden_states.shape + + # Reshape for grouped channel attention + qkv = self.qkv(hidden_states).reshape(batch_size, num_tokens, 3, self.groups, hidden_size // self.groups) + qkv = qkv.permute(2, 0, 3, 4, 1) + query, key, value = qkv.unbind(0) + + scale = num_tokens**-0.5 + # Channel-to-channel attention within groups: + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + hidden_states, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + scaling=scale, + ) + hidden_states = hidden_states.permute(0, 3, 2, 1) + hidden_states = hidden_states.reshape(batch_size, num_tokens, hidden_size) + + # Final projection + hidden_states = self.proj(hidden_states) + return hidden_states + + +class Florence2VisionChannelBlock(nn.Module): + def __init__( + self, + config: Florence2VisionConfig, + stage_idx: int, + drop_path_rate: float, + ): + super().__init__() + + self.config = config + dim_in = config.embed_dim[stage_idx] + + self.conv1 = nn.Conv2d( + dim_in, + dim_in, + kernel_size=3, + padding=1, + groups=dim_in, + ) + self.norm1 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.channel_attn = Florence2VisionChannelAttention(config=config, stage_idx=stage_idx) + self.drop_path1 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + self.conv2 = nn.Conv2d( + dim_in, + dim_in, + kernel_size=3, + padding=1, + groups=dim_in, + ) + self.norm2 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.ffn = Florence2VisionMLP(config=config, stage_idx=stage_idx) + self.drop_path2 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, hidden_states: torch.Tensor): + batch_size, embed_dim, height, width = hidden_states.shape + + # First channel block: Depthwise Conv + Channel Attention + hidden_states = self.conv1(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # Channel group attention self-attention mechanism + hidden_states = self.norm1(hidden_states) + hidden_states = self.channel_attn(hidden_states) + hidden_states = residual + self.drop_path1(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + # Second channel block: Depthwise Conv + FFN + hidden_states = self.conv2(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # FFN + hidden_states = self.norm2(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + self.drop_path2(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + return hidden_states + + +class Florence2VisionWindowAttention(nn.Module): + def __init__(self, config: Florence2VisionConfig, stage_idx: int): + super().__init__() + self.config = config + self.dim = config.embed_dim[stage_idx] + self.window_size = config.window_size + self.num_heads = config.num_heads[stage_idx] + head_dim = self.dim // self.num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=config.qkv_bias) + self.proj = nn.Linear(self.dim, self.dim) + self.is_causal = False + + def forward(self, hidden_states: torch.Tensor): + batch_size, height, width, embed_dim = hidden_states.shape + + # Pad the input if necessary + pad_left = pad_top = 0 + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + hidden_states = F.pad(hidden_states, (0, 0, pad_left, pad_right, pad_top, pad_bottom)) + _, padded_height, padded_width, _ = hidden_states.shape + + # Partition input into non-overlapping windows (for local spatial attention in DaViT) + hidden_states = hidden_states.view( + batch_size, + padded_height // self.window_size, + self.window_size, + padded_width // self.window_size, + self.window_size, + embed_dim, + ) + windowed_hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous() + windowed_hidden_states = windowed_hidden_states.view(-1, self.window_size * self.window_size, embed_dim) + + # Generate Q, K, V for each window + num_windows_per_batch, num_tokens_per_window, embed_dim = windowed_hidden_states.shape + qkv = self.qkv(windowed_hidden_states).reshape( + num_windows_per_batch, num_tokens_per_window, 3, self.num_heads, embed_dim // self.num_heads + ) + qkv = qkv.permute(2, 0, 3, 1, 4) + query, key, value = qkv.unbind(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + windowed_hidden_states, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + scaling=self.scale, + ) + windowed_hidden_states = windowed_hidden_states.view(num_windows_per_batch, num_tokens_per_window, embed_dim) + windowed_hidden_states = self.proj(windowed_hidden_states) + + # Merge windows back to original spatial layout + windowed_hidden_states = windowed_hidden_states.view(-1, self.window_size, self.window_size, embed_dim) + hidden_states = windowed_hidden_states.view( + -1, + padded_height // self.window_size, + padded_width // self.window_size, + self.window_size, + self.window_size, + embed_dim, + ) + hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous() + hidden_states = hidden_states.view(-1, padded_height, padded_width, embed_dim) + hidden_states = hidden_states[:, :height, :width, :].contiguous() + hidden_states = hidden_states.view(batch_size, height * width, embed_dim) + + return hidden_states + + +class Florence2VisionSpatialBlock(nn.Module): + def __init__( + self, + config: Florence2VisionConfig, + stage_idx: int, + drop_path_rate: float, + ): + super().__init__() + + self.conv1 = nn.Conv2d( + config.embed_dim[stage_idx], + config.embed_dim[stage_idx], + kernel_size=3, + padding=1, + groups=config.embed_dim[stage_idx], + ) + self.norm1 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.window_attn = Florence2VisionWindowAttention(config=config, stage_idx=stage_idx) + self.drop_path1 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + self.conv2 = nn.Conv2d( + config.embed_dim[stage_idx], + config.embed_dim[stage_idx], + kernel_size=3, + padding=1, + groups=config.embed_dim[stage_idx], + ) + self.norm2 = nn.LayerNorm(config.embed_dim[stage_idx]) + self.ffn = Florence2VisionMLP(config=config, stage_idx=stage_idx) + self.drop_path2 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, hidden_states: torch.Tensor): + batch_size, embed_dim, height, width = hidden_states.shape + + # First spatial mixing block: Conv + Window Attention + hidden_states = self.conv1(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # Spatial Window-based self-attention mechanism + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states.view(batch_size, height, width, embed_dim) + hidden_states = self.window_attn(hidden_states) + hidden_states = residual + self.drop_path1(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + # Second spatial mixing block: Conv + FFN + hidden_states = self.conv2(hidden_states) + hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + residual = hidden_states + + # FFN + hidden_states = self.norm2(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + self.drop_path2(hidden_states) + hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) + + return hidden_states + + +class Florence2VisionBlock(nn.Module): + def __init__( + self, + config: Florence2VisionConfig, + stage_idx: int, + spatial_drop_path_rate: float, + channel_drop_path_rate: float, + ): + super().__init__() + self.spatial_block = Florence2VisionSpatialBlock( + config=config, + stage_idx=stage_idx, + drop_path_rate=spatial_drop_path_rate, + ) + self.channel_block = Florence2VisionChannelBlock( + config=config, + stage_idx=stage_idx, + drop_path_rate=channel_drop_path_rate, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.spatial_block(hidden_states) + hidden_states = self.channel_block(hidden_states) + return hidden_states + + +@auto_docstring +class Florence2VisionPreTrainedModel(PreTrainedModel): + config_class = Florence2VisionConfig + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + +@auto_docstring +class Florence2VisionBackbone(Florence2VisionPreTrainedModel): + def __init__(self, config: Florence2VisionConfig): + super().__init__(config) + self.config = config + + self.embed_dim = config.embed_dim + self.num_heads = config.num_heads + self.num_groups = config.num_groups + self.num_stages = len(self.embed_dim) + + if not (self.num_stages == len(self.num_heads) == len(self.num_groups)): + raise ValueError( + f"Expected self.num_stages ({self.num_stages}) == " + f"len(self.num_heads) ({len(self.num_heads)}) == " + f"len(self.num_groups) ({len(self.num_groups)})" + ) + + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths) * 2, device="cpu")] + depth_offset = 0 + + convs = [] + blocks = [] + for stage_idx in range(self.num_stages): + conv_embed = Florence2VisionConvEmbed( + config=config, + stage_idx=stage_idx, + ) + convs.append(conv_embed) + + block = nn.ModuleList( + Florence2VisionBlock( + config=config, + stage_idx=stage_idx, + spatial_drop_path_rate=dpr[depth_offset + block_idx * 2], + channel_drop_path_rate=dpr[depth_offset + block_idx * 2 + 1], + ) + for block_idx in range(config.depths[stage_idx]) + ) + blocks.append(block) + depth_offset += config.depths[stage_idx] * 2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, hidden_states: torch.Tensor): + for conv, block in zip(self.convs, self.blocks): + hidden_states = conv(hidden_states) + for layer in block: + hidden_states = layer(hidden_states) + return hidden_states + + +class Florence2MultiModalProjector(nn.Module): + def __init__(self, config: Florence2Config): + super().__init__() + self.vision_embedding_dim = config.vision_config.embed_dim[-1] + self.vision_projection_dim = config.vision_config.projection_dim + self.image_projection = nn.Linear(self.vision_embedding_dim, self.vision_projection_dim, bias=False) + self.image_proj_norm = nn.LayerNorm(self.vision_projection_dim) + self.image_position_embed = Florence2VisionLearnedAbsolutePositionEmbedding2D(config=config) + self.visual_temporal_embed = Florence2VisionPositionalEmbeddingCosine1D(config=config) + + def forward(self, image_features): + position_features = image_features + self.image_position_embed(image_features) + position_features = position_features.flatten(2).transpose(1, 2) + temporal_features = self.visual_temporal_embed(position_features[:, :1, :]) + temporal_features = temporal_features.unsqueeze(1) + visual_token_features = position_features + temporal_features + visual_token_features = visual_token_features.unsqueeze(1) + spatial_image_features = visual_token_features.mean(dim=2) + temporal_image_features = visual_token_features.mean(dim=1) + image_features = torch.cat([spatial_image_features, temporal_image_features], dim=1) + image_features = self.image_projection(image_features) + image_features = self.image_proj_norm(image_features) + return image_features + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput): + r""" + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@auto_docstring +class Florence2PreTrainedModel(LlavaPreTrainedModel): + config_class = Florence2Config + + _supports_attention_backend = False + + +@auto_docstring( + custom_intro=""" + Florence-2 is a vision model for captioning, detection, and segmentation. + """ +) +class Florence2Model(LlavaModel): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = [ + "language_model.encoder.embed_tokens.weight", + "language_model.decoder.embed_tokens.weight", + ] + + def __init__(self, config: Florence2Config): + super().__init__(config) + self.vision_tower = Florence2VisionBackbone(config=config.vision_config) + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_image_features(self, pixel_values: torch.Tensor, **kwargs): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_features = self.vision_tower(pixel_values, **kwargs) + image_embeds = self.multi_modal_projector(image_features) + return image_embeds + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Florence2Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + encoder_outputs = self.language_model.encoder( + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + if decoder_input_ids is None: + decoder_start_token_id = self.config.text_config.decoder_start_token_id + decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device) + decoder_input_ids *= decoder_start_token_id + + decoder_outputs = self.language_model.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=True, + ) + + return Florence2Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + Florence-2 is a vision model for captioning, detection, and segmentation. + """ +) +class Florence2ForConditionalGeneration(LlavaForConditionalGeneration): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = [ + "model.language_model.encoder.embed_tokens.weight", + "model.language_model.decoder.embed_tokens.weight", + "lm_head.weight", + ] + + def get_encoder(self): + return self.model.get_encoder() + + def get_image_features(self, pixel_values: torch.Tensor, **kwargs): + return self.model.get_image_features(pixel_values=pixel_values, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Florence2Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration + + >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") + >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") + + >>> prompt = "" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=100) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A green car parked in front of a yellow building." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.text_config.pad_token_id, self.config.text_config.decoder_start_token_id + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + # **kwargs, ## TODO: add back when Bart attention is refactored and takes kwargs + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return Florence2Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + return self.model.get_placeholder_mask( + input_ids=input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str], + generation_config, + ) -> dict[str, Any]: + # override to handle merging image and text embeddings before passing to language encoder + inputs_embeds = model_kwargs.pop("inputs_embeds", None) + pixel_values = model_kwargs.pop("pixel_values", None) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(inputs_tensor) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + inputs_tensor, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + model_kwargs["inputs_embeds"] = inputs_embeds + model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation( + None, model_kwargs, model_input_name, generation_config + ) + model_kwargs.pop("inputs_embeds", None) + return model_kwargs + + +__all__ = [ + "Florence2Config", + "Florence2Processor", + "Florence2VisionConfig", + "Florence2Model", + "Florence2ForConditionalGeneration", + "Florence2PreTrainedModel", + "Florence2VisionBackbone", + "Florence2VisionPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/florence2/processing_florence2.py b/venv/lib/python3.13/site-packages/transformers/models/florence2/processing_florence2.py new file mode 100644 index 0000000000000000000000000000000000000000..91b63e9da7db8495ff9172f67a8fbd7597156a20 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/florence2/processing_florence2.py @@ -0,0 +1,802 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/florence2/modular_florence2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_florence2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import Any, Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_torch_available, logging + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class Florence2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": {"padding": False, "return_mm_token_type_ids": False}, + "images_kwargs": {}, + } + + +class Florence2Processor(ProcessorMixin): + r""" + Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor. + + [`Florence2Processor`] offers all the functionalities of [`AutoImageProcessor`] and [`BartTokenizerFast`]. See the + [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information. + + Args: + image_processor (`AutoImageProcessor`, *optional*): + The image processor is a required input. + tokenizer (`Union[BartTokenizer, BartTokenizerFast]`, *optional*): + The tokenizer is a required input. + num_additional_image_tokens (`int`, *optional*, defaults to 0): + Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other + extra tokens appended, no need to set this arg. + post_processor_config (`dict`, *optional*, defaults to 0): + Task-specific parsing rules for [`Florence2PostProcessor`], e.g. regex patterns, + thresholds, or banned tokens. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("BartTokenizer", "BartTokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + num_additional_image_tokens: int = 0, + post_processor_config: Optional[dict] = None, + **kwargs, + ): + self.tasks_answer_post_processing_type = { + "": "pure_text", + "": "ocr", + "": "pure_text", + "": "pure_text", + "": "pure_text", + "": "description_with_bboxes", + "": "description_with_bboxes", + "": "phrase_grounding", + "": "polygons", + "": "polygons", + "": "description_with_bboxes_or_polygons", + "": "pure_text", + "": "pure_text", + "": "pure_text", + "": "bboxes", + } + + self.task_prompts_without_inputs = { + "": "What is the text in the image?", + "": "What is the text in the image, with regions?", + "": "What does the image describe?", + "": "Describe in detail what is shown in the image.", + "": "Describe with a paragraph what is shown in the image.", + "": "Locate the objects with category name in the image.", + "": "Locate the objects in the image, with their descriptions.", + "": "Locate the region proposals in the image.", + } + + self.task_prompts_with_input = { + "": "Locate the phrases in the caption: {input}", + "": "Locate {input} in the image with mask", + "": "What is the polygon mask of region {input}", + "": "Locate {input} in the image.", + "": "What is the region {input}?", + "": "What does the region {input} describe?", + "": "What text is in the region {input}?", + } + + self.num_image_tokens = image_processor.image_seq_length + self.num_additional_image_tokens = num_additional_image_tokens + self.post_processor_config = post_processor_config + self.post_processor = Florence2PostProcessor(config=post_processor_config, tokenizer=tokenizer) + self.image_token = tokenizer.image_token + self.image_token_id = tokenizer.image_token_id + + super().__init__(image_processor, tokenizer, **kwargs) + + def _construct_prompts(self, text: Union[str, list[str]]) -> list[str]: + """ + Construct prompts by replacing task tokens with corresponding prompt strings. + """ + if isinstance(text, str): + text = [text] + + prompts = [] + for prompt in text: + # Check for tasks without inputs + for task_token, task_prompt in self.task_prompts_without_inputs.items(): + if task_token in prompt: + if prompt != task_token: + raise ValueError(f"Task token {task_token} should be the only content in the prompt.") + prompt = task_prompt + break + # Check for tasks with inputs + for task_token, task_prompt in self.task_prompts_with_input.items(): + if task_token in prompt: + input_text = prompt.replace(task_token, "").strip() + prompt = task_prompt.format(input=input_text) + break + prompts.append(prompt) + return prompts + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + **kwargs: Unpack[Florence2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is None and text is None: + raise ValueError("You have to specify at least one of `images` or `text`.") + + output_kwargs = self._merge_kwargs( + Florence2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + image_inputs = {} + if images is not None: + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text is None: + logger.warning_once("You are using Florence-2 without a text prefix.") + text = [""] * (1 if not isinstance(images, list) else len(images)) + elif isinstance(text, str): + text = [text] + + if not isinstance(text, list) or not all(isinstance(token, str) for token in text): + raise ValueError("`text` must be a string or list of strings.") + + if isinstance(images, list) and len(images) != len(text): + raise ValueError(f"Number of images ({len(images)}) must match number of texts ({len(text)}).") + + prompt_strings = self._construct_prompts(text) + + # Add image tokens and special tokens if images are provided + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + expanded_image_prompts = [] + for sample in prompt_strings: + sample = ( + self.image_token * self.num_image_tokens + + self.tokenizer.bos_token + + sample + + self.tokenizer.eos_token + ) + expanded_image_prompts.append(sample) + prompt_strings = expanded_image_prompts + + # Construct and tokenize prompts + output_kwargs["text_kwargs"].pop("add_special_tokens", None) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer( + prompt_strings, **output_kwargs["text_kwargs"], add_special_tokens=False, return_tensors=None + ) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**image_inputs, **text_inputs}, tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + num_image_tokens = [self.image_seq_length] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=False, **kwargs): + """ + Post-processes the output of `FuyuForConditionalGeneration` to only return the text output. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + containing the token ids of the generated sequences. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text output. + """ + return self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + + def post_process_generation(self, text=None, sequence=None, task=None, image_size=None) -> dict[str, Any]: + """ + Post-process generation outputs based on the task. + + Args: + text (`str`, *optional*): + Generated text. + sequence (`Union[List[int], torch.Tensor]`, *optional*): + Generated token sequence. + task (`str`, *optional*): + The task for post-processing. + image_size (`Tuple[int, int]`, *optional*): + Image size for dequantization. + + Returns: + `Dict[str, Any]`: Post-processed results keyed by task. + """ + if task is None: + raise ValueError("`task` must be provided for post-processing.") + + post_proc_type = self.tasks_answer_post_processing_type.get(task, "pure_text") + parsed = self.post_processor( + text=text, + sequence=sequence, + image_size=image_size, + parse_tasks=[post_proc_type], + )[post_proc_type] + + if post_proc_type == "pure_text": + final_answer = parsed.replace("", "").replace("", "").strip() + elif post_proc_type in ["description_with_bboxes", "bboxes"]: + bboxes = [inst["bbox"] for inst in parsed] + labels = [inst["cat_name"] for inst in parsed] + final_answer = {"bboxes": bboxes, "labels": labels} + if parsed and "score" in parsed[0]: + final_answer["scores"] = [inst["score"] for inst in parsed] + elif post_proc_type == "ocr": + quad_boxes = [inst["quad_box"] for inst in parsed] + labels = [inst["text"] for inst in parsed] + final_answer = {"quad_boxes": quad_boxes, "labels": labels} + elif post_proc_type == "phrase_grounding": + bboxes = [] + labels = [] + for inst in parsed: + for bbox in inst["bbox"]: + bboxes.append(bbox) + labels.append(inst["cat_name"]) + final_answer = {"bboxes": bboxes, "labels": labels} + elif post_proc_type in ["description_with_polygons", "polygons"]: + polygons = [inst["polygons"] for inst in parsed] + labels = [inst["cat_name"] for inst in parsed] + final_answer = {"polygons": polygons, "labels": labels} + elif post_proc_type == "description_with_bboxes_or_polygons": + bboxes = [] + bboxes_labels = [] + polygons = [] + polygons_labels = [] + for inst in parsed: + label = inst["cat_name"] + if "polygons" in inst: + polygons.append(inst["polygons"]) + polygons_labels.append(label) + else: + bboxes.append(inst["bbox"]) + bboxes_labels.append(label) + final_answer = { + "bboxes": bboxes, + "bboxes_labels": bboxes_labels, + "polygons": polygons, + "polygons_labels": polygons_labels, + } + else: + raise ValueError(f"Unknown post-processing type: {post_proc_type}") + + return {task: final_answer} + + +class Florence2PostProcessor: + """ + Post-processor for Florence-2 model outputs. Parses generated text into structured results for various tasks + like object detection, OCR, phrase grounding, etc. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used for decoding model outputs. + """ + + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.parse_task_config = config or {} + self.banned_grounding_tokens = set( + self.parse_task_config.get("phrase_grounding", {}).get("banned_grounding_tokens", []) + ) + self.all_special_tokens = set(self.tokenizer.all_special_tokens) + self.quantize_bins = (1000, 1000) + + def quantize(self, locations: "torch.Tensor", size: tuple[int, int]) -> "torch.Tensor": + """ + Quantize locations. + + Args: + locations (`torch.Tensor`): + Tensor of shape (N, 4) for boxes or (N, 2) for points/coordinates. + size (`tuple[int, int]`): + Original image size (width, height). + + Returns: + `torch.Tensor`: Quantized locations as integers. + """ + bins_w, bins_h = self.quantize_bins + size_w, size_h = size + per_bin_w = size_w / bins_w + per_bin_h = size_h / bins_h + + if locations.shape[-1] == 4: # Bounding boxes: [xmin, ymin, xmax, ymax] + xmin, ymin, xmax, ymax = locations.split(1, dim=-1) + q_xmin = (xmin / per_bin_w).floor().clamp(0, bins_w - 1) + q_ymin = (ymin / per_bin_h).floor().clamp(0, bins_h - 1) + q_xmax = (xmax / per_bin_w).floor().clamp(0, bins_w - 1) + q_ymax = (ymax / per_bin_h).floor().clamp(0, bins_h - 1) + return torch.cat([q_xmin, q_ymin, q_xmax, q_ymax], dim=-1).int() + + elif locations.shape[-1] == 2: # Points/coordinates: [x, y] + x, y = locations.split(1, dim=-1) + q_x = (x / per_bin_w).floor().clamp(0, bins_w - 1) + q_y = (y / per_bin_h).floor().clamp(0, bins_h - 1) + return torch.cat([q_x, q_y], dim=-1).int() + + else: + raise ValueError(f"Unsupported location shape: last dim must be 2 or 4, got {locations.shape[-1]}.") + + def dequantize(self, locations: "torch.Tensor", size: tuple[int, int]) -> "torch.Tensor": + """ + Dequantize locations back to original scale. + + Args: + locations (`torch.Tensor`): + Quantized tensor of shape (N, 4) for boxes or (N, 2) for points/coordinates. + size (`tuple[int, int]`): + Original image size (width, height). + + Returns: + `torch.Tensor`: Dequantized locations as floats. + """ + bins_w, bins_h = self.quantize_bins + size_w, size_h = size + per_bin_w = size_w / bins_w + per_bin_h = size_h / bins_h + + # Add 0.5 to use the center position of the bin as the coordinate. + if locations.shape[-1] == 4: # Bounding boxes + xmin, ymin, xmax, ymax = locations.split(1, dim=-1) + dq_xmin = (xmin + 0.5) * per_bin_w + dq_ymin = (ymin + 0.5) * per_bin_h + dq_xmax = (xmax + 0.5) * per_bin_w + dq_ymax = (ymax + 0.5) * per_bin_h + return torch.cat([dq_xmin, dq_ymin, dq_xmax, dq_ymax], dim=-1).int() + + elif locations.shape[-1] == 2: # Points/coordinates + x, y = locations.split(1, dim=-1) + dq_x = (x + 0.5) * per_bin_w + dq_y = (y + 0.5) * per_bin_h + return torch.cat([dq_x, dq_y], dim=-1).int() + + else: + raise ValueError(f"Unsupported location shape: last dim must be 2 or 4, got {locations.shape[-1]}.") + + def decode_with_spans(self, token_ids: list[int]) -> tuple[str, list[tuple[int, int]]]: + """ + Decode token IDs to text and compute character spans. + + Args: + token_ids (`list[int]`): + list of token IDs to decode. + + Returns: + `tuple[str, list[tuple[int, int]]]`: Decoded text and list of spans (start, end) for each token. + """ + filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False) + text = "" + spans = [] + for token in filtered_tokens: + if token in self.all_special_tokens: + sub_text = token + else: + sub_text = self.tokenizer.convert_tokens_to_string([token]) + span = (len(text), len(text) + len(sub_text)) + text += sub_text + spans.append(span) + return text, spans + + def parse_ocr_from_text_and_spans( + self, text: str, pattern: Optional[str], image_size: tuple[int, int], area_threshold: float = 0.0 + ) -> list[dict[str, Any]]: + """ + Parse OCR results with quadrilateral boxes. + + Args: + text (`str`): + The generated text. + pattern (`str`): + Regex pattern for matching. + image_size (`tuple[int, int]`): + Image size (width, height). + area_threshold (`float`, *optional*, defaults to 0.0): + Minimum area threshold for filtering boxes. + + Returns: + `list[dict[str, Any]]`: list of instances with 'quad_box' and 'text'. + """ + text = text.replace("", "").replace("", "").replace("", "") + if pattern is None: + pattern = r"(.+?)" + + matches = re.findall(pattern, text) + instances = [] + width, height = image_size + + for content, *quad_str in matches: + quad_bins = [int(i) for i in quad_str] + quad_box = self.dequantize(torch.tensor(quad_bins).reshape(-1, 2), size=image_size).flatten().tolist() + + if area_threshold > 0: + x_coords = quad_box[0::2] + y_coords = quad_box[1::2] + # Apply the Shoelace formula + area = 0.5 * abs( + sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)) + ) + + if area < (width * height) * area_threshold: + continue + + instances.append({"quad_box": quad_box, "text": content.strip()}) + return instances + + def parse_phrase_grounding_from_text_and_spans( + self, text: str, image_size: tuple[int, int] + ) -> list[dict[str, Any]]: + """ + Parse phrase grounding results. + + Args: + text (`str`): + The generated text. + image_size (`tuple[int, int]`): + Image size (width, height). + + Returns: + `list[dict[str, Any]]`: list of instances with 'bbox' and 'cat_name'. + """ + text = text.replace("", "").replace("", "").replace("", "") + phrase_pattern = r"([^<]+(?:){4,})" + phrases = re.findall(phrase_pattern, text) + text_pattern = r"^\s*(.*?)(?=||||||" + + instances = [] + for phrase_text in phrases: + phrase_text = phrase_text.replace("", "", 1).replace("", "", 1) + if not phrase_text: + continue + match = re.search(text_pattern, phrase_text) + if not match: + continue + phrase = match.group().strip() + if phrase in self.banned_grounding_tokens: + continue + boxes_matches = list(re.finditer(box_pattern, phrase_text)) + if not boxes_matches: + continue + bbox_bins = [[int(m.group(j)) for j in range(1, 5)] for m in boxes_matches] + bboxes = self.dequantize(torch.tensor(bbox_bins), size=image_size).tolist() + phrase = phrase.encode("ascii", "ignore").decode("ascii") + instances.append({"bbox": bboxes, "cat_name": phrase}) + return instances + + def _find_matched_token_indices(self, cur_span: tuple[int, int], token_spans: list[tuple[int, int]]) -> list[int]: + return [i for i, span in enumerate(token_spans) if not (span[1] <= cur_span[0] or span[0] >= cur_span[1])] + + def parse_description_with_bboxes_from_text_and_spans( + self, + text: str, + image_size: tuple[int, int], + allow_empty_phrase: bool = False, + ) -> list[dict[str, Any]]: + """ + Parse descriptions with bounding boxes. + + Args: + text (`str`): + The generated text. + image_size (`tuple[int, int]`): + Image size (width, height). + allow_empty_phrase (`bool`, *optional*, defaults to `False`): + Allow phrases without text. + + Returns: + `list[dict[str, Any]]`: list of instances with 'bbox', 'cat_name', and optional 'score'. + """ + text = text.replace("", "").replace("", "").replace("", "") + + if allow_empty_phrase: + pattern = r"(?:(?:){4,})" + else: + pattern = r"([^<]+(?:){4,})" + phrases = re.findall(pattern, text) + + text_pattern = r"^\s*(.*?)(?=||||||" + + instances = [] + for phrase_text in phrases: + phrase_text = phrase_text.replace("", "", 1).replace("", "", 1) + if not phrase_text and not allow_empty_phrase: + continue + match = re.search(text_pattern, phrase_text) + if not match: + continue + phrase = match.group().strip() + boxes_matches = list(re.finditer(box_pattern, phrase_text)) + if not boxes_matches: + continue + bbox_bins = [[int(m.group(j)) for j in range(1, 5)] for m in boxes_matches] + bboxes = self.dequantize(torch.tensor(bbox_bins), size=image_size).tolist() + + phrase = phrase.encode("ascii", "ignore").decode("ascii") + for bbox in bboxes: + instance = {"bbox": bbox, "cat_name": phrase} + instances.append(instance) + + return instances + + def parse_description_with_polygons_from_text_and_spans( + self, + text: str, + image_size: tuple[int, int], + allow_empty_phrase: bool = False, + polygon_sep_token: str = "", + polygon_start_token: str = "", + polygon_end_token: str = "", + with_box_at_start: bool = False, + ) -> list[dict[str, Any]]: + """ + Parse descriptions with polygons. + + Args: + text (`str`): + The generated text. + image_size (`tuple[int, int]`): + Image size (width, height). + allow_empty_phrase (`bool`, *optional*, defaults to `False`): + Allow phrases without text. + polygon_sep_token (`str`, *optional*, defaults to ""): + Token separating polygons. + polygon_start_token (`str`, *optional*, defaults to ""): + Start token for polygons. + polygon_end_token (`str`, *optional*, defaults to ""): + End token for polygons. + with_box_at_start (`bool`, *optional*, defaults to `False`): + Whether a bounding box is at the start of polygons. + + Returns: + `list[dict[str, Any]]`: list of instances with 'polygons', 'cat_name', and optional 'bbox'. + """ + text = text.replace("", "").replace("", "").replace("", "") + + if allow_empty_phrase: + pattern = rf"(?:(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})" + else: + pattern = rf"([^<]+(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})" + phrases = re.findall(pattern, text) + phrase_pattern = r"^\s*(.*?)(?=||||||)" + poly_instance_pattern = rf"{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}" + box_pattern = rf"((?:)+)(?:{re.escape(polygon_sep_token)}|$)" + + instances = [] + for phrase_text in phrases: + phrase_text_strip = re.sub(r"^", "", phrase_text, count=1) + if not phrase_text_strip and not allow_empty_phrase: + continue + match = re.search(phrase_pattern, phrase_text_strip) + if not match: + continue + phrase = match.group().strip() + + if polygon_start_token in phrase_text and polygon_end_token in phrase_text: + poly_instances = [m.group(1) for m in re.finditer(poly_instance_pattern, phrase_text)] + else: + poly_instances = [phrase_text] + + for poly_inst in poly_instances: + poly_matches = list(re.finditer(box_pattern, poly_inst)) + if len(poly_matches) == 0: + continue + bbox = [] + polygons = [] + for poly_match in poly_matches: + poly_str = poly_match.group(1) + poly_bins = [int(m.group(1)) for m in re.finditer(r"", poly_str)] + if with_box_at_start and not bbox: + if len(poly_bins) > 4: + bbox = poly_bins[:4] + poly_bins = poly_bins[4:] + else: + bbox = [0, 0, 0, 0] + if len(poly_bins) % 2 == 1: + poly_bins = poly_bins[:-1] + poly_coords = ( + self.dequantize(torch.tensor(poly_bins).reshape(-1, 2), size=image_size).flatten().tolist() + ) + polygons.append(poly_coords) + + instance = {"cat_name": phrase, "polygons": polygons} + if bbox: + instance["bbox"] = self.dequantize(torch.tensor([bbox]), size=image_size)[0].tolist() + instances.append(instance) + return instances + + def __call__(self, text=None, sequence=None, image_size=None, parse_tasks=None) -> dict[str, Any]: + """ + Process model output and parse into task-specific results. + + Args: + text (`Optional[str]`, *optional*): + Generated text. Either this or `sequence` must be provided. + sequence (`Optional[Union[list[int], torch.Tensor]]`, *optional*): + Token sequence. Either this or `text` must be provided. + image_size (`Optional[tuple[int, int]]`, *optional*): + Image size (width, height) required for dequantization. + parse_tasks (`Optional[Union[str, list[str]]]`, *optional*): + Specific tasks to parse. If None, parse all supported tasks. + + Returns: + `dict[str, Any]`: Parsed results for each task, including the raw 'text'. + """ + if parse_tasks is not None: + parse_tasks = [parse_tasks] if isinstance(parse_tasks, str) else parse_tasks + for task in parse_tasks: + if task not in self.parse_task_config.keys(): + raise ValueError(f"Unsupported parse task: {task}") + + if (text is None and sequence is None) or (text is not None and sequence is not None): + raise ValueError("Exactly one of 'text' or 'sequence' must be provided.") + + if sequence is not None: + if isinstance(sequence, torch.Tensor): + sequence = sequence.tolist() + sequence = sequence[1:] if sequence[0] == self.tokenizer.bos_token_id else sequence # Skip BOS if present + text, _ = self.decode_with_spans(sequence) + + parsed_dict = {"text": text} + + tasks_to_parse = parse_tasks or self.parse_task_config.keys() + for task in tasks_to_parse: + config = self.parse_task_config[task] + pattern = config.get("PATTERN") + + if task == "ocr": + parsed_dict["ocr"] = self.parse_ocr_from_text_and_spans( + text, pattern=pattern, image_size=image_size, area_threshold=config.get("AREA_THRESHOLD", 0.0) + ) + elif task == "phrase_grounding": + parsed_dict["phrase_grounding"] = self.parse_phrase_grounding_from_text_and_spans( + text, image_size=image_size + ) + elif task == "pure_text": + parsed_dict["pure_text"] = text + elif task == "description_with_bboxes": + parsed_dict["description_with_bboxes"] = self.parse_description_with_bboxes_from_text_and_spans( + text, image_size=image_size + ) + elif task == "description_with_polygons": + parsed_dict["description_with_polygons"] = self.parse_description_with_polygons_from_text_and_spans( + text, image_size=image_size + ) + elif task == "polygons": + parsed_dict["polygons"] = self.parse_description_with_polygons_from_text_and_spans( + text, image_size=image_size, allow_empty_phrase=True + ) + elif task == "bboxes": + parsed_dict["bboxes"] = self.parse_description_with_bboxes_from_text_and_spans( + text, image_size=image_size, allow_empty_phrase=True + ) + elif task == "description_with_bboxes_or_polygons": + if "" in text: + instances = self.parse_description_with_polygons_from_text_and_spans(text, image_size=image_size) + else: + instances = self.parse_description_with_bboxes_from_text_and_spans(text, image_size=image_size) + parsed_dict["description_with_bboxes_or_polygons"] = instances + else: + raise ValueError(f"task {task} is not supported") + + return parsed_dict + + +__all__ = ["Florence2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7fcd5dc9cce65ccc83944e483c5b75a3bbe16cd3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_glm4_moe import * + from .modeling_glm4_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/configuration_glm4_moe.py b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/configuration_glm4_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..b9937a8ba4943804573e88fe85bfe7414ce11cbc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/configuration_glm4_moe.py @@ -0,0 +1,242 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm4_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Glm4MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4MoeModel`]. It is used to instantiate a + Glm4Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [THUDM/GLM-4-100B-A10B](https://huggingface.co/THUDM/GLM-4-100B-A10B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151552): + Vocabulary size of the Glm4Moe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Glm4MoeModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 10944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 46): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 96): + Number of attention heads for each attention layer in the Transformer encoder. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + The factor of the partial rotary position. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + number of experts per token. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts. + n_routed_experts (`int`, *optional*, defaults to 128): + Number of routed experts. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 1): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + first_k_dense_replace (`int`, *optional*, defaults to 1): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + use_qk_norm (`bool`, *optional*, defaults to `False`): + Whether to use query-key normalization in the attention + ```python + >>> from transformers import Glm4MoeModel, Glm4MoeConfig + + >>> # Initializing a Glm4Moe style configuration + >>> configuration = Glm4MoeConfig() + + >>> # Initializing a model from the GLM-4-MOE-100B-A10B style configuration + >>> model = Glm4MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Glm4Moe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151552, + hidden_size=4096, + intermediate_size=10944, + num_hidden_layers=46, + num_attention_heads=96, + partial_rotary_factor=0.5, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + moe_intermediate_size=1408, + num_experts_per_tok=8, + n_shared_experts=1, + n_routed_experts=128, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, + norm_topk_prob=True, + use_qk_norm=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.partial_rotary_factor = partial_rotary_factor + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.use_qk_norm = use_qk_norm + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Glm4MoeConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/modeling_glm4_moe.py b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/modeling_glm4_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..77ab940e9ffb015b51de7254072e2f7bdff13aad --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -0,0 +1,616 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm4_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_glm4_moe import Glm4MoeConfig + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class Glm4MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.rope_scaling = config.rope_scaling + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + if self.use_qk_norm: # main diff from Llama + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Glm4MoeMLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Glm4MoeTopkRouter(nn.Module): + def __init__(self, config: Glm4MoeConfig): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Glm4MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Glm4MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Glm4MoeMoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = nn.ModuleList( + [ + Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = Glm4MoeTopkRouter(config) + self.shared_experts = Glm4MoeMLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class Glm4MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Glm4MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = Glm4MoeMoE(config) + else: + self.mlp = Glm4MoeMLP(config) + + self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Glm4MoePreTrainedModel(PreTrainedModel): + config: Glm4MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Glm4MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Glm4MoeDecoderLayer, + "attentions": Glm4MoeAttention, + } + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Glm4MoeTopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + +class Glm4MoeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Glm4MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Glm4MoeModel(Glm4MoePreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"] + + def __init__(self, config: Glm4MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Glm4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Glm4MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Glm4MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM + + >>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Glm4MoePreTrainedModel", "Glm4MoeModel", "Glm4MoeForCausalLM"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/modular_glm4_moe.py b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/modular_glm4_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..20144c8ffc401fc9e4dd25c5736cab463f7479f1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4_moe/modular_glm4_moe.py @@ -0,0 +1,329 @@ +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GLM-4-MOE model.""" + +from typing import Optional + +import torch +from torch import nn + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging +from ..cohere.modeling_cohere import CohereAttention +from ..deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3DecoderLayer, + DeepseekV3ForCausalLM, + DeepseekV3MLP, + DeepseekV3Model, + DeepseekV3PreTrainedModel, + DeepseekV3RMSNorm, + DeepseekV3TopkRouter, +) +from ..gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb # noqa + + +logger = logging.get_logger(__name__) + + +class Glm4MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4MoeModel`]. It is used to instantiate a + Glm4Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [THUDM/GLM-4-100B-A10B](https://huggingface.co/THUDM/GLM-4-100B-A10B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151552): + Vocabulary size of the Glm4Moe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Glm4MoeModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 10944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 46): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 96): + Number of attention heads for each attention layer in the Transformer encoder. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + The factor of the partial rotary position. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + number of experts per token. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts. + n_routed_experts (`int`, *optional*, defaults to 128): + Number of routed experts. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 1): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + first_k_dense_replace (`int`, *optional*, defaults to 1): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + use_qk_norm (`bool`, *optional*, defaults to `False`): + Whether to use query-key normalization in the attention + ```python + >>> from transformers import Glm4MoeModel, Glm4MoeConfig + + >>> # Initializing a Glm4Moe style configuration + >>> configuration = Glm4MoeConfig() + + >>> # Initializing a model from the GLM-4-MOE-100B-A10B style configuration + >>> model = Glm4MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Glm4Moe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151552, + hidden_size=4096, + intermediate_size=10944, + num_hidden_layers=46, + num_attention_heads=96, + partial_rotary_factor=0.5, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + moe_intermediate_size=1408, + num_experts_per_tok=8, + n_shared_experts=1, + n_routed_experts=128, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, + norm_topk_prob=True, + use_qk_norm=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.partial_rotary_factor = partial_rotary_factor + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.use_qk_norm = use_qk_norm + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Glm4MoeAttention(CohereAttention): + def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.rope_scaling = config.rope_scaling + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + +class Glm4MoeMLP(DeepseekV3MLP): + pass + + +class Glm4MoeTopkRouter(DeepseekV3TopkRouter): + def __init__(self, config: Glm4MoeConfig): + nn.Module.__init__(self) + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + + +class Glm4MoeRMSNorm(DeepseekV3RMSNorm): + pass + + +class Glm4MoeDecoderLayer(DeepseekV3DecoderLayer): + pass + + +class Glm4MoePreTrainedModel(DeepseekV3PreTrainedModel): + _can_compile_fullgraph = False + + +class Glm4MoeModel(DeepseekV3Model): + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"] + + +class Glm4MoeForCausalLM(DeepseekV3ForCausalLM): + pass + + +__all__ = [ + "Glm4MoeConfig", + "Glm4MoePreTrainedModel", + "Glm4MoeModel", + "Glm4MoeForCausalLM", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4216c137fbe200d424bf9283b0da2f6e60c2817c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_glm4v import * + from .modeling_glm4v import * + from .processing_glm4v import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/configuration_glm4v.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/configuration_glm4v.py new file mode 100644 index 0000000000000000000000000000000000000000..4c417020fa84666613e09bac4f4d4d428865a279 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/configuration_glm4v.py @@ -0,0 +1,353 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm4v.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Glm4vVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4vVisionModel`]. It is used to instantiate an Glm4vVisionModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of + GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). + + Args: + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + depth (`int`, *optional*, defaults to 24): + Number of layers (depth) in the model. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + intermediate_size (`int`, *optional*, defaults to 13696): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"selu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for attention weights. + projection_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the projection layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + image_size (`int` or `list[int]`, *optional*, defaults to `[336, 336]`): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to `14`): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_hidden_size (`int`, *optional*, defaults to 4096): + The output hidden size of the vision model. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. + Example: + + ```python + >>> from transformers import Glm4vVisionConfig, Glm4vVisionModel + + >>> # Initializing a Glm4vVisionConfig GLM-4.1V-9B style configuration + >>> configuration = Glm4vVisionConfig() + + >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration + >>> model = Glm4vVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4v" + base_config_key = "vision_config" + + def __init__( + self, + depth=24, + hidden_size=1536, + hidden_act="silu", + attention_bias=False, + attention_dropout=0.0, + num_heads=12, + in_channels=3, + image_size=336, + patch_size=14, + rms_norm_eps=1e-05, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=4096, + intermediate_size=13696, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.image_size = image_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.intermediate_size = intermediate_size + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + +class Glm4vTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a + GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151552): + Vocabulary size of the Glm4v model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Glm4vModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 13696): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + image_token_id (`int`, *optional*): + Token index used as placeholder for image embeddings. + video_token_id (`int`, *optional*): + Token index used as placeholder for video embeddings. + + ```python + >>> from transformers import Glm4vTextModel, Glm4vConfig + + >>> # Initializing a GLM-4.1V style configuration + >>> configuration = Glm4vConfig() + + >>> # Initializing a model from the GLM-4.1V style configuration + >>> model = Glm4vTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4v_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Glm4v` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation + "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151552, + hidden_size=4096, + intermediate_size=13696, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + attention_dropout=0.0, + rope_scaling=None, + image_token_id=None, + video_token_id=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Glm4vConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a + GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151343): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151344): + The video token index to encode the image prompt. + image_start_token_id (`int`, *optional*, defaults to 151339): + The image start token index to encode the start of image. + image_end_token_id (`int`, *optional*, defaults to 151340): + The image end token index to encode the end of image. + video_start_token_id (`int`, *optional*, defaults to 151341): + The video start token index to encode the start of video. + video_end_token_id (`int`, *optional*, defaults to 151342): + The video end token index to encode the end of video. + + ```python + >>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig + + >>> # Initializing a GLM-4.1V style configuration + >>> configuration = Glm4vConfig() + + >>> # Initializing a model from the GLM-4.1V style configuration + >>> model = Glm4vForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4v" + sub_configs = {"vision_config": Glm4vVisionConfig, "text_config": Glm4vTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.video_start_token_id = video_start_token_id + self.video_end_token_id = video_end_token_id + self.image_start_token_id = image_start_token_id + self.image_end_token_id = image_end_token_id + + super().__init__(**kwargs) + + +__all__ = ["Glm4vConfig", "Glm4vTextConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/image_processing_glm4v.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/image_processing_glm4v.py new file mode 100644 index 0000000000000000000000000000000000000000..8293545deee239d1b5d19355086cb58fb06b8b1a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/image_processing_glm4v.py @@ -0,0 +1,472 @@ +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for GLM-4.1V.""" + +import math +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, logging +from ...video_utils import VideoInput + + +logger = logging.get_logger(__name__) + + +def smart_resize( + num_frames: int, + height: int, + width: int, + temporal_factor: int = 2, + factor: int = 28, + min_pixels: int = 112 * 112, + max_pixels: int = 14 * 14 * 2 * 2 * 2 * 6144, +): + if num_frames < temporal_factor: + raise ValueError(f"t:{num_frames} must be larger than temporal_factor:{temporal_factor}") + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + t_bar = round(num_frames / temporal_factor) * temporal_factor + + if t_bar * h_bar * w_bar > max_pixels: + beta = math.sqrt((num_frames * height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif t_bar * h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (num_frames * height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + return h_bar, w_bar + + +class Glm4vImageProcessor(BaseImageProcessor): + r""" + Constructs a GLM-4V image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + model_input_names = ["pixel_values", "image_grid_thw"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + elif size is None: + size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000} + self.size = size + + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: Union[ImageInput, VideoInput], + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`List[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_flat_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + num_frames=temporal_patch_size, + height=height, + width=width, + temporal_factor=temporal_patch_size, + factor=patch_size * merge_size, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], + ) + image = resize( + image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + patches = np.array(processed_images) + if data_format == ChannelDimension.LAST: + patches = patches.transpose(0, 3, 1, 2) + if patches.shape[0] % temporal_patch_size != 0: + repeats = np.repeat( + patches[-1][np.newaxis], temporal_patch_size - (patches.shape[0] % temporal_patch_size), axis=0 + ) + patches = np.concatenate([patches, repeats], axis=0) + channel = patches.shape[1] + grid_t = patches.shape[0] // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + patches = patches.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size + ) + + return flatten_patches, (grid_t, grid_h, grid_w) + + def preprocess( + self, + images: ImageInput, + videos: Optional[VideoInput] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + videos (`VideoInput`): + Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If + passing in videos with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + # Try to use config values if set, otherwise fallback to global defaults + size = size if size is not None else self.size + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + elif size is None: + size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000} + + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + patch_size = patch_size if patch_size is not None else self.patch_size + temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size + merge_size = merge_size if merge_size is not None else self.merge_size + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + if images is not None: + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + data = {} + if images is not None: + pixel_values, vision_grid_thws = [], [] + for image in images: + patches, image_grid_thw = self._preprocess( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + merge_size=merge_size, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(patches) + vision_grid_thws.append(image_grid_thw) + pixel_values = np.array(pixel_values) + vision_grid_thws = np.array(vision_grid_thws) + data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of image patches per image. + """ + patch_size = images_kwargs.get("patch_size", self.patch_size) + merge_size = images_kwargs.get("merge_size", self.merge_size) + size = images_kwargs.get("size", {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}) + + factor = patch_size * merge_size + resized_height, resized_width = smart_resize( + num_frames=self.temporal_patch_size, + height=height, + width=width, + factor=factor, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], + temporal_factor=self.temporal_patch_size, + ) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + return grid_h * grid_w + + +__all__ = ["Glm4vImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/image_processing_glm4v_fast.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/image_processing_glm4v_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdf31a437ae305fbfc399bd276476f9b453ed6f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/image_processing_glm4v_fast.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for GLM-4.1V.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import ( + BatchFeature, +) +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + logging, +) +from .image_processing_glm4v import smart_resize + + +logger = logging.get_logger(__name__) + + +class Glm4vFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +@auto_docstring +class Glm4vImageProcessorFast(BaseImageProcessorFast): + do_resize = True + resample = PILImageResampling.BICUBIC + size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000} + do_rescale = True + do_normalize = True + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + do_convert_rgb = True + patch_size = 14 + temporal_patch_size = 2 + merge_size = 2 + valid_kwargs = Glm4vFastImageProcessorKwargs + model_input_names = ["pixel_values", "image_grid_thw"] + + def __init__(self, **kwargs: Unpack[Glm4vFastImageProcessorKwargs]): + super().__init__(**kwargs) + if self.size is not None and ( + self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None + ): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + return super()._further_process_kwargs(size=size, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + patch_size: int, + temporal_patch_size: int, + merge_size: int, + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + """ + + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + height, width = stacked_images.shape[-2:] + if do_resize: + resized_height, resized_width = smart_resize( + num_frames=temporal_patch_size, + height=height, + width=width, + temporal_factor=temporal_patch_size, + factor=patch_size * merge_size, + min_pixels=size.shortest_edge, + max_pixels=size.longest_edge, + ) + stacked_images = self.resize( + stacked_images, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, + ) + resized_images_grouped[shape] = stacked_images + + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_grids = {} + + for shape, stacked_images in grouped_images.items(): + resized_height, resized_width = stacked_images.shape[-2:] + + patches = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if patches.ndim == 4: # (B, C, H, W) + patches = patches.unsqueeze(1) # (B, T=1, C, H, W) + + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat( + 1, temporal_patch_size - (patches.shape[1] % temporal_patch_size), 1, 1, 1 + ) + patches = torch.cat([patches, repeats], dim=1) + + batch_size, t_len, channel = patches.shape[:3] + grid_t = t_len // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + + patches = patches.view( + batch_size, + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + # (B, grid_t, gh, gw, mh, mw, C, tp, ph, pw) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + + flatten_patches = patches.reshape( + batch_size, + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + processed_images_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_grids = reorder_images(processed_grids, grouped_images_index) + + pixel_values = torch.cat(processed_images, dim=0) + image_grid_thw = torch.tensor(processed_grids) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors + ) + + @auto_docstring + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[Glm4vFastImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, **kwargs) + + +__all__ = ["Glm4vImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/modeling_glm4v.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/modeling_glm4v.py new file mode 100644 index 0000000000000000000000000000000000000000..da19208119885e0274ebf679e757654d1c98a2e1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/modeling_glm4v.py @@ -0,0 +1,1636 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm4v.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils.generic import check_model_inputs +from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig + + +@use_kernel_forward_from_hub("RMSNorm") +class Glm4vRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Glm4vRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Glm4VisionMlp(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.out_hidden_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Glm4vVisionPatchEmbed(nn.Module): + def __init__(self, config: Glm4vVisionConfig) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Glm4vVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Glm4vVisionPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None: + super().__init__() + self.proj = nn.Linear(dim, dim, bias=bias) + self.post_projection_norm = LayerNorm(dim) + self.gate_proj = nn.Linear(dim, context_dim, bias=bias) + self.up_proj = nn.Linear(dim, context_dim, bias=bias) + self.down_proj = nn.Linear(context_dim, dim, bias=bias) + self.act1 = nn.GELU() + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.proj(hidden_state) + hidden_state = self.act1(self.post_projection_norm(hidden_state)) + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Glm4vVisionEmbeddings(nn.Module): + def __init__(self, config: Glm4vVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor: + """ + Forward pass with integrated position encoding adaptation using 2D interpolation. + + Args: + embeddings: Input embeddings tensor + lengths (torch.Tensor): Sequence lengths for each image in the batch. + image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w). + h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch. + w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch. + + Returns: + torch.Tensor: Embeddings with adapted position encoding added. + """ + # Get position embedding parameters + pos_embed_weight = self.position_embedding.weight + hidden_size = pos_embed_weight.shape[1] + total_seq = h_coords.shape[0] + device = pos_embed_weight.device + + # Move coordinates to correct device + h_coords, w_coords = h_coords.to(device), w_coords.to(device) + + # Handle empty sequence case + if total_seq == 0: + adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype) + else: + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + if not isinstance(image_shapes, torch.Tensor): + image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq**0.5) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) + + # Calculate target dimensions for each patch + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + + # Normalize coordinates to [-1, 1] range for grid_sample + h_coords = h_coords.to(device=device, dtype=torch.float32) + w_coords = w_coords.to(device=device, dtype=torch.float32) + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + + # Create sampling grid + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border" + ) + + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) + + # Add adapted position encoding to embeddings + embeddings = embeddings + adapted_pos_embed + return embeddings + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Glm4vVisionAttention(nn.Module): + def __init__(self, config: Glm4vVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = config.attention_dropout + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Glm4vVisionBlock(GradientCheckpointingLayer): + def __init__(self, config) -> None: + super().__init__() + self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = Glm4vVisionAttention(config) + self.mlp = Glm4VisionMlp(config, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Glm4vTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Glm4vTextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Glm4vText has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half_llm(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + + return q_embed, k_embed + + +class Glm4vTextAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Glm4vTextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Glm4vTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +class Glm4vTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Glm4vTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Glm4vTextAttention(config, layer_idx) + self.mlp = Glm4vTextMLP(config) + self.input_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Glm4vModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +@auto_docstring +class Glm4vPreTrainedModel(PreTrainedModel): + config: Glm4vConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Glm4vTextDecoderLayer, + "attentions": Glm4vTextAttention, + } + + +class Glm4vVisionModel(Glm4vPreTrainedModel): + config: Glm4vVisionConfig + _no_split_modules = ["Glm4vVisionBlock"] + + def __init__(self, config) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + + self.embeddings = Glm4vVisionEmbeddings(config) + self.patch_embed = Glm4vVisionPatchEmbed(config) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Glm4vVisionBlock(config) for _ in range(config.depth)]) + self.merger = Glm4vVisionPatchMerger( + dim=config.out_hidden_size, context_dim=config.intermediate_size, hidden_act=config.hidden_act + ) + + self.post_conv_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.downsample = nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.out_hidden_size, + kernel_size=config.spatial_merge_size, + stride=config.spatial_merge_size, + ) + self.post_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + self.post_init() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb, pos_ids + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + hidden_states = self.post_conv_layernorm(hidden_states) + + rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + + hidden_states = self.post_layernorm(hidden_states) + + hidden_states = hidden_states.view( + -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1] + ) + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size) + + hidden_states = self.merger(hidden_states) + return hidden_states + + +@auto_docstring +class Glm4vTextModel(Glm4vPreTrainedModel): + config: Glm4vTextConfig + + def __init__(self, config: Glm4vTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Glm4vTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Glm4vTextRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Glm4vModel(Glm4vPreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Glm4vConfig + _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Glm4vVisionModel._from_config(config.vision_config) + self.language_model = Glm4vTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_start_token_id = self.config.video_start_token_id + video_end_token_id = self.config.video_end_token_id + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + video_group_index = 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + input_tokens = input_ids.tolist() + + input_token_type = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if token == image_token_id and not video_check_flg: + input_token_type.append("image") + elif token == image_token_id and video_check_flg: + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group = [] + for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]): + group = list(group) + start_index = group[0][0] + end_index = group[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + llm_pos_ids_list = [] + video_frame_num = 1 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + + if modality_type == "image": + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + image_index += 1 + video_frame_num = 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + video_group_index += 1 + + if video_group_index >= video_grid_thw[video_index][0]: + video_index += 1 + video_group_index = 0 + + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + video_frame_num = 1 + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + # GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Glm4vModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return Glm4vModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Glm4v causal language model (or autoregressive) outputs. + """ +) +class Glm4vCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + + def __init__(self, config): + super().__init__(config) + self.model = Glm4vModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Glm4vCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Glm4vForConditionalGeneration + + >>> model = Glm4vForConditionalGeneration.from_pretrained("THUDM/GLM-4.1V-9B-Thinking") + >>> processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Glm4vCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # GLM-4.1V position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + + if inputs_embeds is not None: + is_image = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_start = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_end = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + is_image = input_ids == self.config.image_start_token_id + is_video_start = input_ids == self.config.video_start_token_id + is_video_end = input_ids == self.config.video_end_token_id + + # Cumulative sum to track if we're inside a video span + # We'll assume well-formed video tags (i.e. matching starts and ends) + video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1) + inside_video = video_level > 0 # shape (batch_size, seq_length) + + # Mask out image tokens that are inside video spans + standalone_images = is_image & (~inside_video) + + # Count per batch + image_counts = standalone_images.sum(dim=1) + video_counts = is_video_start.sum(dim=1) + + return image_counts, video_counts + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Glm4vForConditionalGeneration", "Glm4vModel", "Glm4vPreTrainedModel", "Glm4vTextModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/modular_glm4v.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/modular_glm4v.py new file mode 100644 index 0000000000000000000000000000000000000000..ab624a0d0fb1eb0e86ec1dd8c40e1dd94dc06f84 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/modular_glm4v.py @@ -0,0 +1,1699 @@ +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import ImagesKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import check_model_inputs +from ...video_utils import VideoInput +from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, eager_attention_forward +from ..qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLCausalLMOutputWithPast, + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLMLP, + Qwen2_5_VLModel, + Qwen2_5_VLModelOutputWithPast, + Qwen2_5_VLPreTrainedModel, + Qwen2_5_VLRotaryEmbedding, + Qwen2_5_VLTextModel, + Qwen2_5_VLVisionAttention, + Qwen2_5_VLVisionBlock, +) +from ..qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLVideosProcessorKwargs +from ..qwen2_vl.processing_qwen2_vl import ( + Qwen2_VLProcessor, + Qwen2_VLProcessorKwargs, +) + + +logger = logging.get_logger(__name__) + + +class Glm4vVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4vVisionModel`]. It is used to instantiate an Glm4vVisionModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of + GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). + + Args: + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + depth (`int`, *optional*, defaults to 24): + Number of layers (depth) in the model. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + intermediate_size (`int`, *optional*, defaults to 13696): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"selu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for attention weights. + projection_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the projection layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + image_size (`int` or `list[int]`, *optional*, defaults to `[336, 336]`): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to `14`): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_hidden_size (`int`, *optional*, defaults to 4096): + The output hidden size of the vision model. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. + Example: + + ```python + >>> from transformers import Glm4vVisionConfig, Glm4vVisionModel + + >>> # Initializing a Glm4vVisionConfig GLM-4.1V-9B style configuration + >>> configuration = Glm4vVisionConfig() + + >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration + >>> model = Glm4vVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4v" + base_config_key = "vision_config" + + def __init__( + self, + depth=24, + hidden_size=1536, + hidden_act="silu", + attention_bias=False, + attention_dropout=0.0, + num_heads=12, + in_channels=3, + image_size=336, + patch_size=14, + rms_norm_eps=1e-05, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=4096, + intermediate_size=13696, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.image_size = image_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.intermediate_size = intermediate_size + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + +class Glm4vTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a + GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151552): + Vocabulary size of the Glm4v model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Glm4vModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 13696): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + image_token_id (`int`, *optional*): + Token index used as placeholder for image embeddings. + video_token_id (`int`, *optional*): + Token index used as placeholder for video embeddings. + + ```python + >>> from transformers import Glm4vTextModel, Glm4vConfig + + >>> # Initializing a GLM-4.1V style configuration + >>> configuration = Glm4vConfig() + + >>> # Initializing a model from the GLM-4.1V style configuration + >>> model = Glm4vTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4v_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Glm4v` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation + "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151552, + hidden_size=4096, + intermediate_size=13696, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + attention_dropout=0.0, + rope_scaling=None, + image_token_id=None, + video_token_id=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Glm4vConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a + GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151343): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151344): + The video token index to encode the image prompt. + image_start_token_id (`int`, *optional*, defaults to 151339): + The image start token index to encode the start of image. + image_end_token_id (`int`, *optional*, defaults to 151340): + The image end token index to encode the end of image. + video_start_token_id (`int`, *optional*, defaults to 151341): + The video start token index to encode the start of video. + video_end_token_id (`int`, *optional*, defaults to 151342): + The video end token index to encode the end of video. + + ```python + >>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig + + >>> # Initializing a GLM-4.1V style configuration + >>> configuration = Glm4vConfig() + + >>> # Initializing a model from the GLM-4.1V style configuration + >>> model = Glm4vForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4v" + sub_configs = {"vision_config": Glm4vVisionConfig, "text_config": Glm4vTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.video_start_token_id = video_start_token_id + self.video_end_token_id = video_end_token_id + self.image_start_token_id = image_start_token_id + self.image_end_token_id = image_end_token_id + + super().__init__(**kwargs) + + +# Will be used for both Text and Vision modalities +class Glm4vRMSNorm(Glm4RMSNorm): + pass + + +class Glm4VisionMlp(Qwen2_5_VLMLP): + def __init__(self, config, bias: bool = False): + super().__init__(config, bias) + self.intermediate_size = config.out_hidden_size + + +class Glm4vVisionPatchEmbed(Qwen2_5_VisionPatchEmbed): + def __init__(self, config: Glm4vVisionConfig) -> None: + nn.Module.__init__(self) + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size) + + +class Glm4vVisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding): + pass + + +class Glm4vVisionPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None: + super().__init__() + self.proj = nn.Linear(dim, dim, bias=bias) + self.post_projection_norm = LayerNorm(dim) + self.gate_proj = nn.Linear(dim, context_dim, bias=bias) + self.up_proj = nn.Linear(dim, context_dim, bias=bias) + self.down_proj = nn.Linear(context_dim, dim, bias=bias) + self.act1 = nn.GELU() + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.proj(hidden_state) + hidden_state = self.act1(self.post_projection_norm(hidden_state)) + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Glm4vVisionEmbeddings(nn.Module): + def __init__(self, config: Glm4vVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor: + """ + Forward pass with integrated position encoding adaptation using 2D interpolation. + + Args: + embeddings: Input embeddings tensor + lengths (torch.Tensor): Sequence lengths for each image in the batch. + image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w). + h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch. + w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch. + + Returns: + torch.Tensor: Embeddings with adapted position encoding added. + """ + # Get position embedding parameters + pos_embed_weight = self.position_embedding.weight + hidden_size = pos_embed_weight.shape[1] + total_seq = h_coords.shape[0] + device = pos_embed_weight.device + + # Move coordinates to correct device + h_coords, w_coords = h_coords.to(device), w_coords.to(device) + + # Handle empty sequence case + if total_seq == 0: + adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype) + else: + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + if not isinstance(image_shapes, torch.Tensor): + image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq**0.5) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) + + # Calculate target dimensions for each patch + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + + # Normalize coordinates to [-1, 1] range for grid_sample + h_coords = h_coords.to(device=device, dtype=torch.float32) + w_coords = w_coords.to(device=device, dtype=torch.float32) + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + + # Create sampling grid + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border" + ) + + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) + + # Add adapted position encoding to embeddings + embeddings = embeddings + adapted_pos_embed + return embeddings + + +class Glm4vVisionAttention(Qwen2_5_VLVisionAttention): + def __init__(self, config: Glm4vVisionConfig) -> None: + super().__init__(config) + self.attention_dropout = config.attention_dropout + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + + +class Glm4vVisionBlock(Qwen2_5_VLVisionBlock): + def __init__(self, config) -> None: + super().__init__(config) + self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = Glm4vVisionAttention(config) + self.mlp = Glm4VisionMlp(config, bias=False) + + +class Glm4vTextRotaryEmbedding(Qwen2_5_VLRotaryEmbedding): + pass + + +def rotate_half_llm(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + + return q_embed, k_embed + + +class Glm4vTextAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Glm4vTextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Glm4vTextMLP(Glm4MLP): + pass + + +class Glm4vTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Glm4vTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Glm4vTextAttention(config, layer_idx) + self.mlp = Glm4vTextMLP(config) + self.input_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Glm4vModelOutputWithPast(Qwen2_5_VLModelOutputWithPast): + pass + + +class Glm4vPreTrainedModel(Qwen2_5_VLPreTrainedModel): + _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"] + _can_record_outputs = { + "hidden_states": Glm4vTextDecoderLayer, + "attentions": Glm4vTextAttention, + } + + +class Glm4vVisionModel(Glm4vPreTrainedModel): + config: Glm4vVisionConfig + _no_split_modules = ["Glm4vVisionBlock"] + + def __init__(self, config) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + + self.embeddings = Glm4vVisionEmbeddings(config) + self.patch_embed = Glm4vVisionPatchEmbed(config) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Glm4vVisionBlock(config) for _ in range(config.depth)]) + self.merger = Glm4vVisionPatchMerger( + dim=config.out_hidden_size, context_dim=config.intermediate_size, hidden_act=config.hidden_act + ) + + self.post_conv_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.downsample = nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.out_hidden_size, + kernel_size=config.spatial_merge_size, + stride=config.spatial_merge_size, + ) + self.post_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + self.post_init() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb, pos_ids + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + hidden_states = self.post_conv_layernorm(hidden_states) + + rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + + hidden_states = self.post_layernorm(hidden_states) + + hidden_states = hidden_states.view( + -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1] + ) + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size) + + hidden_states = self.merger(hidden_states) + return hidden_states + + +class Glm4vTextModel(Qwen2_5_VLTextModel): + def __init__(self, config: Glm4vTextConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [Glm4vTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Glm4vTextRotaryEmbedding(config=config) + del self._attn_implementation + del self.has_sliding_layers + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class Glm4vModel(Qwen2_5_VLModel): + _checkpoint_conversion_mapping = {} + _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Glm4vVisionModel._from_config(config.vision_config) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_start_token_id = self.config.video_start_token_id + video_end_token_id = self.config.video_end_token_id + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + video_group_index = 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + input_tokens = input_ids.tolist() + + input_token_type = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if token == image_token_id and not video_check_flg: + input_token_type.append("image") + elif token == image_token_id and video_check_flg: + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group = [] + for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]): + group = list(group) + start_index = group[0][0] + end_index = group[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + llm_pos_ids_list = [] + video_frame_num = 1 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + + if modality_type == "image": + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + image_index += 1 + video_frame_num = 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + video_group_index += 1 + + if video_group_index >= video_grid_thw[video_index][0]: + video_index += 1 + video_group_index = 0 + + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + video_frame_num = 1 + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + # GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Glm4vModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return Glm4vModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +class Glm4vCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast): + pass + + +class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + _checkpoint_conversion_mapping = {} + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Glm4vCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Glm4vForConditionalGeneration + + >>> model = Glm4vForConditionalGeneration.from_pretrained("THUDM/GLM-4.1V-9B-Thinking") + >>> processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Glm4vCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # GLM-4.1V position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + + if inputs_embeds is not None: + is_image = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_start = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_end = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + is_image = input_ids == self.config.image_start_token_id + is_video_start = input_ids == self.config.video_start_token_id + is_video_end = input_ids == self.config.video_end_token_id + + # Cumulative sum to track if we're inside a video span + # We'll assume well-formed video tags (i.e. matching starts and ends) + video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1) + inside_video = video_level > 0 # shape (batch_size, seq_length) + + # Mask out image tokens that are inside video spans + standalone_images = is_image & (~inside_video) + + # Count per batch + image_counts = standalone_images.sum(dim=1) + video_counts = is_video_start.sum(dim=1) + + return image_counts, video_counts + + +class Glm4vVideosProcessorKwargs(Qwen2_5_VLVideosProcessorKwargs): + pass + + +class Glm4vImagesKwargs(ImagesKwargs): + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Glm4vProcessorKwargs(Qwen2_VLProcessorKwargs): + images_kwargs: Glm4vImagesKwargs + videos_kwargs: Glm4vVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_token_type_ids": False, + "return_mm_token_type_ids": False, + }, + "videos_kwargs": {"return_metadata": True}, + } + + +class Glm4vProcessor(Qwen2_VLProcessor): + r""" + Constructs a GLM-4V processor which wraps a GLM-4V image processor and a GLM-4 tokenizer into a single processor. + [`~Glm4vProcessor.__call__`] and [`~Glm4vProcessor.decode`] for more information. + Args: + image_processor ([`Glm4vProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + video_processor ([`Glm4vVideoProcessor`], *optional*): + The video processor is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + tokenizer_class = ("PreTrainedTokenizer", "PreTrainedTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) + self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: Optional[VideoInput] = None, + **kwargs: Unpack[Glm4vProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode + the text. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Glm4vProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + # If user has not requested video metadata, pop it + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.video_processor.merge_size**2 + video_index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + num_frames = video_grid_thw[video_index][0] + video_structure = "" + + metadata = video_metadata[video_index] + if metadata.fps is None: + logger.warning_once( + "SmolVLM requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. " + "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + metadata.fps = 24 if metadata.fps is None else metadata.fps + timestamps = metadata.timestamps[::2] # mrope + + unique_timestamps = [] + for idx in range(0, len(timestamps)): + unique_timestamps.append(timestamps[idx]) + + selected_timestamps = unique_timestamps[:num_frames] + while len(selected_timestamps) < num_frames: + selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + + for frame_idx in range(num_frames): + timestamp_sec = selected_timestamps[frame_idx] + frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{int(timestamp_sec)}" + video_structure += frame_structure + + text[i] = text[i].replace(self.video_token, video_structure, 1) + num_image_tokens = ( + video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] + ) + for frame_idx in range(num_frames): + if self.image_token in text[i]: + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + + video_index += 1 + + text[i] = text[i].replace("<|placeholder|>", self.image_token) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + +__all__ = [ + "Glm4vConfig", + "Glm4vTextConfig", + "Glm4vForConditionalGeneration", + "Glm4vModel", + "Glm4vPreTrainedModel", + "Glm4vProcessor", + "Glm4vTextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/processing_glm4v.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/processing_glm4v.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ebb4d41b497c47a05864a1fe3f53e1273b5dd2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/processing_glm4v.py @@ -0,0 +1,295 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm4v.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +from ...video_utils import VideoInput + + +logger = logging.get_logger(__name__) + + +class Glm4vVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[list[float], float] + + +class Glm4vImagesKwargs(ImagesKwargs): + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Glm4vProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Glm4vImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_token_type_ids": False, + "return_mm_token_type_ids": False, + }, + "videos_kwargs": {"return_metadata": True}, + } + videos_kwargs: Glm4vVideosProcessorKwargs + + +class Glm4vProcessor(ProcessorMixin): + r""" + Constructs a GLM-4V processor which wraps a GLM-4V image processor and a GLM-4 tokenizer into a single processor. + [`~Glm4vProcessor.__call__`] and [`~Glm4vProcessor.decode`] for more information. + Args: + image_processor ([`Glm4vProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + video_processor ([`Glm4vVideoProcessor`], *optional*): + The video processor is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer", "video_processor"] + image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" + + tokenizer_class = ("PreTrainedTokenizer", "PreTrainedTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) + self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: Optional[VideoInput] = None, + **kwargs: Unpack[Glm4vProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode + the text. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Glm4vProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + # If user has not requested video metadata, pop it + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.video_processor.merge_size**2 + video_index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + num_frames = video_grid_thw[video_index][0] + video_structure = "" + + metadata = video_metadata[video_index] + if metadata.fps is None: + logger.warning_once( + "SmolVLM requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. " + "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + metadata.fps = 24 if metadata.fps is None else metadata.fps + timestamps = metadata.timestamps[::2] # mrope + + unique_timestamps = [] + for idx in range(0, len(timestamps)): + unique_timestamps.append(timestamps[idx]) + + selected_timestamps = unique_timestamps[:num_frames] + while len(selected_timestamps) < num_frames: + selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + + for frame_idx in range(num_frames): + timestamp_sec = selected_timestamps[frame_idx] + frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{int(timestamp_sec)}" + video_structure += frame_structure + + text[i] = text[i].replace(self.video_token, video_structure, 1) + num_image_tokens = ( + video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] + ) + for frame_idx in range(num_frames): + if self.image_token in text[i]: + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + + video_index += 1 + + text[i] = text[i].replace("<|placeholder|>", self.image_token) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + video_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (num_frames, height, width) per each video. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = Glm4vProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + if video_sizes is not None: + videos_kwargs = Glm4vProcessorKwargs._defaults.get("videos_kwargs", {}) + videos_kwargs.update(kwargs) + num_video_patches = [ + self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) + for video_size in video_sizes + ] + num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] + vision_data["num_video_tokens"] = num_video_tokens + + return MultiModalData(**vision_data) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + +__all__ = ["Glm4vProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v/video_processing_glm4v.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v/video_processing_glm4v.py new file mode 100644 index 0000000000000000000000000000000000000000..0986c414f1d374589c2f14cdcc514917c09344a2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v/video_processing_glm4v.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""video processor class for GLM-4.1V.""" + +import math +from typing import Optional, Union + +import numpy as np +import torch + +from ...image_processing_utils import BatchFeature +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + PILImageResampling, + SizeDict, + get_image_size, +) +from ...processing_utils import Unpack, VideosKwargs +from ...utils import TensorType, add_start_docstrings +from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor +from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos +from .image_processing_glm4v import smart_resize + + +class Glm4vVideoProcessorInitKwargs(VideosKwargs): + max_image_size: dict[str, int] = None + patch_size: Optional[int] = None + temporal_patch_size: Optional[int] = None + merge_size: Optional[int] = None + image_mean: Optional[list[float]] = None + image_std: Optional[list[float]] = None + + +@add_start_docstrings( + "Constructs a fast GLM-4V image processor that dynamically resizes videos based on the original videos.", + BASE_VIDEO_PROCESSOR_DOCSTRING, + """ + patch_size (`int`, *optional*, defaults to 14): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """, +) +class Glm4vVideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BICUBIC + size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 2 * 30000} + max_image_size = {"longest_edge": 28 * 28 * 2 * 30000} + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + do_sample_frames = True + patch_size = 14 + temporal_patch_size = 2 + max_duration = 300 + merge_size = 2 + valid_kwargs = Glm4vVideoProcessorInitKwargs + num_frames = 16 + fps = 2 + + model_input_names = ["pixel_values_videos", "video_grid_thw"] + + def __init__(self, **kwargs: Unpack[Glm4vVideoProcessorInitKwargs]): + super().__init__(**kwargs) + if self.size is not None and ( + self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None + ): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + return super()._further_process_kwargs(size=size, **kwargs) + + def sample_frames( + self, + metadata: VideoMetadata, + fps: Optional[Union[int, float]] = None, + **kwargs, + ): + """ + Args: + metadata (`VideoMetadata`): + Metadata of the video containing information about total duration, fps and total number of frames. + fps (`int` or `float`, *optional*): + Target frames to sample per second. Defaults to `self.fps`. + Returns: + np.ndarray: + Indices to sample video frames. + """ + if metadata is None or getattr(metadata, "fps", None) is None: + raise ValueError( + "Asked to sample frames per second but no video metadata was provided which is required when sampling in GLM4V. " + "Please pass in `VideoMetadata` object or set `do_sample_frames=False`" + ) + + total_frames = metadata.total_num_frames + requested_fps = fps if fps is not None else self.fps + + max_frame_idx = total_frames - 1 + duration = metadata.duration or round(max_frame_idx / metadata.fps) + 1 + + if duration <= self.max_duration: + n = int(math.floor(duration * requested_fps)) + frame_indices = [min(max_frame_idx, int(math.ceil(i * metadata.fps / requested_fps))) for i in range(n)] + else: + num_samples = int(self.max_duration * requested_fps) + if num_samples >= total_frames: + frame_indices = list(range(total_frames)) + else: + target_seconds = np.linspace(0, duration, num_samples, endpoint=True) + frame_indices = [min(max_frame_idx, int(math.ceil(t * metadata.fps))) for t in target_seconds] + + seen, uniq = set(), [] + for idx in frame_indices: + if idx not in seen: + seen.add(idx) + uniq.append(idx) + + if len(uniq) & 1: + uniq.append(uniq[-1]) + + return np.array(uniq) + + def _preprocess( + self, + videos: list[torch.Tensor], + do_convert_rgb: bool = True, + do_resize: bool = True, + size: Optional[SizeDict] = None, + interpolation: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: float = 1 / 255.0, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ): + grouped_videos, grouped_videos_index = group_videos_by_shape(videos) + resized_videos_grouped = {} + + for shape, stacked_videos in grouped_videos.items(): + B, T, C, H, W = stacked_videos.shape + num_frames, height, width = T, H, W + if do_resize: + resized_height, resized_width = smart_resize( + num_frames=num_frames, + height=height, + width=width, + temporal_factor=temporal_patch_size, + factor=patch_size * merge_size, + min_pixels=size.shortest_edge, + max_pixels=size.longest_edge, + ) + stacked_videos = stacked_videos.view(B * T, C, H, W) + stacked_videos = self.resize( + stacked_videos, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, + ) + stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width) + resized_videos_grouped[shape] = stacked_videos + resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) + + # Group videos by size for further processing + # Needed in case do_resize is False, or resize returns videos with different sizes + grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos) + processed_videos_grouped = {} + processed_grids = {} + for shape, stacked_videos in grouped_videos.items(): + resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST) + + # Fused rescale and normalize + stacked_videos = self.rescale_and_normalize( + stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + patches = stacked_videos + + # Check that videos have `num_frames` divisible by `temporal_patch_size` + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=1) + batch_size, grid_t, channel = patches.shape[:3] + grid_t = grid_t // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + + patches = patches.view( + batch_size, + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten_patches = patches.reshape( + batch_size, + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + processed_videos_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) + processed_grids = reorder_videos(processed_grids, grouped_videos_index) + pixel_values_videos = torch.cat(processed_videos, dim=0) + video_grid_thw = torch.tensor(processed_grids) + data = { + "pixel_values_videos": pixel_values_videos, + "video_grid_thw": video_grid_thw, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Glm4vVideoProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glpn/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/glpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5b38675c34780fd7554db92ea870121f31fd76 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glpn/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_glpn import * + from .feature_extraction_glpn import * + from .image_processing_glpn import * + from .modeling_glpn import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/glpn/configuration_glpn.py b/venv/lib/python3.13/site-packages/transformers/models/glpn/configuration_glpn.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb35bb0b08cdedde15e3afc05ef2d8e3f2dbb5b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glpn/configuration_glpn.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2022 KAIST and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GLPN model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class GLPNConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GLPNModel`]. It is used to instantiate an GLPN + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the GLPN + [vinvino02/glpn-kitti](https://huggingface.co/vinvino02/glpn-kitti) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_encoder_blocks (`int`, *optional*, defaults to 4): + The number of encoder blocks (i.e. stages in the Mix Transformer encoder). + depths (`list[int]`, *optional*, defaults to `[2, 2, 2, 2]`): + The number of layers in each encoder block. + sr_ratios (`list[int]`, *optional*, defaults to `[8, 4, 2, 1]`): + Sequence reduction ratios in each encoder block. + hidden_sizes (`list[int]`, *optional*, defaults to `[32, 64, 160, 256]`): + Dimension of each of the encoder blocks. + patch_sizes (`list[int]`, *optional*, defaults to `[7, 3, 3, 3]`): + Patch size before each encoder block. + strides (`list[int]`, *optional*, defaults to `[4, 2, 2, 2]`): + Stride before each encoder block. + num_attention_heads (`list[int]`, *optional*, defaults to `[1, 2, 5, 8]`): + Number of attention heads for each attention layer in each block of the Transformer encoder. + mlp_ratios (`list[int]`, *optional*, defaults to `[4, 4, 4, 4]`): + Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the + encoder blocks. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + drop_path_rate (`float`, *optional*, defaults to 0.1): + The dropout probability for stochastic depth, used in the blocks of the Transformer encoder. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + decoder_hidden_size (`int`, *optional*, defaults to 64): + The dimension of the decoder. + max_depth (`int`, *optional*, defaults to 10): + The maximum depth of the decoder. + head_in_index (`int`, *optional*, defaults to -1): + The index of the features to use in the head. + + Example: + + ```python + >>> from transformers import GLPNModel, GLPNConfig + + >>> # Initializing a GLPN vinvino02/glpn-kitti style configuration + >>> configuration = GLPNConfig() + + >>> # Initializing a model from the vinvino02/glpn-kitti style configuration + >>> model = GLPNModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glpn" + + def __init__( + self, + num_channels=3, + num_encoder_blocks=4, + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + hidden_sizes=[32, 64, 160, 256], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + num_attention_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + drop_path_rate=0.1, + layer_norm_eps=1e-6, + decoder_hidden_size=64, + max_depth=10, + head_in_index=-1, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_channels = num_channels + self.num_encoder_blocks = num_encoder_blocks + self.depths = depths + self.sr_ratios = sr_ratios + self.hidden_sizes = hidden_sizes + self.patch_sizes = patch_sizes + self.strides = strides + self.mlp_ratios = mlp_ratios + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.drop_path_rate = drop_path_rate + self.layer_norm_eps = layer_norm_eps + self.decoder_hidden_size = decoder_hidden_size + self.max_depth = max_depth + self.head_in_index = head_in_index + + +__all__ = ["GLPNConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glpn/feature_extraction_glpn.py b/venv/lib/python3.13/site-packages/transformers/models/glpn/feature_extraction_glpn.py new file mode 100644 index 0000000000000000000000000000000000000000..327fee4a11fd308f980845a971a9fc8335decaf1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glpn/feature_extraction_glpn.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for GLPN.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_glpn import GLPNImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class GLPNFeatureExtractor(GLPNImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class GLPNFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use GLPNImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["GLPNFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glpn/image_processing_glpn.py b/venv/lib/python3.13/site-packages/transformers/models/glpn/image_processing_glpn.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e0255e2b47c3fcb1a617a5793502d6babb5537 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glpn/image_processing_glpn.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for GLPN.""" + +from typing import TYPE_CHECKING, Optional, Union + +from ...utils.import_utils import requires + + +if TYPE_CHECKING: + from ...modeling_outputs import DepthEstimatorOutput + +import numpy as np +import PIL.Image + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_torch_available, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class GLPNImageProcessor(BaseImageProcessor): + r""" + Constructs a GLPN image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions, rounding them down to the closest multiple of + `size_divisor`. Can be overridden by `do_resize` in `preprocess`. + size_divisor (`int`, *optional*, defaults to 32): + When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest + multiple of `size_divisor`. Can be overridden by `size_divisor` in `preprocess`. + resample (`PIL.Image` resampling filter, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be + overridden by `do_rescale` in `preprocess`. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size_divisor: int = 32, + resample=PILImageResampling.BILINEAR, + do_rescale: bool = True, + **kwargs, + ) -> None: + self.do_resize = do_resize + self.do_rescale = do_rescale + self.size_divisor = size_divisor + self.resample = resample + super().__init__(**kwargs) + + def resize( + self, + image: np.ndarray, + size_divisor: int, + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image, rounding the (height, width) dimensions down to the closest multiple of size_divisor. + + If the image is of dimension (3, 260, 170) and size_divisor is 32, the image will be resized to (3, 256, 160). + + Args: + image (`np.ndarray`): + The image to resize. + size_divisor (`int`): + The image is resized so its height and width are rounded down to the closest multiple of + `size_divisor`. + resample: + `PIL.Image` resampling filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If `None`, the channel dimension format of the input + image is used. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not set, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The resized image. + """ + height, width = get_image_size(image, channel_dim=input_data_format) + # Rounds the height and width down to the closest multiple of size_divisor + new_h = height // size_divisor * size_divisor + new_w = width // size_divisor * size_divisor + image = resize( + image, + (new_h, new_w), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: Union["PIL.Image.Image", TensorType, list["PIL.Image.Image"], list[TensorType]], + do_resize: Optional[bool] = None, + size_divisor: Optional[int] = None, + resample=None, + do_rescale: Optional[bool] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Preprocess the given images. + + Args: + images (`PIL.Image.Image` or `TensorType` or `list[np.ndarray]` or `list[TensorType]`): + Images to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_normalize=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the input such that the (height, width) dimensions are a multiple of `size_divisor`. + size_divisor (`int`, *optional*, defaults to `self.size_divisor`): + When `do_resize` is `True`, images are resized so their height and width are rounded down to the + closest multiple of `size_divisor`. + resample (`PIL.Image` resampling filter, *optional*, defaults to `self.resample`): + `PIL.Image` resampling filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has + an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - `None`: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + size_divisor = size_divisor if size_divisor is not None else self.size_divisor + resample = resample if resample is not None else self.resample + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # Here, the rescale() method uses a constant rescale_factor. It does not need to be validated + # with a rescale_factor. + validate_preprocess_arguments( + do_resize=do_resize, + size=size_divisor, # Here, size_divisor is used as a parameter for optimal resizing instead of size. + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(img) for img in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [self.rescale(image, scale=1 / 255, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_depth_estimation( + self, + outputs: "DepthEstimatorOutput", + target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None, + ) -> list[dict[str, TensorType]]: + """ + Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`DepthEstimatorOutput`]): + Raw outputs of the model. + target_sizes (`TensorType` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `list[dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + depth = depth[None, None, ...] + depth = torch.nn.functional.interpolate(depth, size=target_size, mode="bicubic", align_corners=False) + depth = depth.squeeze() + + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["GLPNImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/glpn/modeling_glpn.py b/venv/lib/python3.13/site-packages/transformers/models/glpn/modeling_glpn.py new file mode 100644 index 0000000000000000000000000000000000000000..e326750743a1aa4cb30791cfde1d46b7a3af16e4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/glpn/modeling_glpn.py @@ -0,0 +1,724 @@ +# coding=utf-8 +# Copyright 2022 KAIST and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GLPN model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from .configuration_glpn import GLPNConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerDropPath +class GLPNDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings +class GLPNOverlapPatchEmbeddings(nn.Module): + """Construct the overlapping patch embeddings.""" + + def __init__(self, patch_size, stride, num_channels, hidden_size): + super().__init__() + self.proj = nn.Conv2d( + num_channels, + hidden_size, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2, + ) + + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, pixel_values): + embeddings = self.proj(pixel_values) + _, _, height, width = embeddings.shape + # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels) + # this can be fed to a Transformer layer + embeddings = embeddings.flatten(2).transpose(1, 2) + embeddings = self.layer_norm(embeddings) + return embeddings, height, width + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerEfficientSelfAttention +class GLPNEfficientSelfAttention(nn.Module): + """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT + paper](https://huggingface.co/papers/2102.12122).""" + + def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): + super().__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " + f"heads ({self.num_attention_heads})" + ) + + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(self.hidden_size, self.all_head_size) + self.key = nn.Linear(self.hidden_size, self.all_head_size) + self.value = nn.Linear(self.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.sr_ratio = sequence_reduction_ratio + if sequence_reduction_ratio > 1: + self.sr = nn.Conv2d( + hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio + ) + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward( + self, + hidden_states, + height, + width, + output_attentions=False, + ): + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if self.sr_ratio > 1: + batch_size, seq_len, num_channels = hidden_states.shape + # Reshape to (batch_size, num_channels, height, width) + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + # Apply sequence reduction + hidden_states = self.sr(hidden_states) + # Reshape back to (batch_size, seq_len, num_channels) + hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) + hidden_states = self.layer_norm(hidden_states) + + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerSelfOutput +class GLPNSelfOutput(nn.Module): + def __init__(self, config, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerAttention with Segformer->GLPN +class GLPNAttention(nn.Module): + def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): + super().__init__() + self.self = GLPNEfficientSelfAttention( + config=config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequence_reduction_ratio=sequence_reduction_ratio, + ) + self.output = GLPNSelfOutput(config, hidden_size=hidden_size) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, height, width, output_attentions=False): + self_outputs = self.self(hidden_states, height, width, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerDWConv +class GLPNDWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, hidden_states, height, width): + batch_size, seq_len, num_channels = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width) + hidden_states = self.dwconv(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + return hidden_states + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerMixFFN with Segformer->GLPN +class GLPNMixFFN(nn.Module): + def __init__(self, config, in_features, hidden_features=None, out_features=None): + super().__init__() + out_features = out_features or in_features + self.dense1 = nn.Linear(in_features, hidden_features) + self.dwconv = GLPNDWConv(hidden_features) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(hidden_features, out_features) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, height, width): + hidden_states = self.dense1(hidden_states) + hidden_states = self.dwconv(hidden_states, height, width) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.segformer.modeling_segformer.SegformerLayer with Segformer->GLPN +class GLPNLayer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio): + super().__init__() + self.layer_norm_1 = nn.LayerNorm(hidden_size) + self.attention = GLPNAttention( + config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequence_reduction_ratio=sequence_reduction_ratio, + ) + self.drop_path = GLPNDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.layer_norm_2 = nn.LayerNorm(hidden_size) + mlp_hidden_size = int(hidden_size * mlp_ratio) + self.mlp = GLPNMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size) + + def forward(self, hidden_states, height, width, output_attentions=False): + self_attention_outputs = self.attention( + self.layer_norm_1(hidden_states), # in GLPN, layernorm is applied before self-attention + height, + width, + output_attentions=output_attentions, + ) + + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection (with stochastic depth) + attention_output = self.drop_path(attention_output) + hidden_states = attention_output + hidden_states + + mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) + + # second residual connection (with stochastic depth) + mlp_output = self.drop_path(mlp_output) + layer_output = mlp_output + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class GLPNEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + + # patch embeddings + embeddings = [] + for i in range(config.num_encoder_blocks): + embeddings.append( + GLPNOverlapPatchEmbeddings( + patch_size=config.patch_sizes[i], + stride=config.strides[i], + num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], + hidden_size=config.hidden_sizes[i], + ) + ) + self.patch_embeddings = nn.ModuleList(embeddings) + + # Transformer blocks + blocks = [] + cur = 0 + for i in range(config.num_encoder_blocks): + # each block consists of layers + layers = [] + if i != 0: + cur += config.depths[i - 1] + for j in range(config.depths[i]): + layers.append( + GLPNLayer( + config, + hidden_size=config.hidden_sizes[i], + num_attention_heads=config.num_attention_heads[i], + drop_path=dpr[cur + j], + sequence_reduction_ratio=config.sr_ratios[i], + mlp_ratio=config.mlp_ratios[i], + ) + ) + blocks.append(nn.ModuleList(layers)) + + self.block = nn.ModuleList(blocks) + + # Layer norms + self.layer_norm = nn.ModuleList( + [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)] + ) + + def forward( + self, + pixel_values, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + batch_size = pixel_values.shape[0] + + hidden_states = pixel_values + for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)): + embedding_layer, block_layer, norm_layer = x + # first, obtain patch embeddings + hidden_states, height, width = embedding_layer(hidden_states) + # second, send embeddings through blocks + for i, blk in enumerate(block_layer): + layer_outputs = blk(hidden_states, height, width, output_attentions) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + # third, apply layer norm + hidden_states = norm_layer(hidden_states) + # fourth, optionally reshape back to (batch_size, num_channels, height, width) + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class GLPNPreTrainedModel(PreTrainedModel): + config: GLPNConfig + base_model_prefix = "glpn" + main_input_name = "pixel_values" + _no_split_modules = [] + + # Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class GLPNModel(GLPNPreTrainedModel): + # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.__init__ with Segformer->GLPN + def __init__(self, config): + super().__init__(config) + self.config = config + + # hierarchical Transformer encoder + self.encoder = GLPNEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.forward + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class GLPNSelectiveFeatureFusion(nn.Module): + """ + Selective Feature Fusion module, as explained in the [paper](https://huggingface.co/papers/2201.07436) (section 3.4). This + module adaptively selects and integrates local and global features by attaining an attention map for each feature. + """ + + def __init__(self, in_channel=64): + super().__init__() + + self.convolutional_layer1 = nn.Sequential( + nn.Conv2d(in_channels=int(in_channel * 2), out_channels=in_channel, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(in_channel), + nn.ReLU(), + ) + + self.convolutional_layer2 = nn.Sequential( + nn.Conv2d(in_channels=in_channel, out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(int(in_channel / 2)), + nn.ReLU(), + ) + + self.convolutional_layer3 = nn.Conv2d( + in_channels=int(in_channel / 2), out_channels=2, kernel_size=3, stride=1, padding=1 + ) + + self.sigmoid = nn.Sigmoid() + + def forward(self, local_features, global_features): + # concatenate features along the channel dimension + features = torch.cat((local_features, global_features), dim=1) + # pass through convolutional layers + features = self.convolutional_layer1(features) + features = self.convolutional_layer2(features) + features = self.convolutional_layer3(features) + # apply sigmoid to get two-channel attention map + attn = self.sigmoid(features) + # construct hybrid features by adding element-wise + hybrid_features = local_features * attn[:, 0, :, :].unsqueeze(1) + global_features * attn[ + :, 1, :, : + ].unsqueeze(1) + + return hybrid_features + + +class GLPNDecoderStage(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + should_skip = in_channels == out_channels + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1) if not should_skip else nn.Identity() + self.fusion = GLPNSelectiveFeatureFusion(out_channels) + self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + + def forward(self, hidden_state, residual=None): + hidden_state = self.convolution(hidden_state) + if residual is not None: + hidden_state = self.fusion(hidden_state, residual) + hidden_state = self.upsample(hidden_state) + + return hidden_state + + hidden_state = self.upsample(hidden_state) + return hidden_state + + +class GLPNDecoder(nn.Module): + def __init__(self, config): + super().__init__() + # we use features from end -> start + reserved_hidden_sizes = config.hidden_sizes[::-1] + out_channels = config.decoder_hidden_size + + self.stages = nn.ModuleList( + [GLPNDecoderStage(hidden_size, out_channels) for hidden_size in reserved_hidden_sizes] + ) + # don't fuse in first stage + self.stages[0].fusion = None + + self.final_upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + + def forward(self, hidden_states: list[torch.Tensor]) -> list[torch.Tensor]: + stage_hidden_states = [] + stage_hidden_state = None + for hidden_state, stage in zip(hidden_states[::-1], self.stages): + stage_hidden_state = stage(hidden_state, stage_hidden_state) + stage_hidden_states.append(stage_hidden_state) + + stage_hidden_states[-1] = self.final_upsample(stage_hidden_state) + + return stage_hidden_states + + +class SiLogLoss(nn.Module): + r""" + Implements the Scale-invariant log scale loss [Eigen et al., 2014](https://huggingface.co/papers/1406.2283). + + $$L=\frac{1}{n} \sum_{i} d_{i}^{2}-\frac{1}{2 n^{2}}\left(\sum_{i} d_{i}^{2}\right)$$ where $d_{i}=\log y_{i}-\log + y_{i}^{*}$. + + """ + + def __init__(self, lambd=0.5): + super().__init__() + self.lambd = lambd + + def forward(self, pred, target): + valid_mask = (target > 0).detach() + diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask]) + loss = torch.sqrt(torch.pow(diff_log, 2).mean() - self.lambd * torch.pow(diff_log.mean(), 2)) + + return loss + + +class GLPNDepthEstimationHead(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + channels = config.decoder_hidden_size + self.head = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor: + # use last features of the decoder + hidden_states = hidden_states[self.config.head_in_index] + + hidden_states = self.head(hidden_states) + + predicted_depth = torch.sigmoid(hidden_states) * self.config.max_depth + predicted_depth = predicted_depth.squeeze(dim=1) + + return predicted_depth + + +@auto_docstring( + custom_intro=""" + GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2. + """ +) +class GLPNForDepthEstimation(GLPNPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.glpn = GLPNModel(config) + self.decoder = GLPNDecoder(config) + self.head = GLPNDepthEstimationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], DepthEstimatorOutput]: + r""" + labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, GLPNForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti") + >>> model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # interpolate to original size + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], + ... ) + + >>> # visualize the prediction + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 255 / predicted_depth.max() + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint8")) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.glpn( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + out = self.decoder(hidden_states) + predicted_depth = self.head(out) + + loss = None + if labels is not None: + loss_fct = SiLogLoss() + loss = loss_fct(predicted_depth, labels) + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["GLPNForDepthEstimation", "GLPNLayer", "GLPNModel", "GLPNPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00b6ccc53fc0efb0fc88c2f95586276cd40010fe --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_got_ocr2 import * + from .image_processing_got_ocr2 import * + from .image_processing_got_ocr2_fast import * + from .modeling_got_ocr2 import * + from .processing_got_ocr2 import * + + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/configuration_got_ocr2.py b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/configuration_got_ocr2.py new file mode 100644 index 0000000000000000000000000000000000000000..eb039f958950c3fbc3cd4401b425348181b13a00 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/configuration_got_ocr2.py @@ -0,0 +1,211 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_got_ocr2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class GotOcr2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GotOcr2VisionModel`]. It is used to instantiate a GOT_OCR2 + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the SAM ViT-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + mlp_dim (`int`, *optional*, defaults to 3072): + The dimensionality of the MLP layer in the Transformer encoder. + """ + + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + mlp_dim=3072, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.mlp_dim = mlp_dim + + +class GotOcr2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GotOcr2ForConditionalGeneration`]. It is used to instantiate a + GotOcr2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of GOT-OCR-2.0. + + e.g [stepfun-ai/GOT-OCR-2.0-hf](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 151859): + The image token index to encode the image prompt. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. + pad_token_id (`int`, *optional*, defaults to -1): + Padding token id. + + ```python + >>> from transformers import GotOcr2ForConditionalGeneration, GotOcr2Config + + >>> # Initializing a GotOcr2 style configuration + >>> configuration = GotOcr2Config() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = GotOcr2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "got_ocr2" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=151859, + image_seq_length=576, + pad_token_id=-1, + **kwargs, + ): + self.image_token_index = image_token_index + self.image_seq_length = image_seq_length + self.pad_token_id = pad_token_id + + if vision_config is None: + self.vision_config = GotOcr2VisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = GotOcr2VisionConfig(**vision_config) + elif isinstance(vision_config, GotOcr2VisionConfig): + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "qwen2") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["qwen2"]( + vocab_size=151860, + hidden_size=1024, + intermediate_size=2816, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=True, + rope_theta=1000000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=21, + attention_dropout=0.0, + ) + + self.text_config = text_config + + super().__init__(**kwargs) + + +__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2.py b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2.py new file mode 100644 index 0000000000000000000000000000000000000000..209ac88ea2fbecc5ddd2c4d82b92830992caac18 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2.py @@ -0,0 +1,526 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Got-OCR-2.""" + +from functools import lru_cache +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +# Similar to image_processing_mllama.get_all_supported_aspect_ratios +@lru_cache(maxsize=10) +def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> list[tuple[int, int]]: + """ + Computes all allowed aspect ratios for a given minimum and maximum number of input tiles. + + This function calculates all possible arrangements of tiles that can be formed + within the constraint of the minimum and maximum number of tiles. Each arrangement is + represented by its aspect ratio (width/height) and the corresponding tile configuration. + + Args: + min_image_tiles (`int`): + The minimum number of tiles allowed. + max_image_tiles (`int`): + The maximum number of tiles allowed. + + Returns: + `list[tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height) + configuration in terms of number of tiles. + + Example: + >>> get_all_supported_aspect_ratios(1, 4) + [(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)] + + """ + aspect_ratios = [] + for width in range(1, max_image_tiles + 1): + for height in range(1, max_image_tiles + 1): + if width * height <= max_image_tiles and width * height >= min_image_tiles: + aspect_ratios.append((width, height)) + + aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1]) + + return aspect_ratios + + +@lru_cache(maxsize=100) +def get_optimal_tiled_canvas( + original_image_size: tuple[int, int], + target_tile_size: tuple[int, int], + min_image_tiles: int, + max_image_tiles: int, +) -> tuple[int, int]: + """ + Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the + original image aspect ratio. + In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with + more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily + excessive tiling. + """ + possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles) + + original_height, original_width = original_image_size + target_tile_height, target_tile_width = target_tile_size + aspect_ratio = original_width / original_height + area = original_width * original_height + + # find the grid with the best aspect ratio + best_ratio_diff = float("inf") + best_grid = (1, 1) + for grid in possible_tile_arrangements: + grid_aspect_ratio = grid[0] / grid[1] + ratio_diff = abs(aspect_ratio - grid_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_grid = grid + elif ratio_diff == best_ratio_diff: + # if the aspect ratio difference is the same, we favor the grid with more patches + # until the area covered by the patches is more than twice the original image area + if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]: + best_grid = grid + + return best_grid + + +class GotOcr2ImageProcessor(BaseImageProcessor): + r""" + Constructs a GOT_OCR2 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + crop_to_patches (`bool`, *optional*, defaults to `False`): + Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the + `preprocess` method. + min_patches (`int`, *optional*, defaults to 1): + The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method. + max_patches (`int`, *optional*, defaults to 12): + The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + crop_to_patches: bool = False, + min_patches: int = 1, + max_patches: int = 12, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 384, "width": 384} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.crop_to_patches = crop_to_patches + self.min_patches = min_patches + self.max_patches = max_patches + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + crop_to_patches: Optional[bool] = None, + min_patches: Optional[int] = None, + max_patches: Optional[int] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`): + Whether to crop the image to patches. + min_patches (`int`, *optional*, defaults to `self.min_patches`): + The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. + max_patches (`int`, *optional*, defaults to `self.max_patches`): + The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches + min_patches = min_patches if min_patches is not None else self.min_patches + max_patches = max_patches if max_patches is not None else self.max_patches + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if crop_to_patches and max_patches > 1: + images = [ + self.crop_image_to_patches( + image, + min_patches=min_patches, + max_patches=max_patches, + patch_size=size, + data_format=input_data_format, + ) + for image in images + ] + num_patches = np.array([len(image) for image in images]) + images = [image for images_list in images for image in images_list] + else: + num_patches = np.array([1] * len(images)) + + for i, image in enumerate(images): + if do_resize: + images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + images[i] = self.normalize( + image=images[i], + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + + images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format) + + encoded_outputs = BatchFeature( + data={"pixel_values": images, "num_patches": num_patches}, tensor_type=return_tensors + ) + + return encoded_outputs + + def crop_image_to_patches( + self, + images: np.ndarray, + min_patches: int, + max_patches: int, + use_thumbnail: bool = True, + patch_size: Optional[Union[tuple, int, dict]] = None, + data_format: Optional[ChannelDimension] = None, + ): + """ + Crop the image to patches and return a list of cropped images. + The number of patches and their grid arrangement are determined by the original image size, + the target patch size and the minimum and maximum number of patches. + The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. + + Args: + images (`np.ndarray`): + The image to be cropped. + min_patches (`int`): + The minimum number of patches to be extracted from the image. + max_patches (`int`): + The maximum number of patches to be extracted from the image. + use_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to add a thumbnail image to the list of cropped patches. + patch_size (`int`, `tuple[int, int]`, `dict`, *optional*): + The size of the output patches. + data_format (`ChannelDimension`, *optional*): + The format of the image data. If `None`, the format is inferred from the input image. + + Returns: + list[`PIL.Image.Image`] or list[np.ndarray]: The list of cropped images. + """ + if data_format is None: + data_format = infer_channel_dimension_format(images) + images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format) + patch_size_height, patch_size_width = patch_size["height"], patch_size["width"] + original_height, original_width = images.shape[-2:] + # find the closest aspect ratio to the target + num_columns, num_rows = get_optimal_tiled_canvas( + (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches + ) + + # calculate the target width and height + target_width = patch_size_width * num_columns + target_height = patch_size_height * num_rows + num_blocks = num_columns * num_rows + + # resize the image so that each patch is of patch_size + resized_image = self.resize( + images, + {"height": target_height, "width": target_width}, + data_format=ChannelDimension.FIRST, + input_data_format=ChannelDimension.FIRST, + ) + # split the image into patches + processed_images = [] + for i in range(num_blocks): + column = i % num_columns + row = i // num_columns + box = ( + column * patch_size_width, + row * patch_size_height, + (column + 1) * patch_size_width, + (row + 1) * patch_size_height, + ) + # split the image + patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]] + patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST) + processed_images.append(patch_image) + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = self.resize( + images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST + ) + processed_images.append(thumbnail_img) + + return processed_images + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of patches per image. + """ + min_patches = images_kwargs.get("min_patches", self.min_patches) + max_patches = images_kwargs.get("max_patches", self.max_patches) + patch_size = images_kwargs.get("patch_size", self.size) + crop_to_patches = images_kwargs.get("crop_to_patches", self.crop_to_patches) + + num_patches = 1 + if crop_to_patches and max_patches > 1: + num_columns, num_rows = get_optimal_tiled_canvas( + (height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches + ) + if num_columns * num_rows > 1: + num_patches += num_columns * num_rows + + return num_patches + + +__all__ = ["GotOcr2ImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..a47a1422a5dc5dcee6c9d75cb7436a1180c32829 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py @@ -0,0 +1,247 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Got-OCR-2.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) +from .image_processing_got_ocr2 import get_optimal_tiled_canvas + + +class GotOcr2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + crop_to_patches (`bool`, *optional*, defaults to `False`): + Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the + `preprocess` method. + min_patches (`int`, *optional*, defaults to 1): + The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method. + max_patches (`int`, *optional*, defaults to 12): + The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method. + """ + + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +@auto_docstring +class GotOcr2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + crop_to_patches = False + min_patches = 1 + max_patches = 12 + valid_kwargs = GotOcr2FastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[GotOcr2FastImageProcessorKwargs]): + super().__init__(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[GotOcr2FastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def crop_image_to_patches( + self, + images: "torch.Tensor", + min_patches: int, + max_patches: int, + use_thumbnail: bool = True, + patch_size: Optional[Union[tuple, int, dict]] = None, + interpolation: Optional["F.InterpolationMode"] = None, + ): + """ + Crop the images to patches and return a list of cropped images. + The number of patches and their grid arrangement are determined by the original image size, + the target patch size and the minimum and maximum number of patches. + The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. + + Args: + images (`torch.Tensor`): + The images to be cropped. + min_patches (`int`): + The minimum number of patches to be extracted from the image. + max_patches (`int`): + The maximum number of patches to be extracted from the image. + use_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to add a thumbnail image to the list of cropped patches. + patch_size (`int`, `tuple[int, int]`, `dict`, *optional*): + The size of the output patches. + The format of the image data. If `None`, the format is inferred from the input image. + + Returns: + list[`PIL.Image.Image`] or list[np.ndarray]: The list of cropped images. + """ + patch_size_height, patch_size_width = patch_size.height, patch_size.width + original_height, original_width = images.shape[-2:] + # find the closest aspect ratio to the target + num_columns, num_rows = get_optimal_tiled_canvas( + (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches + ) + + # calculate the target width and height + target_width = patch_size_width * num_columns + target_height = patch_size_height * num_rows + num_blocks = num_columns * num_rows + + # resize the image so that each patch is of patch_size + resized_image = self.resize( + images, SizeDict(height=target_height, width=target_width), interpolation=interpolation + ) + # split the image into patches + processed_images = [] + for i in range(num_blocks): + column = i % num_columns + row = i // num_columns + box = ( + column * patch_size_width, + row * patch_size_height, + (column + 1) * patch_size_width, + (row + 1) * patch_size_height, + ) + # split the image + patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]] + processed_images.append(patch_image) + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = self.resize(images, patch_size, interpolation=interpolation) + processed_images.append(thumbnail_img) + + processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous() + + return processed_images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + crop_to_patches: bool, + min_patches: int, + max_patches: int, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + if crop_to_patches: + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_images_grouped = {} + num_patches = {} + for shape, stacked_images in grouped_images.items(): + stacked_images = self.crop_image_to_patches( + stacked_images, + min_patches, + max_patches, + patch_size=size, + interpolation=interpolation, + ) + processed_images_grouped[shape] = stacked_images + num_patches[shape] = [stacked_images.shape[1]] * stacked_images.shape[0] + images = reorder_images(processed_images_grouped, grouped_images_index) + images = [image for images_list in images for image in images_list] + num_patches = reorder_images(num_patches, grouped_images_index) + else: + num_patches = [1] * len(images) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature( + data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors + ) + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of patches per image. + """ + min_patches = images_kwargs.get("min_patches", self.min_patches) + max_patches = images_kwargs.get("max_patches", self.max_patches) + patch_size = images_kwargs.get("patch_size", self.size) + crop_to_patches = images_kwargs.get("crop_to_patches", self.crop_to_patches) + + num_patches = 1 + if crop_to_patches and max_patches > 1: + num_columns, num_rows = get_optimal_tiled_canvas( + (height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches + ) + if num_columns * num_rows > 1: + num_patches += num_columns * num_rows + + return num_patches + + +__all__ = ["GotOcr2ImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/modeling_got_ocr2.py b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/modeling_got_ocr2.py new file mode 100644 index 0000000000000000000000000000000000000000..eb22e62bfd7df1f56b8c5b13ac6fcacb90df6696 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -0,0 +1,840 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_got_ocr2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.utils.generic import check_model_inputs + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ..auto import AutoModel +from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig + + +class GotOcr2MLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +class GotOcr2VisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def get_decomposed_rel_pos( + self, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: tuple[int, int], + k_size: tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + + decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + + return decomposed_rel_pos + + def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + return attn_output, attn_weights + + +class GotOcr2VisionLayer(GradientCheckpointingLayer): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = GotOcr2VisionAttention(config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = GotOcr2MLPBlock(config) + self.window_size = window_size + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + return hidden_states + + +@auto_docstring +class GotOcr2PreTrainedModel(PreTrainedModel): + config: GotOcr2Config + base_model_prefix = "" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = False + _supports_sdpa = False + + _can_compile_fullgraph = True + _supports_flex_attn = False + _supports_attention_backend = True + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GotOcr2VisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() + elif isinstance(module, GotOcr2VisionEncoder): + if module.pos_embed is not None: + module.pos_embed.data.zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for got_ocr2 vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + """ +) +class GotOcr2VisionEncoderOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class GotOcr2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class GotOcr2LayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +class GotOcr2VisionNeck(nn.Module): + def __init__(self, config: GotOcr2VisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +class GotOcr2VisionEncoder(GotOcr2PreTrainedModel): + _can_record_outputs = {"hidden_states": GotOcr2VisionLayer, "attentions": GotOcr2VisionAttention} + + def __init__(self, config: GotOcr2VisionConfig): + super().__init__(config) + self.config = config + self.image_size = config.image_size + self.patch_embed = GotOcr2PatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = GotOcr2VisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = GotOcr2VisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + @check_model_inputs(tie_last_hidden_states=False) + def forward( + self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs] + ) -> GotOcr2VisionEncoderOutput: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + hidden_states = self.neck(hidden_states) + return GotOcr2VisionEncoderOutput( + last_hidden_state=hidden_states, + ) + + +class GotOcr2MultiModalProjector(nn.Module): + def __init__(self, config: GotOcr2Config): + super().__init__() + vision_output_channels = config.vision_config.output_channels + language_hidden_size = config.text_config.hidden_size + self.conv_upsampler1 = nn.Conv2d( + vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.conv_upsampler2 = nn.Conv2d( + vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False + ) + self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size) + + def forward(self, vision_embeddings: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv_upsampler1(vision_embeddings) + hidden_state = self.conv_upsampler2(hidden_state) + hidden_state = hidden_state.flatten(2).permute(0, 2, 1) + hidden_state = self.multimodal_projector(hidden_state) + return hidden_state + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GotOcr2 causal language model (or autoregressive) outputs. + """ +) +class GotOcr2CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GotOcr2 outputs, with hidden states and attentions. + """ +) +class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@auto_docstring( + custom_intro=""" + The GotOcr2 model which consists of a vision backbone and a language model, without a language modeling head. + """ +) +class GotOcr2Model(GotOcr2PreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: GotOcr2Config): + super().__init__(config) + self.vision_tower = GotOcr2VisionEncoder(config.vision_config) + + self.multi_modal_projector = GotOcr2MultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values).last_hidden_state + return self.multi_modal_projector(image_outputs) + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, GotOcr2ModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return GotOcr2ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The GOT_OCR2 model which consists of a vision backbone and a language model. + """ +) +class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GotOcr2Config): + super().__init__(config) + self.model = GotOcr2Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + **kwargs, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, GotOcr2CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, GotOcr2ForConditionalGeneration, TextStreamer + + >>> model = GotOcr2ForConditionalGeneration.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf").to("cuda") + >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") + + >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(image, return_tensors="pt", color="green").to("cuda") + + >>> # Generate + >>> streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) + >>> generate_ids = model.generate( + ... **inputs, + ... do_sample=False, + ... tokenizer = processor.tokenizer, + ... stop_strings='<|im_end|>', + ... streamer=streamer, + ... max_new_tokens=4096, + ... ) + "You should keep in mind what features from the module should be used, especially + when you're planning to sell a template." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return GotOcr2CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/modular_got_ocr2.py b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/modular_got_ocr2.py new file mode 100644 index 0000000000000000000000000000000000000000..0ecf39fcd03b577b406868a639d3ce8ee9425e3d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/modular_got_ocr2.py @@ -0,0 +1,483 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union + +import torch +import torch.nn as nn + +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, logging +from ..auto import CONFIG_MAPPING, AutoConfig +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, + LlavaPreTrainedModel, + TransformersKwargs, +) +from ..sam.modeling_sam import ( + SamMLPBlock, + SamPreTrainedModel, + SamVisionAttention, + SamVisionEncoder, + SamVisionLayer, +) + + +logger = logging.get_logger(__name__) + + +class GotOcr2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GotOcr2VisionModel`]. It is used to instantiate a GOT_OCR2 + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the SAM ViT-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + mlp_dim (`int`, *optional*, defaults to 3072): + The dimensionality of the MLP layer in the Transformer encoder. + """ + + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + mlp_dim=3072, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.mlp_dim = mlp_dim + + +class GotOcr2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GotOcr2ForConditionalGeneration`]. It is used to instantiate a + GotOcr2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of GOT-OCR-2.0. + + e.g [stepfun-ai/GOT-OCR-2.0-hf](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 151859): + The image token index to encode the image prompt. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. + pad_token_id (`int`, *optional*, defaults to -1): + Padding token id. + + ```python + >>> from transformers import GotOcr2ForConditionalGeneration, GotOcr2Config + + >>> # Initializing a GotOcr2 style configuration + >>> configuration = GotOcr2Config() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = GotOcr2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "got_ocr2" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=151859, + image_seq_length=576, + pad_token_id=-1, + **kwargs, + ): + self.image_token_index = image_token_index + self.image_seq_length = image_seq_length + self.pad_token_id = pad_token_id + + if vision_config is None: + self.vision_config = GotOcr2VisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = GotOcr2VisionConfig(**vision_config) + elif isinstance(vision_config, GotOcr2VisionConfig): + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "qwen2") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["qwen2"]( + vocab_size=151860, + hidden_size=1024, + intermediate_size=2816, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=True, + rope_theta=1000000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=21, + attention_dropout=0.0, + ) + + self.text_config = text_config + + super().__init__(**kwargs) + + +class GotOcr2MLPBlock(SamMLPBlock): + pass + + +class GotOcr2VisionAttention(SamVisionAttention): + pass + + +class GotOcr2VisionLayer(SamVisionLayer): + def __init__(self, config, window_size): + super().__init__(config, window_size) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = GotOcr2VisionAttention(config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = GotOcr2MLPBlock(config) + self.window_size = window_size + + +class GotOcr2PreTrainedModel(SamPreTrainedModel): + pass + + +class GotOcr2VisionEncoder(SamVisionEncoder, GotOcr2PreTrainedModel): + pass + + +class GotOcr2MultiModalProjector(nn.Module): + def __init__(self, config: GotOcr2Config): + super().__init__() + vision_output_channels = config.vision_config.output_channels + language_hidden_size = config.text_config.hidden_size + self.conv_upsampler1 = nn.Conv2d( + vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.conv_upsampler2 = nn.Conv2d( + vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False + ) + self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size) + + def forward(self, vision_embeddings: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv_upsampler1(vision_embeddings) + hidden_state = self.conv_upsampler2(hidden_state) + hidden_state = hidden_state.flatten(2).permute(0, 2, 1) + hidden_state = self.multimodal_projector(hidden_state) + return hidden_state + + +class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + pass + + +class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast): + pass + + +class GotOcr2PreTrainedModel(LlavaPreTrainedModel): + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, GotOcr2VisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() + elif isinstance(module, GotOcr2VisionEncoder): + if module.pos_embed is not None: + module.pos_embed.data.zero_() + + +class GotOcr2Model(LlavaModel): + def __init__(self, config: GotOcr2Config): + super().__init__(config) + self.vision_tower = GotOcr2VisionEncoder(config.vision_config) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values).last_hidden_state + return self.multi_modal_projector(image_outputs) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, GotOcr2ModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return GotOcr2ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, GotOcr2CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, GotOcr2ForConditionalGeneration, TextStreamer + + >>> model = GotOcr2ForConditionalGeneration.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf").to("cuda") + >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") + + >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(image, return_tensors="pt", color="green").to("cuda") + + >>> # Generate + >>> streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) + >>> generate_ids = model.generate( + ... **inputs, + ... do_sample=False, + ... tokenizer = processor.tokenizer, + ... stop_strings='<|im_end|>', + ... streamer=streamer, + ... max_new_tokens=4096, + ... ) + "You should keep in mind what features from the module should be used, especially + when you're planning to sell a template." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return GotOcr2CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + +__all__ = [ + "GotOcr2VisionConfig", + "GotOcr2Config", + "GotOcr2PreTrainedModel", + "GotOcr2Model", + "GotOcr2ForConditionalGeneration", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/processing_got_ocr2.py b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/processing_got_ocr2.py new file mode 100644 index 0000000000000000000000000000000000000000..16c062ec63ade1310971f8797291439b59bad5ac --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/got_ocr2/processing_got_ocr2.py @@ -0,0 +1,261 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union + +import numpy as np + +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...utils import is_vision_available, logging + + +if is_vision_available(): + from ...image_utils import load_images + +logger = logging.get_logger(__name__) + + +class GotOcr2TextKwargs(TextKwargs, total=False): + format: Optional[bool] + + +class GotOcr2ImagesKwargs(ImagesKwargs, total=False): + box: Optional[Union[list, tuple[float, float], tuple[float, float, float, float]]] + color: Optional[str] + num_image_tokens: Optional[int] + multi_page: Optional[bool] + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: GotOcr2TextKwargs + images_kwargs: GotOcr2ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "format": False, + }, + "images_kwargs": { + "num_image_tokens": 256, + "multi_page": False, + "crop_to_patches": False, + "min_patches": 1, + "max_patches": 12, + }, + } + + +def preprocess_box_annotation(box: Union[list, tuple], image_size: tuple[int, int]) -> list: + """ + Convert box annotation to the format [x1, y1, x2, y2] in the range [0, 1000]. + """ + width, height = image_size + if len(box) == 4: + box[0] = int(box[0] / width * 1000) + box[1] = int(box[1] / height * 1000) + box[2] = int(box[2] / width * 1000) + box[3] = int(box[3] / height * 1000) + else: + raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].") + + return list(box) + + +class GotOcr2Processor(ProcessorMixin): + r""" + Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and + [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and + tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information. + Args: + image_processor ([`GotOcr2ImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "PreTrainedTokenizerFast" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + self.message_start_token = "<|im_start|>" + self.message_end_token = "<|im_end|>" + self.img_start_token = "" + self.img_end_token = "" + self.img_pad_token = "" + self.image_token = "" # keep the above for BC, but we need to call it `image_token` + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail." + + def _make_list_of_inputs(self, images, text, box, color, multi_page): + if not isinstance(images, (list, tuple)): + images = [images] + if multi_page: + logger.warning("Multi-page inference is enabled but only one image is passed.") + images = [images] + elif isinstance(images[0], (list, tuple)) and not multi_page: + raise ValueError("Nested images are only supported with `multi_page` set to `True`.") + elif not isinstance(images[0], (list, tuple)) and multi_page: + images = [images] + + if isinstance(text, str): + text = [text] + + if not isinstance(box[0], (list, tuple)): + # Use the same box for all images + box = [box for _ in range(len(images))] + if not isinstance(color, (list, tuple)): + color = [color for _ in range(len(images))] + + return images, text, box, color + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[GotOcr2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text` + is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and + `crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to + GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + format (`bool`, *optional*): + If set, will add the format token to the query, and the model will return the OCR result with formatting. + box (`list[float]`, `list[tuple[float, float]]`, `list[tuple[float, float, float, float]]`, *optional*): + The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it + will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the + form (x1, y1, x2, y2). + color (`str`, *optional*): + The color annotation to be added to the query. The model will return the OCR result within the box with + the specified color. + multi_page (`bool`, *optional*): + If set, will enable multi-page inference. The model will return the OCR result across multiple pages. + crop_to_patches (`bool`, *optional*): + If set, will crop the image to patches. The model will return the OCR result upon the patch reference. + min_patches (`int`, *optional*): + The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to + `True`. + max_patches (`int`, *optional*): + The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to + `True`. + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + output_kwargs = self._merge_kwargs( + GotOcr2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + format_output = output_kwargs["text_kwargs"].pop("format") + num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens") + box = output_kwargs["images_kwargs"].pop("box", [None]) + color = output_kwargs["images_kwargs"].pop("color", None) + multi_page = output_kwargs["images_kwargs"].pop("multi_page") + + crop_to_patches = output_kwargs["images_kwargs"].get("crop_to_patches") + images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page) + if multi_page: + # save the number of pages per batch + num_pages_per_batch = [len(image_group) for image_group in images] + # flatten the list of images + images = [image for image_group in images for image in image_group] + else: + num_pages_per_batch = [1 for _ in range(len(images))] + # Load images as we need to know the image size + images = load_images(images) + image_sizes = [image.size for image in images] + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + num_patches_array = image_inputs.pop("num_patches") + if text is None: + text = [] + patch_indices = np.cumsum(num_pages_per_batch) + for index, (num_pages, box_single, color_single) in enumerate(zip(num_pages_per_batch, box, color)): + current_patch_index = patch_indices[index - 1] if index > 0 else 0 + num_patches = sum(num_patches_array[current_patch_index : current_patch_index + num_pages]) + if box_single[0] is not None: + box_single = preprocess_box_annotation(box_single, image_sizes[index]) + query = ( + f"{f'[{color_single}] ' if color_single is not None else ''}" + f"{str(box_single) if box_single[0] is not None else ''} " + "OCR" + f"{' with format' if format_output else ''}" + f"{' across multi pages' if multi_page else ''}" + f"{' upon the patch reference' if crop_to_patches else ''}" + ": " + ) + prompt = ( + self.message_start_token + + self.system_query + + self.message_end_token + + self.message_start_token + + "user\n" + + self.img_start_token + + self.img_pad_token * num_image_tokens * num_patches + + self.img_end_token + + "\n" + + query + + self.message_end_token + + self.message_start_token + + "assistant\n" + ) + text.append(prompt) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + +__all__ = ["GotOcr2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..578577f22882cdc5eea08928e274a18725cf4615 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gpt_neo import * + from .modeling_flax_gpt_neo import * + from .modeling_gpt_neo import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/configuration_gpt_neo.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/configuration_gpt_neo.py new file mode 100644 index 0000000000000000000000000000000000000000..875a170277d2048dcadda9cd8f57205a11742797 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GPT Neo model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import Any, Optional + +from ... import PreTrainedTokenizer, TensorType, is_torch_available +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class GPTNeoConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GPTNeoModel`]. It is used to instantiate a GPT + Neo model according to the specified arguments, defining the model architecture. Instantiating a configuration with + the defaults will yield a similar configuration to that of the GPTNeo + [EleutherAI/gpt-neo-1.3B](https://huggingface.co/EleutherAI/gpt-neo-1.3B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different + tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`]. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + attention_types (`List`, *optional*, defaults to `[[['global', 'local'], 12]]`): + The type of attention for each layer in a `List` of the following format `[[["attention_type"], + num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the + value of `attention_type` from `["global", "local"]` + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + window_size (`int`, *optional*, defaults to 256): + The size of the sliding window for local attention. + activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + resid_dropout (`float`, *optional*, defaults to 0.0): + Residual dropout used in the attention pattern. + embed_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + classifier_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`]. The + dropout ratio for the hidden layer. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + bos_token_id (`int`, *optional*, defaults to 50256): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + The id of the end of sentence token in the vocabulary. + + Example: + + ```python + >>> from transformers import GPTNeoConfig, GPTNeoModel + + >>> # Initializing a GPTNeo EleutherAI/gpt-neo-1.3B style configuration + >>> configuration = GPTNeoConfig() + + >>> # Initializing a model (with random weights) from the EleutherAI/gpt-neo-1.3B style configuration + >>> model = GPTNeoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt_neo" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=50257, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=24, + attention_types=[[["global", "local"], 12]], + num_heads=16, + intermediate_size=None, + window_size=256, + activation_function="gelu_new", + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + classifier_dropout=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.intermediate_size = intermediate_size + self.window_size = window_size + self.activation_function = activation_function + self.resid_dropout = resid_dropout + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.attention_types = attention_types + self.attention_layers = self.expand_attention_types_params(attention_types) + + if len(self.attention_layers) != self.num_layers: + raise ValueError( + "Configuration for convolutional module is incorrect. " + "It is required that `len(config.attention_layers)` == `config.num_layers` " + f"but is `len(config.attention_layers) = {len(self.attention_layers)}`, " + f"`config.num_layers = {self.num_layers}`. " + "`config.attention_layers` is prepared using `config.attention_types`. " + "Please verify the value of `config.attention_types` argument." + ) + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + @staticmethod + def expand_attention_types_params(attention_types): + attentions = [] + for item in attention_types: + for _ in range(item[1]): + attentions.extend(item[0]) + return attentions + + +def custom_unfold(input, dimension, size, step): + """Custom torch.Tensor.unfold implementation to enable the export to ONNX.""" + import torch + + shape = input.size() + rank = len(shape) + sizedim = shape[dimension] + + low_indices = torch.arange(0, sizedim, step) + min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1 + indices = torch.arange(size) + low_indices[:min_length][:, None] + + s = [slice(None)] * rank + s[dimension] = indices + sliced = input[s] + + perm = list(range(0, rank + 1)) + perm.append(perm.pop(dimension + 1)) + + return sliced.permute(perm) + + +def custom_get_block_length_and_num_blocks(seq_length, window_size): + """ + Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as + original implementation uses Python variables and control flow. + """ + import torch + + candidates = torch.arange(1, window_size) + remainders = torch.remainder(seq_length, candidates) + divisor_indices = remainders == 0 + divisors = candidates[divisor_indices] + largest_divisor = torch.max(divisors) + return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor") + + +class GPTNeoOnnxConfig(OnnxConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_attention_heads(self) -> int: + return self._config.num_heads + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 + + +__all__ = ["GPTNeoConfig", "GPTNeoOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/modeling_flax_gpt_neo.py new file mode 100644 index 0000000000000000000000000000000000000000..a6cdc50b359b0415fd165b05311eeb9bc07c7526 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -0,0 +1,687 @@ +# coding=utf-8 +# Copyright 2021 The Eleuther AI and The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_gpt_neo import GPTNeoConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GPTNeoConfig" +_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B" + + +GPT_NEO_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`GPTNeoConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +GPT_NEO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxGPTNeoSelfAttention(nn.Module): + config: GPTNeoConfig + attention_type: str + dtype: jnp.dtype = jnp.float32 + + def setup(self): + config = self.config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and " + f"`num_heads`: {self.num_heads})." + ) + + self.attn_dropout = nn.Dropout(config.attention_dropout) + self.resid_dropout = nn.Dropout(config.resid_dropout) + + dense = partial( + nn.Dense, + self.embed_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False) + self.out_proj = dense() + + self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") + if self.attention_type == "local": + self.causal_mask = self.causal_mask ^ jnp.tril(self.causal_mask, -config.window_size) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + query = self.q_proj(hidden_states) * jnp.sqrt(self.head_dim).astype(self.dtype) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query) + key = self._split_heads(key) + value = self._split_heads(value) + + query_length, key_length = query.shape[1], key.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + + dropout_rng = None + if not deterministic and self.config.attention_dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) + + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + + # usual dot product attention + attn_weights = dot_product_attention_weights( + query, + key, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_dropout, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output, deterministic=deterministic) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxGPTNeoAttention(nn.Module): + config: GPTNeoConfig + layer_id: int = 0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + attention_type = self.config.attention_layers[self.layer_id] + self.attention = FlaxGPTNeoSelfAttention(self.config, attention_type, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + return self.attention( + hidden_states, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + + +class FlaxGPTNeoMLP(nn.Module): + config: GPTNeoConfig + intermediate_size: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + embed_dim = self.config.hidden_size + kernel_init = jax.nn.initializers.normal(self.config.initializer_range) + self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) + self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) + self.act = ACT2FN[self.config.activation_function] + self.dropout = nn.Dropout(rate=self.config.resid_dropout) + + def __call__(self, hidden_states, deterministic: bool = True): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxGPTNeoBlock(nn.Module): + config: GPTNeoConfig + layer_id: int = 0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + hidden_size = self.config.hidden_size + inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + self.attn = FlaxGPTNeoAttention(self.config, layer_id=self.layer_id, dtype=self.dtype) + self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + self.mlp = FlaxGPTNeoMLP(self.config, inner_dim, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + # residual connection + attn_output = outputs[0] + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return (hidden_states,) + outputs[1:] + + +class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTNeoConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: GPTNeoConfig, + input_shape: tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: Optional[dict] = None, + past_key_values: Optional[dict] = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") + + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTNeoAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxGPTNeoBlockCollection(nn.Module): + config: GPTNeoConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.blocks = [ + FlaxGPTNeoBlock(self.config, layer_id=i, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxGPTNeoModule` will filter them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +class FlaxGPTNeoModule(nn.Module): + config: GPTNeoConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embed_dim = self.config.hidden_size + embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) + self.wte = nn.Embed( + self.config.vocab_size, + self.embed_dim, + embedding_init=embedding_init, + ) + self.wpe = nn.Embed( + self.config.max_position_embeddings, + self.embed_dim, + embedding_init=embedding_init, + ) + self.dropout = nn.Dropout(rate=self.config.embed_dropout) + self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype) + self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + input_embeds = self.wte(input_ids.astype("i4")) + position_embeds = self.wpe(position_ids.astype("i4")) + + hidden_states = input_embeds + position_embeds + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.h( + hidden_states, + attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings( + "The bare GPTNeo Model transformer outputting raw hidden-states without any specific head on top.", + GPT_NEO_START_DOCSTRING, +) +class FlaxGPTNeoModel(FlaxGPTNeoPreTrainedModel): + module_class = FlaxGPTNeoModule + + +append_call_sample_docstring(FlaxGPTNeoModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) + + +class FlaxGPTNeoForCausalLMModule(nn.Module): + config: GPTNeoConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.transformer = FlaxGPTNeoModule(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +@add_start_docstrings( + """ + The GPTNeo Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT_NEO_START_DOCSTRING, +) +class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel): + module_class = FlaxGPTNeoForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPTNeo uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring(FlaxGPTNeoForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC) + + +__all__ = ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py new file mode 100644 index 0000000000000000000000000000000000000000..69d74565745a578412af97ba686dacde513f0fb9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -0,0 +1,1192 @@ +# coding=utf-8 +# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPT Neo model.""" + +import os +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + logging, +) +from .configuration_gpt_neo import GPTNeoConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt_neo_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + if "global_step" not in name and "adam" not in name: + array = tf.train.load_variable(tf_path, name) + array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy() + name = name.replace("attn/q", "attn/attention/q_proj/w") + name = name.replace("attn/k", "attn/attention/k_proj/w") + name = name.replace("attn/v", "attn/attention/v_proj/w") + name = name.replace("attn/o", "attn/attention/out_proj/w") + name = name.replace("norm_1", "ln_1") + name = name.replace("norm_2", "ln_2") + name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b") + name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w") + name = name.replace("conv1d_main/c_fc/bias", "c_fc/b") + name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w") + name = name.replace("conv1d_main/c_proj/bias", "c_proj/b") + + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name[5:] # skip "gpt2/" + name = name.split("/") + pointer = model.transformer + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]: + array = array.transpose() + + if name == ["wte"]: + # if vocab is padded, then trim off the padding embeddings + array = array[: config.vocab_size] + + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}") + + print(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + + # init the final linear layer using word embeddings + embs = model.transformer.wte.weight + lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False) + lin.weight = embs + model.set_output_embeddings(lin) + return model + + +class GPTNeoSelfAttention(nn.Module): + def __init__(self, config, attention_type, layer_id=None): + super().__init__() + self.config = config + + max_positions = config.max_position_embeddings + bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view( + 1, 1, max_positions, max_positions + ) + + # local causal self attention is a sliding window where each token can only attend to the previous + # window_size tokens. This is implemented by updating the causal mask such that for each token + # all other tokens are masked except the previous window_size tokens. + if attention_type == "local": + bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size)) + + self.register_buffer("bias", bias, persistent=False) + self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) + + self.attn_dropout = nn.Dropout(float(config.attention_dropout)) + self.resid_dropout = nn.Dropout(float(config.resid_dropout)) + self.is_causal = True + self.layer_id = layer_id + + self.embed_dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + # Apply sliding window masking for local attention layers + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key, value = layer_past.update(key, value, self.layer_id, cache_kwargs) + + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + return attn_output, attn_weights + + +class GPTNeoFlashAttention2(GPTNeoSelfAttention): + """ + GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + bsz, _, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key, value = layer_past.update(key, value, self.layer_id, cache_kwargs) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.config.attention_dropout if self.training else 0.0 + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + device_type = query.device.type if query.device.type != "mps" else "cpu" + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attn_dropout, + softmax_scale=1.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + return attn_output, attn_weights_reshaped + + +GPT_NEO_ATTENTION_CLASSES = { + "eager": GPTNeoSelfAttention, + "flash_attention_2": GPTNeoFlashAttention2, +} + + +class GPTNeoAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.layer_id = layer_id + self.attention_layers = config.attention_layers + self.attention_type = self.attention_layers[layer_id] + + if self.attention_type in ["global", "local"]: + self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation]( + config, self.attention_type, layer_id + ) + else: + raise NotImplementedError( + "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: " + f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only." + ) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + return self.attention( + hidden_states, + attention_mask=attention_mask, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + +class GPTNeoMLP(nn.Module): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size + super().__init__() + embed_dim = config.hidden_size + self.c_fc = nn.Linear(embed_dim, intermediate_size) + self.c_proj = nn.Linear(intermediate_size, embed_dim) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(float(config.resid_dropout)) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GPTNeoBlock(GradientCheckpointingLayer): + def __init__(self, config, layer_id=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTNeoAttention(config, layer_id) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTNeoMLP(inner_dim, config) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output, attn_weights = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return hidden_states, attn_weights + + +@auto_docstring +class GPTNeoPreTrainedModel(PreTrainedModel): + config: GPTNeoConfig + load_tf_weights = load_tf_weights_in_gpt_neo + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["GPTNeoBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _can_compile_fullgraph = False # TODO: needs a hybrid cache + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class GPTNeoModel(GPTNeoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.drop = nn.Dropout(float(config.embed_dropout)) + self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + seq_length = inputs_embeds.shape[1] + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + output_shape = (-1, seq_length, hidden_states.size(-1)) + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring( + custom_intro=""" + The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """ +) +class GPTNeoForCausalLM(GPTNeoPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPTNeoModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = lm_logits.to(torch.float32) + + # Flatten the tokens + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The GPTNeo Model transformer with a sequence classification head on top (linear layer). + + [`GPTNeoForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPTNeoModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class GPTNeoForTokenClassification(GPTNeoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPTNeoModel(config) + self.dropout = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPTNeoModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "GPTNeoForCausalLM", + "GPTNeoForQuestionAnswering", + "GPTNeoForSequenceClassification", + "GPTNeoForTokenClassification", + "GPTNeoModel", + "GPTNeoPreTrainedModel", + "load_tf_weights_in_gpt_neo", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94ba39d69ad638c706f6ac8491e2dea80e269929 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gpt_neox_japanese import * + from .modeling_gpt_neox_japanese import * + from .tokenization_gpt_neox_japanese import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py new file mode 100644 index 0000000000000000000000000000000000000000..320157334539b1d7c418c8cf97c8b57dc38629f7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GPTNeoX Japanese model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class GPTNeoXJapaneseConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GPTNeoXModelJapanese`]. It is used to instantiate + a GPTNeoX model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPTNeoXJapanese + [abeja/gpt-neox-japanese-2.7b](https://huggingface.co/abeja/gpt-neox-japanese-2.7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Default configs is set as 2.7B model + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the GPTNeoXJapanese model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`GPTNeoXJapanese`]. + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_multiple_size (`int`, *optional*, defaults to 4): + Dimension of the "intermediate" layer in the Transformer encoder is calculated by hidden_size * + intermediate_multiple_size. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + rotary_pct (`float`, *optional*, defaults to 1.00): + percentage of hidden dimensions to allocate to rotary embeddings + rotary_emb_base (`int`, *optional*, defaults to 10000) + base for computing rotary embeddings frequency + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden layer. + Example: + + ```python + >>> from transformers import GPTNeoXJapaneseConfig, GPTNeoXJapaneseModel + + >>> # Initializing a GPTNeoXJapanese gpt-neox-japanese-2.7b style configuration + >>> configuration = GPTNeoXJapaneseConfig() + + >>> # Initializing a model (with random weights) from the gpt-neox-japanese-2.7b style configuration + >>> model = GPTNeoXJapaneseModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt_neox_japanese" + + def __init__( + self, + vocab_size=32000, + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=32, + intermediate_multiple_size=4, + hidden_act="gelu", + rotary_pct=1.00, + rotary_emb_base=10000, + max_position_embeddings=2048, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + bos_token_id=31996, + eos_token_id=31999, + rope_scaling=None, + attention_dropout=0.1, + hidden_dropout=0.0, + **kwargs, + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_multiple_size = intermediate_multiple_size + self.hidden_act = hidden_act + self.rotary_pct = rotary_pct + self.partial_rotary_factor = rotary_pct + self.rotary_emb_base = rotary_emb_base + self.rope_theta = rotary_emb_base + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + +__all__ = ["GPTNeoXJapaneseConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py new file mode 100644 index 0000000000000000000000000000000000000000..70399f376c7553fc5d7a1437b0a4b760732697ba --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -0,0 +1,755 @@ +# coding=utf-8 +# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTNeoX model.""" + +import math +from typing import Optional, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): + config: GPTNeoXJapaneseConfig + base_model_prefix = "gpt_neox_japanese" + _no_split_modules = ["GPTNeoXJapaneseLayer"] + _skip_keys_device_placement = "past_key_values" + + _can_compile_fullgraph = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, GPTNeoXJapaneseAttention): + if module.dense_bias is not None: + module.dense_bias.data.zero_() + + +class GPTNeoXJapaneseAttention(nn.Module): + def __init__(self, config, use_bias=False, layer_idx=None): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.layer_idx = layer_idx + self.rotary_ndims = int(self.head_size * config.rotary_pct) + self.rope_theta = config.rotary_emb_base + self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config) + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.norm_factor = math.sqrt(self.head_size) + + self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # Activate bias if the last layer + self.use_bias = use_bias + self.dense_bias = nn.Parameter(torch.zeros(config.hidden_size)) if use_bias else None + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + cos, sin = position_embeddings + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query = torch.cat((query, query_pass), dim=-1).contiguous() + key = torch.cat((key, key_pass), dim=-1).contiguous() + + # Cache QKV values + if layer_past is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self.dense(attn_output) + + return attn_output, attn_weights, self.dense_bias + + @classmethod + def _split_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + + # [batch_size * num_heads, q_length, kv_length] + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attention_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=1.0 / self.norm_factor, + ) + + attention_scores = attention_scores.view(batch_size, num_attention_heads, query_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attention_scores = attention_scores + causal_mask + + attn_weights = nn.functional.softmax(attention_scores, dim=-1) + attn_weights = self.attention_dropout(attn_weights) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoX->GPTNeoXJapanese +class GPTNeoXJapaneseRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GPTNeoXJapaneseConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def bias_dropout_add(x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float, training: bool) -> Tensor: + """add bias to x, apply dropout and residual connection + + Args: + x (Tensor): main path of output + bias (Tensor): None or attn_bias of the last attention layer + residual (Optional[Tensor]): residual value + prob (float): dropout probability + training (bool): whether in training mode or not + + Returns: + Tensor: dropout(x + bias) + residual + """ + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + if residual is not None: + out = residual + out + return out + + +class GPTNeoXJapaneseMLP(nn.Module): + def __init__(self, config): + super().__init__() + intermediate_size = int(config.hidden_size * config.intermediate_multiple_size) + self.dense_h_to_4h = nn.Linear(config.hidden_size, intermediate_size, bias=False) + # Project back to h. + self.dense_4h_to_h = nn.Linear(intermediate_size, config.hidden_size, bias=False) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + intermediate = self.dense_h_to_4h(hidden_states) + intermediate = self.act(intermediate) + output = self.dense_4h_to_h(intermediate) + return output + + +class GPTNeoXJapaneseLayer(nn.Module): + def __init__(self, config, layer_number): + super().__init__() + self.layer_number = layer_number + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + # activate bias only last layer + self.attention = GPTNeoXJapaneseAttention( + config=config, use_bias=layer_number == config.num_hidden_layers - 1, layer_idx=layer_number + ) + self.mlp = GPTNeoXJapaneseMLP(config) + self.hidden_dropout = config.hidden_dropout + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + layer_past: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + residual = hidden_states + ln_out = self.input_layernorm(hidden_states) + attn_output, attn_weights, attn_bias = self.attention( + ln_out, + attention_mask=attention_mask, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + position_ids=position_ids, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + # attn_output = (atten_output + bias) + residual + attn_output = bias_dropout_add( + attn_output, + bias=attn_bias.expand_as(residual) if attn_bias is not None else attn_bias, + residual=residual, + prob=self.hidden_dropout, + training=self.training, + ) + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + + # attn_output = (mlp_output + mlp_bias) + atten_output + attn_output = bias_dropout_add( + mlp_output, bias=None, residual=attn_output, prob=self.hidden_dropout, training=self.training + ) + + return attn_output, attn_weights + + +@auto_docstring +class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [GPTNeoXJapaneseLayer(config=config, layer_number=i) for i in range(config.num_hidden_layers)] + ) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_in + + def set_input_embeddings(self, value): + self.embed_in = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, tuple[tuple[torch.FloatTensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, GPTNeoXJapaneseModel + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> model = GPTNeoXJapaneseModel.from_pretrained("abeja/gpt-neox-japanese-2.7b") + + >>> inputs = tokenizer("日本語のGPT-neoxがHugging Faceで使えます😀", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + seq_length = inputs_embeds.shape[1] + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + layer_past=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = outputs[0] + if output_attentions: + all_attentions = all_attentions + (outputs[1],) + + hidden_states = self.final_layer_norm(hidden_states) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring( + custom_intro=""" + GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning. + """ +) +class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["embed_out.weight"] + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.gpt_neox_japanese = GPTNeoXJapaneseModel(config) + self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.embed_out + + def set_output_embeddings(self, new_embeddings): + self.embed_out = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, tuple[tuple[torch.FloatTensor]]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GPTNeoXJapaneseForCausalLM, GPTNeoXJapaneseConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> config = GPTNeoXJapaneseConfig.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> config.is_decoder = True + >>> model = GPTNeoXJapaneseForCausalLM.from_pretrained("abeja/gpt-neox-japanese-2.7b", config=config) + + >>> inputs = tokenizer("日本語のGPT-neoxがHugging Faceで使えます😀", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt_neox_japanese( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + lm_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + + lm_loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "GPTNeoXJapaneseForCausalLM", + "GPTNeoXJapaneseLayer", + "GPTNeoXJapaneseModel", + "GPTNeoXJapanesePreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py new file mode 100644 index 0000000000000000000000000000000000000000..584e74a8123e7cfdf31c4738a656a8417085e9a1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py @@ -0,0 +1,369 @@ +# coding=utf-8 +# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for GPTNeoXJapanese.""" + +import collections +import json +import os +import re +import sys +from typing import Optional + +import numpy as np + +from ...tokenization_utils_fast import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"} + + +def load_vocab_and_emoji(vocab_file, emoji_file): + """Loads a vocabulary file and emoji file into a dictionary.""" + with open(emoji_file, "r", encoding="utf-8") as f: + emoji = json.loads(f.read()) + + vocab = collections.OrderedDict() + raw_vocab = collections.OrderedDict() + ids_to_tokens = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as f: + token = f.readlines() + token = [[t.rstrip("\n")] if (t == "," or "," not in t) else t.rstrip("\n").split(",") for t in token] + for idx, b in enumerate(token): + ids_to_tokens[idx] = b + raw_vocab[",".join(b)] = idx + for wd in b: + vocab[wd] = idx + + return vocab, raw_vocab, ids_to_tokens, emoji + + +class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer): + """ + This tokenizer inherits from [`PreTrainedTokenizer`] and is based on Japanese special Sub-Word-Encoding that is + used in this repository (https://github.com/tanreinama/Japanese-BPEEncoder_V2). Check the repository for details. + Japanese has a relatively large vocabulary and there is no separation between words. Furthermore, the language is a + combination of hiragana, katakana, and kanji, and variants such as "1" and "①" are often used. In order to cope + with these, this tokenizer has the following features + - Subword-by-subword segmentation, which is intermediate between byte strings and morphological analysis. + - BPEs are created for each Kanji, Hiragana, and Katakana character, and there are no BPEs that cross character + types, such as Kanji + Hiragana or Hiragana + Katakana. + - All-byte encoding that does not require . + - Independent of UTF codes such as 2-byte and 3-byte characters + - Conversion of heterographs to the same token_id + - Emoji and Emoticon are grouped into 12 types as special tags. + + Example: + + ```python + >>> from transformers import GPTNeoXJapaneseTokenizer + + >>> tokenizer = GPTNeoXJapaneseTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> # You can confirm both 慶応 and 慶應 are encoded to 17749 + >>> tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"] + [30014, 26883, 26638, 27228, 25, 26650, 31732, 31679, 27809, 26638, 17749, 31592, 17749, 31593, 321, 1281] + + >>> # Both 慶応 and 慶應 are decoded to 慶応 + >>> tokenizer.decode(tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"]) + '吾輩は猫である🐯。実は慶応(慶応)大学出身' + ``` + + Args: + vocab_file (`str`): + File containing the vocabulary. + emoji_file (`str`): + File containing the emoji. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding + bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + do_clean_text (`bool`, *optional*, defaults to `False`): + Whether or not to clean text for URL, EMAIL, TEL, Japanese DATE and Japanese PRICE. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + emoji_file, + unk_token="<|endoftext|>", + pad_token="<|endoftext|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + do_clean_text=False, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = GPTNeoXJapaneseokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + if not os.path.isfile(emoji_file): + raise ValueError( + f"Can't find a emoji file at path '{emoji_file}'. To load the emoji information from a Google" + " pretrained model use `tokenizer = GPTNeoXJapaneseokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.do_clean_text = do_clean_text + self.vocab, self.raw_vocab, self.ids_to_tokens, self.emoji = load_vocab_and_emoji(vocab_file, emoji_file) + self.subword_tokenizer = SubWordJapaneseTokenizer( + vocab=self.vocab, ids_to_tokens=self.ids_to_tokens, emoji=self.emoji + ) + super().__init__( + unk_token=unk_token, + pad_token=pad_token, + bos_token=bos_token, + eos_token=eos_token, + do_clean_text=do_clean_text, + **kwargs, + ) + + @property + def vocab_size(self): + # self.vocab contains support for character fluctuation unique to Japanese, and has a large number of vocab + return len(self.raw_vocab) + + def get_vocab(self): + return dict(self.raw_vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + return self.subword_tokenizer.tokenize(text, clean=self.do_clean_text) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.subword_tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = "".join(tokens).strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + emoji_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["emoji_file"] + ) + else: + vocab_file = ( + (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["vocab_file"] + ) + emoji_file = ( + (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["emoji_file"] + ) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token_index, token in self.ids_to_tokens.items(): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(",".join(token) + "\n") + index += 1 + with open(emoji_file, "w", encoding="utf-8") as writer: + json.dump(self.emoji, writer) + return vocab_file, emoji_file + + +class SubWordJapaneseTokenizer: + """ + https://github.com/tanreinama/Japanese-BPEEncoder_V2 This tokenizer class is under MIT License according to the + original repository. + + MIT License + + Copyright (c) 2020 tanreinama + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of + the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO + THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + + def __init__(self, vocab, ids_to_tokens, emoji): + self.vocab = vocab # same as swe + self.ids_to_tokens = ids_to_tokens # same as bpe + self.emoji = emoji + self.maxlen = np.max([len(w) for w in self.vocab]) + self.content_repatter1 = re.compile(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)") + self.content_repatter2 = re.compile(r"[A-Za-z0-9\._+]*@[\-_0-9A-Za-z]+(\.[A-Za-z]+)*") + self.content_repatter3 = re.compile(r"[\(]{0,1}[0-9]{2,4}[\)\-\(]{0,1}[0-9]{2,4}[\)\-]{0,1}[0-9]{3,4}") + self.content_repatter4 = re.compile( + r"([12]\d{3}[/\-年])*(0?[1-9]|1[0-2])[/\-月]((0?[1-9]|[12][0-9]|3[01])日?)*(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*" + ) + self.content_repatter5 = re.compile( + r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*" + ) + # The original version of this regex displays catastrophic backtracking behaviour. We avoid this using + # possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly + # different regex that should generally have the same behaviour in most non-pathological cases. + if sys.version_info >= (3, 11): + self.content_repatter6 = re.compile( + r"(?:\d,\d{3}|[\d億])*+" + r"(?:\d,\d{3}|[\d万])*+" + r"(?:\d,\d{3}|[\d千])*+" + r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+" + r"(?:\(税込\)|\(税抜\)|\+tax)*" + ) + else: + self.content_repatter6 = re.compile( + r"(?:\d,\d{3}|[\d億万千])*" + r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+" + r"(?:\(税込\)|\(税抜\)|\+tax)*" + ) + keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿" + blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟" + self.content_trans1 = str.maketrans(dict.fromkeys(keisen + blocks, "")) + + def __len__(self): + return len(self.ids_to_tokens) + + def clean_text(self, content): + content = self.content_repatter1.sub("", content) + content = self.content_repatter2.sub("", content) + content = self.content_repatter3.sub("", content) + content = self.content_repatter4.sub("", content) + content = self.content_repatter5.sub("", content) + content = self.content_repatter6.sub("", content) + content = content.translate(self.content_trans1) + while "" in content: + content = content.replace("", "") + return content + + def tokenize(self, text, clean=False): + text = text.replace(" ", "") + text = text.replace(" ", "") + text = text.replace("\r\n", "
") + text = text.replace("\n", "
") + text = text.replace("\r", "
") + text = text.replace("\t", "") + text = text.replace("—", "ー") + text = text.replace("−", "ー") + for k, v in self.emoji["emoji"].items(): + if k in text: + text = text.replace(k, v) + if clean: + text = self.clean_text(text) + + def check_simbol(x): + e = x.encode() + if len(x) == 1 and len(e) == 2: + c = (int(e[0]) << 8) + int(e[1]) + if ( + (c >= 0xC2A1 and c <= 0xC2BF) + or (c >= 0xC780 and c <= 0xC783) + or (c >= 0xCAB9 and c <= 0xCBBF) + or (c >= 0xCC80 and c <= 0xCDA2) + ): + return True + return False + + def checku2e(x): + e = x.encode() + if len(x) == 1 and len(e) == 3: + c = (int(e[0]) << 16) + (int(e[1]) << 8) + int(e[2]) + if c >= 0xE28080 and c <= 0xE2B07F: + return True + return False + + pos = 0 + result = [] + while pos < len(text): + end = min(len(text), pos + self.maxlen + 1) if text[pos] == "<" else pos + 3 + candidates = [] # (token_id, token, pos) + for e in range(end, pos, -1): + wd = text[pos:e] + if wd in self.vocab: + if wd[0] == "<" and len(wd) > 2: + candidates = [(self.vocab[wd], wd, e)] + break + else: + candidates.append((self.vocab[wd], wd, e)) + if len(candidates) > 0: + # the smallest token_id is adopted + _, wd, e = min(candidates, key=lambda x: x[0]) + result.append(wd) + pos = e + else: + end = pos + 1 + wd = text[pos:end] + if check_simbol(wd): + result.append("") + elif checku2e(wd): + result.append("") + else: + for i in wd.encode("utf-8"): + result.append("<|byte%d|>" % i) + pos = end + return result + + def convert_id_to_token(self, index, breakline="\n"): + words = [] + byte_tokens = [] + word = self.ids_to_tokens[index][0] + if word[:6] == "<|byte" and word[-2:] == "|>": + byte_tokens.append(int(word[6:-2])) + else: + if len(byte_tokens) > 0: + words.append(bytearray(byte_tokens).decode("utf-8", errors="replace")) + byte_tokens = [] + if word[:7] == "<|emoji" and word[-2:] == "|>": + words.append(self.emoji["emoji_inv"][word]) + elif word == "": + words.append(" ") + elif word == "
": + words.append(breakline) + elif word == "": + words.append("\t") + elif word == "": + words.append("▀") + elif word == "": + words.append("ǀ") + elif word == "": + words.append("‖") + else: + words.append(word) + if len(byte_tokens) > 0: + words.append(bytearray(byte_tokens).decode("utf-8", errors="replace")) + text = "".join(words) + return text + + +__all__ = ["GPTNeoXJapaneseTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19e12e75ef8f46a09244a1a0541bec8097ba3a94 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gpt_oss import * + from .modeling_gpt_oss import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/configuration_gpt_oss.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/configuration_gpt_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..6459e9a7fd4a936033f06b95fd7fc4c1381fe31b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/configuration_gpt_oss.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""openai model configuration""" + +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import rope_config_validation + + +class GptOssConfig(PretrainedConfig): + r""" + This will yield a configuration to that of the BERT + [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. + + """ + + model_type = "gpt_oss" + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.mlp.experts": "gather", + "layers.*.mlp.router": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", + } + + def __init__( + self, + num_hidden_layers: int = 36, + num_local_experts: int = 128, + vocab_size: int = 201088, + hidden_size: int = 2880, + intermediate_size: int = 2880, + head_dim: int = 64, + num_attention_heads: int = 64, + num_key_value_heads: int = 8, + sliding_window: int = 128, + rope_theta: float = 150000.0, + tie_word_embeddings=False, + hidden_act: str = "silu", + initializer_range: float = 0.02, + max_position_embeddings=131072, + rms_norm_eps: float = 1e-5, + rope_scaling={ + "rope_type": "yarn", + "factor": 32.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "truncate": False, + "original_max_position_embeddings": 4096, + }, + attention_dropout: float = 0.0, + num_experts_per_tok=4, + router_aux_loss_coef: float = 0.9, + output_router_logits=False, + use_cache=True, + layer_types=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_local_experts = num_local_experts + self.sliding_window = sliding_window + self.num_experts_per_tok = num_experts_per_tok + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + self.attention_bias = True + self.max_position_embeddings = max_position_embeddings + self.router_aux_loss_coef = router_aux_loss_coef + self.output_router_logits = output_router_logits + self.use_cache = use_cache + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["GptOssConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..e433ebc53ab9e6b33f25d1c5688e3791926951d9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,725 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gpt_oss/modular_gpt_oss.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gpt_oss.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations.hub_kernels import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_gpt_oss import GptOssConfig + + +@use_kernel_forward_from_hub("RMSNorm") +class GptOssRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GptOssRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) # main diff with Llama + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) + self.alpha = 1.702 + self.limit = 7.0 + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + """ + When training it is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + if hidden_states.device.type == "cpu" or self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=num_experts + 1 + ) # masking is also a class + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + # expert_idx only have 1 element, so we can use scale for fast indexing + expert_idx = expert_idx[0] + # skip masking index + if expert_idx == num_experts: + continue + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + +class GptOssTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.empty(self.num_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + +@use_kernel_forward_from_hub("MegaBlocksMoeMLP") +class GptOssMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.router = GptOssTopKRouter(config) + self.experts = GptOssExperts(config) + + def forward(self, hidden_states): + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores + + +class GptOssRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GptOssConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = freqs + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(x.dtype), sin.to(x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + first_half, second_half = torch.chunk(x, 2, dim=-1) + first_ = first_half * cos - second_half * sin + second_ = second_half * cos + first_half * sin + return torch.cat((first_, second_), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class GptOssAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class GptOssDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx) + self.mlp = GptOssMLP(config) + self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class GptOssPreTrainedModel(PreTrainedModel): + config: GptOssConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GptOssDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = False + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(GptOssTopKRouter, index=0), + "hidden_states": GptOssDecoderLayer, + "attentions": GptOssAttention, + } + _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] + _supports_flash_attention = False + _supports_flex_attention = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GptOssRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, GptOssExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.data.zero_() + module.down_proj.data.normal_(mean=0.0, std=std) + module.down_proj_bias.data.zero_() + elif isinstance(module, GptOssAttention): + module.sinks.data.normal_(mean=0.0, std=std) + elif isinstance(module, GptOssTopKRouter): + module.weight.data.normal_(mean=0.0, std=std) + module.bias.data.normal_(mean=0.0, std=std) + + +@auto_docstring +class GptOssModel(GptOssPreTrainedModel): + _no_split_modules = ["GptOssDecoderLayer"] + + def __init__(self, config: GptOssConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GptOssRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = GptOssModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class GptOssForSequenceClassification(GenericForSequenceClassification, GptOssPreTrainedModel): + pass + + +class GptOssForTokenClassification(GenericForTokenClassification, GptOssPreTrainedModel): + pass + + +__all__ = [ + "GptOssForCausalLM", + "GptOssForSequenceClassification", + "GptOssForTokenClassification", + "GptOssModel", + "GptOssPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/modular_gpt_oss.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/modular_gpt_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..276333082f56797dd86206267b08b104a62839b8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_oss/modular_gpt_oss.py @@ -0,0 +1,472 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from ...cache_utils import Cache, DynamicCache +from ...integrations.hub_kernels import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_outputs import ( + MoeModelOutputWithPast, +) +from ...modeling_rope_utils import dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + repeat_kv, +) +from ..mixtral.modeling_mixtral import ( + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralForTokenClassification, + MixtralModel, +) +from ..qwen2.modeling_qwen2 import Qwen2Attention +from .configuration_gpt_oss import GptOssConfig + + +logger = logging.get_logger(__name__) + + +class GptOssRMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) # main diff with Llama + + +class GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) + self.alpha = 1.702 + self.limit = 7.0 + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + """ + When training it is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + if hidden_states.device.type == "cpu" or self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=num_experts + 1 + ) # masking is also a class + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + # expert_idx only have 1 element, so we can use scale for fast indexing + expert_idx = expert_idx[0] + # skip masking index + if expert_idx == num_experts: + continue + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + +class GptOssTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.empty(self.num_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + +@use_kernel_forward_from_hub("MegaBlocksMoeMLP") +class GptOssMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.router = GptOssTopKRouter(config) + self.experts = GptOssExperts(config) + + def forward(self, hidden_states): + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores + + +class GptOssRotaryEmbedding(LlamaRotaryEmbedding): + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = freqs + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(x.dtype), sin.to(x.dtype) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + first_half, second_half = torch.chunk(x, 2, dim=-1) + first_ = first_half * cos - second_half * sin + second_ = second_half * cos + first_half * sin + return torch.cat((first_, second_), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class GptOssAttention(Qwen2Attention): + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class GptOssDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx) + self.mlp = GptOssMLP(config) + self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = residual + hidden_states + return hidden_states + + +class GptOssPreTrainedModel(LlamaPreTrainedModel): + _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] + _supports_sdpa = False + _supports_flash_attention = False + _supports_flex_attention = False + _can_record_outputs = { + "router_logits": OutputRecorder(GptOssTopKRouter, index=0), + "hidden_states": GptOssDecoderLayer, + "attentions": GptOssAttention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GptOssRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, GptOssExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.data.zero_() + module.down_proj.data.normal_(mean=0.0, std=std) + module.down_proj_bias.data.zero_() + elif isinstance(module, GptOssAttention): + module.sinks.data.normal_(mean=0.0, std=std) + elif isinstance(module, GptOssTopKRouter): + module.weight.data.normal_(mean=0.0, std=std) + module.bias.data.normal_(mean=0.0, std=std) + + +class GptOssModel(MixtralModel): + _no_split_modules = ["GptOssDecoderLayer"] + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class GptOssForCausalLM(MixtralForCausalLM): + pass + + +class GptOssForSequenceClassification(MixtralForSequenceClassification): + pass + + +class GptOssForTokenClassification(MixtralForTokenClassification): + pass + + +__all__ = [ + "GptOssForCausalLM", + "GptOssForSequenceClassification", + "GptOssForTokenClassification", + "GptOssModel", + "GptOssPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_sw3/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_sw3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e477eb1d2cc2ff0e34d214ed99ad3f80afe8ab0a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_sw3/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_gpt_sw3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py new file mode 100644 index 0000000000000000000000000000000000000000..3019acfd5bcc360502ea798b22df8fff71dd2b81 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py @@ -0,0 +1,301 @@ +"""The tokenizer used by the GPT-SW3 models.""" + +import os +import re +import unicodedata +from shutil import copyfile +from typing import Any, Optional, Union + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import is_torch_available, logging +from ...utils.import_utils import requires + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +@requires(backends=("sentencepiece",)) +class GPTSw3Tokenizer(PreTrainedTokenizer): + """ + Construct an GPTSw3 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Example usage: + ```python + >>> from transformers import GPTSw3Tokenizer + + >>> tokenizer = GPTSw3Tokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-126m") + >>> tokenizer("Svenska är kul!")["input_ids"] + [1814, 377, 3617, 63504] + ``` + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + remove_space (`bool`, *optional*, defaults to `False`): + Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). + keep_accents (`bool`, *optional*, defaults to `False`): + Whether or not to keep accents when tokenizing. + pad_token (`str`, *optional*): + The token used for padding, for example when batching sequences of different lengths. If not provided, will + default to '' or '' depending on model size. + unk_token (`str`, *optional*): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. If not provided, will default to ''. + eos_token (`str`, *optional*): + The end of sequence token seen during pretraining. If not provided, will default to '<|endoftext|>' + bos_token (`str`, *optional*): + The beginning of sequence token that can be used for downstream task, was not seen during pretraining. If + not provided, will default to '' or '<|endoftext|>', depending on model size. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + whitespaces (`set`): + The whitespaces that are replaced in the whitespace normalization in preprocessing. + non_printing_characters_re (`Pattern`): + The compiled regular expression to remove non-printing characters in preprocessing. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + do_lower_case=False, + remove_space=False, + keep_accents=False, + pad_token=None, + unk_token=None, + eos_token=None, + bos_token=None, + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + name_or_path = kwargs.get("name_or_path") + if name_or_path is None: + logger.warning( + "name_or_path not provided, will work for all GPTSw3 models except gpt-sw3-7b," + " you are testing the model, this can safely be ignored" + ) + name_or_path = "None" + + # Default definitions for our 2 tokenizer versions, with None-checks to enable proper testing + eos_token = "<|endoftext|>" if eos_token is None else eos_token + unk_token = "" if unk_token is None else unk_token + if "gpt-sw3-7b" in name_or_path: + pad_token = unk_token if pad_token is None else pad_token + bos_token = eos_token if bos_token is None else bos_token + else: + pad_token = "" if pad_token is None else pad_token + bos_token = "" if bos_token is None else bos_token + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.keep_accents = keep_accents + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + # Used for whitespace normalization in input texts + # fmt : off + self.whitespaces = {" ", " ", " ", " ", " ", " ", " ", " ", " ", " ", "", "„"} + # fmt : on + + # Regular expression to remove non-printing characters (e.g. some unicode control chars) in preprocessing + self.non_printing_characters_re = re.compile( + f"[{''.join(map(chr, list(range(0, 9)) + list(range(11, 32)) + list(range(127, 160)) + [160, 173, 8203]))}]" + ) + + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + keep_accents=keep_accents, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + @property + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.vocab_size + def vocab_size(self) -> int: + return len(self.sp_model) + + def preprocess_text(self, text: str) -> str: + """ + Returns the preprocessed text. This procedure is identical to what was used when training the tokenizer. + """ + + # Remove non-printing characters + text = self.non_printing_characters_re.sub("", text) + + # Normalize whitespaces + text = "".join([char if char not in self.whitespaces else " " for char in text]) + + # NFC Unicode normalization + text = unicodedata.normalize("NFC", text) + return text + + def _tokenize(self, text: str, **kwargs) -> list[str]: + text = self.preprocess_text(text) + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) to an id (int) using the vocab.""" + return self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (int) to a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index) + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + """Returns the input string, this function is overridden to remove the default clean up.""" + return out_string + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """Converts a sequence of tokens (strings) to a single string. Special tokens remain intact.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + # TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document + if not prev_is_special: + out_string += " " + + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + + return out_string + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.get_vocab + def get_vocab(self) -> dict[str, int]: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def encode_fast( + self, text: Union[str, list[str]], return_tensors: Union[str, bool] = False + ) -> Union[list[int], list[list[int]], "torch.Tensor"]: + """ + Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced + functionality but is often much faster. + + Does NOT handle special tokens correctly, these can manually be added as ids afterwards. + + Does NOT support padding, these can manually be added as ids afterwards. + + Use default HuggingFace tokenization methods for full functionality. + + Args: + text (`str` or `list[str]`): One or several text(s) to convert to token ids. + return_tensors (`str` or `bool`): Returns PyTorch tensors if set to True or "pt" + + Returns: + `list[int]`, `list[list[int]]`, or `torch.Tensor`: The encoded text(s) as token ids. + """ + + if isinstance(text, str): + text = self.preprocess_text(text) + token_ids = self.sp_model.encode(text) + else: + text = [self.preprocess_text(t) for t in text] + token_ids = self.sp_model.encode(text) + + if return_tensors is True or return_tensors == "pt": + token_ids = torch.tensor(token_ids) + + return token_ids + + def decode_fast(self, token_ids: Union[int, list[int]]) -> str: + """ + Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced + functionality but is often much faster. + + Args: + token_ids (`int` or `list[int]`): Encoded token or text as token id(s). + + Returns: + `str`: Decoded text + """ + + return self.sp_model.decode(token_ids) + + +__all__ = ["GPTSw3Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/granite_speech/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6122581855250911b0bf951a42e22d9c354240a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_granite_speech import * + from .feature_extraction_granite_speech import * + from .modeling_granite_speech import * + from .processing_granite_speech import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/granite_speech/configuration_granite_speech.py b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/configuration_granite_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..fede07b7b7e820e78f44538313a85d39afc811d7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/configuration_granite_speech.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config class for Granite Speech.""" + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class GraniteSpeechEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GraniteSpeechCTCEncoder`]. It is used to instantiate + a Granite Speech audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the dfefaults will yield a similar configuration to that of the audio encoder of the Granite Speech + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + input_dim (`int`, *optional*, defaults to 160): + Dimension of the first hidden layer of the encoder. + num_layers (`int`, *optional*, defaults to 10): + Number of encoder blocks. + hidden_dim (`int`, *optional*, defaults to 1024): + The size of the intermediate layers in the conformer encoder. + feedforward_mult (`int`, *optional*, defaults to 4): + Multiplier for the up/down projections in the encoder's feedforward layers; + The projections will have intermediate dim of size `hidden_dim * feedforward_mult`. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dim_head (`int`, *optional*, defaults to 128): + Dimension of attention heads for each attention layer in the Transformer encoder. + output_dim (`int`, *optional*, defaults to 42): + Intermediate dimension of the feedforward projections in the conformer + to be added to every other encoder block's output. + context_size (`int`, *optional*, defaults to 200): + Context size to be used in conformer attention. + max_pos_emb (`int`, *optional*, defaults to 512): + Max pos embeds to be used in attention (shaw's relative positional encoding). + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for fully connected layers in the encoder. + conv_kernel_size (`int`, *optional*, defaults to 15): + Kernel size to be used for 1D convolution in each conformer block. + conv_expansion_factor (`int`, *optional*, defaults to 2): + Intermediate dimension to be used in conformer convolutions. + + Example: + + ```python + >>> from transformers import GraniteSpeechEncoderConfig, GraniteSpeechCTCEncoder + + >>> # Initializing a GraniteSpeechEncoderConfig + >>> configuration = GraniteSpeechEncoderConfig() + + >>> # Initializing a GraniteSpeechCTCEncoder (with random weights) + >>> model = GraniteSpeechCTCEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "granite_speech_encoder" + + def __init__( + self, + input_dim=160, + num_layers=10, + hidden_dim=1024, + feedforward_mult=4, + num_heads=8, + dim_head=128, + output_dim=42, + context_size=200, + max_pos_emb=512, + dropout=0.1, + conv_kernel_size=15, + conv_expansion_factor=2, + **kwargs, + ): + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.feedforward_mult = feedforward_mult + self.num_heads = num_heads + self.dim_head = dim_head + self.output_dim = output_dim + self.context_size = context_size + self.dropout = dropout + self.conv_kernel_size = conv_kernel_size + self.conv_expansion_factor = conv_expansion_factor + self.max_pos_emb = max_pos_emb + + +class GraniteSpeechConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GraniteSpeechForConditionalGeneration`]. It is used to instantiate an + Granite Speech model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `GraniteConfig`): + The config object or dictionary of the text backbone. + encoder_config (`GraniteSpeechEncoderConfig`, *optional*): + The config object or dictionary of the Granite Speech CTC Encoder. + projector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Blip2QFormerConfig`): + The config object or dictionary of the audio projector. + audio_token_index (`int`, *optional*, defaults to 49155): + The audio token index to encode the audio prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + has_lora_adapter (`bool`, *optional*, defaults to `True`): + Indicates whether or not the model has a lora adapter that should only + be activate when processing audio inputs. + downsample_rate (`int`, *optional*, defaults to 5): + Downsample rate for the audio feature extractor. + window_size (`int`, *optional*, defaults to 15): + Window size for the audio feature projector. + + Example: + + ```python + >>> from transformers import GraniteSpeechConfig, GraniteSpeechForConditionalGeneration + + >>> # Initializing a GraniteSpeechConfig + >>> configuration = GraniteSpeechConfig() + + >>> # Initializing a GraniteSpeechForConditionalGeneration (with random weights) + >>> model = GraniteSpeechForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "granite_speech" + attribute_map = { + "audio_token_id": "audio_token_index", + } + sub_configs = { + "text_config": AutoConfig, + "encoder_config": GraniteSpeechEncoderConfig, + "projector_config": AutoConfig, + } + + def __init__( + self, + text_config=None, + encoder_config=None, + projector_config=None, + audio_token_index=49155, + initializer_range=0.02, + has_lora_adapter=True, + downsample_rate=5, + window_size=15, + **kwargs, + ): + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "granite") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["granite"]() + + if isinstance(projector_config, dict): + projector_config["model_type"] = projector_config.get("model_type", "blip_2_qformer") + projector_config = CONFIG_MAPPING[projector_config["model_type"]](**projector_config) + elif projector_config is None: + projector_config = CONFIG_MAPPING["blip_2_qformer"]() + + if not isinstance(encoder_config, GraniteSpeechEncoderConfig): + encoder_config = {} if encoder_config is None else encoder_config + encoder_config = GraniteSpeechEncoderConfig(**encoder_config) + + self.text_config = text_config + self.encoder_config = encoder_config + self.projector_config = projector_config + self.audio_token_index = audio_token_index + self.initializer_range = initializer_range + self.has_lora_adapter = has_lora_adapter + self.downsample_rate = downsample_rate + self.window_size = window_size + super().__init__(**kwargs) + + +__all__ = ["GraniteSpeechEncoderConfig", "GraniteSpeechConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/granite_speech/feature_extraction_granite_speech.py b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/feature_extraction_granite_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..7528fc7ea5bd9efa6ae322d7fd2e40b567855359 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/feature_extraction_granite_speech.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Granite Speech.""" + +import math +from collections.abc import Sequence +from typing import Optional + +import numpy as np + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...tokenization_utils_base import AudioInput +from ...utils import is_torch_available, is_torchaudio_available, logging +from ...utils.import_utils import requires_backends + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + +if is_torchaudio_available(): + import torchaudio + + +class GraniteSpeechFeatureExtractor(FeatureExtractionMixin): + model_input_names = ["input_features"] + + def __init__( + self, + sampling_rate: int = 16000, + n_fft: int = 512, + win_length: int = 400, + hop_length: int = 160, + n_mels: int = 80, + projector_window_size: int = 15, + projector_downsample_rate: int = 5, + **kwargs, + ): + super().__init__(**kwargs) + self.sampling_rate = sampling_rate + self.melspec_kwargs = { + "sample_rate": sampling_rate, + "n_fft": n_fft, + "win_length": win_length, + "hop_length": hop_length, + "n_mels": n_mels, + } + requires_backends(self, ["torchaudio"]) + self.mel_filters = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) + self.projector_window_size = projector_window_size + self.projector_downsample_rate = projector_downsample_rate + + def __call__( + self, + audios: AudioInput, + device: Optional[str] = "cpu", + ) -> BatchFeature: + requires_backends(self, ["torchaudio"]) + + speech_inputs = {} + batched_audio, audio_lengths = self._get_audios_and_audio_lengths(audios) + speech_inputs["input_features"] = self._extract_mel_spectrograms( + batched_audio, + device=device, + ) + audio_embed_sizes = self._get_num_audio_features(audio_lengths) + speech_inputs["audio_embed_sizes"] = audio_embed_sizes + # TODO (@alex-jw-brooks): Currently input_features_mask is not + # a great name, because input_features and input_features_mask + # have different shapes (before/after the projector). + # + # We should align this with other multimodal models, e.g,. llava + # and qwen2audio and refactor this to ensure input_feature_mask + # has the same dimensionality as input_features, or compute it in + # the model based on the audio embedding sizes (since we do not + # have an attention mask for the audio features to infer padding from). + speech_inputs["input_features_mask"] = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor( + audio_embed_sizes + ).view(-1, 1) + return BatchFeature(data=speech_inputs) + + def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"): + """ + Compute the Mel features to be passed to the conformer encoder. + """ + requires_backends(self, ["torchaudio"]) + if device is not None: + melspec = self.mel_filters.to(device) + audio = audio.to(device) + else: + melspec = self.mel_filters + + bsz = audio.shape[0] + with torch.no_grad(): + # Compute mel features + mel = melspec(audio.float()) + logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_() + mx = logmel.amax(dim=(-2, -1), keepdim=True) + logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) + # remove last frame if odd + if logmel.shape[1] % 2 == 1: + logmel = logmel[:, :-1] + + # stacking and skipping by 2 + audio = logmel.reshape(bsz, -1, 2 * logmel.shape[-1]) + + return audio + + def _get_num_audio_features(self, audio_lengths: Sequence[int]) -> Sequence[int]: + """ + Gets the (variable length) number of features (i.e., projector output) for the sequences + being considered. + + Args: + audio_lengths (`Sequence[int]`): + Sequence of one or more raw audio lengths. + """ + hop_length = self.melspec_kwargs["hop_length"] + effective_window_size = self.projector_window_size // self.projector_downsample_rate + + projector_lengths = [] + for raw_length in audio_lengths: + # mel sequence length computation + mel_length = raw_length // hop_length + 1 + # encoder frame takes two mel features + encoder_length = mel_length // 2 + nblocks = math.ceil(encoder_length / self.projector_window_size) + # projector output length + projector_length = nblocks * effective_window_size + projector_lengths.append(projector_length) + + return projector_lengths + + def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence["torch.Tensor", Sequence[int]]: + """ + Coerces audio inputs to torch tensors and extracts audio lengths prior to stacking. + + Args: + audios (`AudioInput`): + Audio sequence, numpy array, or torch tensor. + """ + requires_backends(self, ["torch"]) + + # Coerce to PyTorch tensors if we have numpy arrays, since + # currently we have a dependency on torch/torchaudio anyway + if isinstance(audios, np.ndarray): + audios = torch.from_numpy(audios) + elif isinstance(audios, Sequence) and isinstance(audios[0], np.ndarray): + audios = [torch.from_numpy(arr) for arr in audios] + + if isinstance(audios, torch.Tensor): + if audios.ndim == 1: + audios = audios.unsqueeze(0) + if not torch.is_floating_point(audios): + raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") + + if audios.shape[0] > 1: + logger.warning("Audio samples are already collated; assuming they all have the same length") + lengths = [audios.shape[-1]] * audios.shape[0] + return audios, lengths + + elif isinstance(audios, Sequence) and isinstance(audios[0], torch.Tensor): + if not torch.is_floating_point(audios[0]): + raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") + lengths = [audio.shape[-1] for audio in audios] + audios = [audio.squeeze(0) for audio in audios] + audios = torch.nn.utils.rnn.pad_sequence(audios, batch_first=True, padding_value=0.0) + return audios, lengths + + raise TypeError("Invalid audio provided. Audio should be a one or more torch tensors or numpy arrays") + + +__all__ = ["GraniteSpeechFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/granite_speech/modeling_granite_speech.py b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/modeling_granite_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..1e44c9781dec683cf4b12bcb9d55d32b3635fb53 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/modeling_granite_speech.py @@ -0,0 +1,577 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_peft_available, logging +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_granite_speech import GraniteSpeechConfig, GraniteSpeechEncoderConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for LlavaNext causal language model (or autoregressive) outputs. + """ +) +class GraniteSpeechCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +### Projector +class GraniteSpeechEncoderProjector(nn.Module): + def __init__(self, config: GraniteSpeechConfig): + super().__init__() + self.hidden_size = config.projector_config.hidden_size + self.downsample_rate = config.downsample_rate + self.window_size = config.window_size + self.num_queries = config.window_size // config.downsample_rate + + self.query = nn.Parameter(torch.zeros(1, self.num_queries, config.projector_config.hidden_size)) + self.query.data.normal_(mean=0.0, std=1.0) + + # By default, this will be a blip_2_qformer config + self.qformer = AutoModel.from_config(config.projector_config) + self.linear = nn.Linear(config.projector_config.hidden_size, config.text_config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() + nblocks = math.ceil(seq_len / self.window_size) + pad = nblocks * self.window_size - seq_len + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0) + hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim) + + query_output = self.qformer( + query_embeds=self.query, + encoder_hidden_states=hidden_states, + encoder_attention_mask=None, + return_dict=True, + ) + query_proj = self.linear( + query_output.last_hidden_state.view(batch_size, nblocks * self.window_size // self.downsample_rate, -1) + ) + return query_proj + + +### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git +class GraniteSpeechConformerFeedForward(nn.Module): + """Feedforward module for conformer encoder blocks.""" + + def __init__(self, config: GraniteSpeechEncoderConfig): + super().__init__() + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.up_proj = nn.Linear(config.hidden_dim, config.hidden_dim * config.feedforward_mult) + self.silu = nn.SiLU() + self.dropout = nn.Dropout(config.dropout) + self.down_proj = nn.Linear(config.hidden_dim * config.feedforward_mult, config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + hidden_states = self.up_proj(hidden_states) + hidden_states = self.dropout(self.silu(hidden_states)) + hidden_states = self.down_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GraniteSpeechConformerAttention(nn.Module): + """Attention for conformer blocks using Shaw's relative positional embeddings. + See the following [paper](https://huggingface.co/papers/1803.02155) for more details. + """ + + def __init__(self, config: GraniteSpeechEncoderConfig): + super().__init__() + + inner_dim = config.dim_head * config.num_heads + self.max_pos_emb = config.max_pos_emb + self.context_size = config.context_size + self.num_heads = config.num_heads + self.dim_head = config.dim_head + self.scale = self.dim_head**-0.5 + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, config.hidden_dim) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) + self.dropout = nn.Dropout(config.dropout) + + if self.context_size <= 0 or self.context_size > self.max_pos_emb: + raise ValueError("Context size is either less than 0 or exceeds the max_pos_emb") + + def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + bsz, num_features, _ = hidden_states.shape + + num_blocks = math.ceil(num_features / self.context_size) + remainder = num_features % self.context_size + if remainder > 0: + # right padding to reach block size + hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, self.context_size - remainder)) + + query_states = self.to_q(hidden_states) + key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) + + query_states = query_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + key_states = key_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + value_states = value_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + + # shaw's relative positional embedding + rel_pos_emb = self.rel_pos_emb(attention_dists) + # alternative computation of `pos_attn` - for readability + # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) + # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale + # einsum implementation of pos_attn - gives x30 speedup over the alternative + # TODO (@avihu111) find a fast alternative to einsum + pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale + + if remainder > 0: + # masked attention in the extended block + mask = torch.ones(self.context_size, self.context_size, dtype=bool, device=hidden_states.device) + mask[:remainder, :remainder] = 0 + mask_value = -torch.finfo(pos_attn.dtype).max + pos_attn[:, -1, :].masked_fill_(mask, mask_value) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=pos_attn, scale=self.scale + ) + out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) + out = self.to_out(out[:, :num_features, :]) + return self.dropout(out) + + +class GraniteSpeechConformerDepthWiseConv1d(nn.Module): + """Wrapper for padded 1D pointwise convolution.""" + + def __init__(self, chan_in: int, chan_out: int, kernel_size: int): + super().__init__() + # Padding for the 1D conv is symmetric or close (i.e., offset by one). + pad = kernel_size // 2 + pad_offset = (kernel_size + 1) % 2 + self.padding = (pad, pad - pad_offset) + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.padding) + return self.conv(hidden_states) + + +class GraniteSpeechConformerConvModule(nn.Module): + """Conformer conv module consisting of several 1D/depthwise 1D convolutional layers.""" + + def __init__(self, config: GraniteSpeechEncoderConfig): + super().__init__() + inner_dim = config.hidden_dim * config.conv_expansion_factor + + self.norm = nn.LayerNorm(config.hidden_dim) + self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) + self.glu = nn.GLU(dim=1) + self.depth_conv = GraniteSpeechConformerDepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=config.conv_kernel_size, + ) + self.silu = nn.SiLU() + self.batch_norm = nn.BatchNorm1d(inner_dim) + self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) + hidden_states = self.glu(hidden_states) + hidden_states = self.depth_conv(hidden_states) + hidden_states = self.silu(self.batch_norm(hidden_states)) + hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GraniteSpeechConformerBlock(nn.Module): + """Conformer block, consisting largely of linear layers, attention, and convolutional layers.""" + + def __init__(self, config: GraniteSpeechEncoderConfig): + super().__init__() + self.ff1 = GraniteSpeechConformerFeedForward(config) + self.attn = GraniteSpeechConformerAttention(config) + self.conv = GraniteSpeechConformerConvModule(config) + self.ff2 = GraniteSpeechConformerFeedForward(config) + self.post_norm = nn.LayerNorm(config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states + hidden_states = self.attn(hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = self.conv(hidden_states) + hidden_states + hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states + hidden_states = self.post_norm(hidden_states) + return hidden_states + + +class GraniteSpeechCTCEncoder(nn.Module): + def __init__(self, config: GraniteSpeechEncoderConfig): + super().__init__() + self.config = config + + # Precompute clamped relative positional encoding distances + seq = torch.arange(config.context_size) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + attention_dists = torch.clamp(relpos_dist, -config.context_size, config.context_size) + config.max_pos_emb + self.register_buffer("attention_dists", attention_dists, persistent=False) + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) + self.layers = nn.ModuleList([GraniteSpeechConformerBlock(config) for _ in range(config.num_layers)]) + + self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) + self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) + self.num_layers = config.num_layers + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.input_linear(hidden_states) + for idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, attention_dists=self.attention_dists) + + if idx == self.num_layers // 2: + hidden_states_mid = hidden_states.clone() + hidden_states_mid = self.out(hidden_states_mid) + hidden_states += self.out_mid(nn.Softmax(dim=-1)(hidden_states_mid)) + return hidden_states + + +@auto_docstring +class GraniteSpeechPreTrainedModel(PreTrainedModel): + config: GraniteSpeechConfig + + _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this + _supports_sdpa = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + std = self.config.initializer_range + + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, GraniteSpeechEncoderProjector): + module.query.data.normal_() + + +@auto_docstring( + custom_intro=""" + The Granite Speech model, which consists of an audio encoder, projector, and language model. + """ +) +class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): + def __init__(self, config: GraniteSpeechConfig): + super().__init__(config) + # NOTE: It doesn't matter when we initialize from config, but we should be careful + # to make sure this does not pick up the adapter_config if in the future we use + # from_pretrained or something similar, since that should be set by the composite + # model; don't need to consider it twice + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + + self.encoder = GraniteSpeechCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechEncoderProjector(config) + + if config.has_lora_adapter and not is_peft_available(): + logger.warning( + "Config indicates that a lora adapter should be present, but " + "peft is not installed; this will cause the model to perform " + "incorrectly when audio inputs are provided. Please install " + "peft and reload the model!" + ) + + self.post_init() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def get_audio_features(self, input_features: torch.Tensor) -> torch.Tensor: + """Get the audio features to merged into the multimodal embeddings.""" + encoder_embeds = self.encoder(input_features) + projected_embeds = self.projector(encoder_embeds) + return projected_embeds + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[tuple[torch.Tensor], GraniteSpeechCausalLMOutputWithPast]: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + # TODO (@alex-jw-brooks) add an example to this docstring once models are released + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_features is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_features and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + # Get the base embeddings; set all audio tokens to 0 index + # to avoid out of vocabulary issues with the LLM embedding. + # Audio features will be masked into is_audio_idx indices later. + is_audio_idx = input_ids == self.config.audio_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[is_audio_idx] = 0 + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if input_features is not None: + if input_features.dtype != self.dtype: + input_features = input_features.to(self.dtype) + # Get the audio features from the encoder / projector + audio_embeds = self.get_audio_features(input_features) + + # Merge the audio features into the LLM embeddings + inputs_embeds = self.get_merged_audio_embeddings( + input_ids=input_ids, + audio_features=audio_embeds, + input_features_mask=input_features_mask, + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return GraniteSpeechCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + input_features=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + # If we're in cached decoding stage, input_features should be None because + # input ids do not contain special audio token anymore Otherwise we need + # input feature values to be passed to the model + if cache_position[0] == 0: + model_inputs["input_features"] = input_features + return model_inputs + + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Adds the audio token to the model's LLM vocabulary so that we can pass it + through the tokenizer; it's assumed that the embeddings corresponding to the + <|audio|> token will be clobbered with speech features. + + Args: + input_ids (`torch.Tensor`): + Input IDs containing one or more audio tokens. + audio_features (`torch.Tensor`): + Audio features to be masked into the language embeddings to form multimodal embeddings. + input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] + + # Mask the audio features into the text embeddings + special_audio_mask = is_audio_index.unsqueeze(-1) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + if torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)).item(): + raise ValueError("Number of audio tokens does not match number of audio features") + + audio_features = audio_features[input_features_mask] + + inputs_embeds = inputs_embeds.masked_scatter( + special_audio_mask, + audio_features, + ) + return inputs_embeds + + def generate(self, *args, **kwargs) -> torch.LongTensor: + # This model is expected to have a lora adapter, which is only + # enabled when considering audio inputs. As such, we override generate + # to conditionally enable / disable the lora adapter based on whether + # or not any input features were provided. + + input_features = kwargs.pop("input_features", None) + if is_peft_available and self._hf_peft_config_loaded: + if input_features is not None: + self.enable_adapters() + else: + self.disable_adapters() + return super().generate(*args, input_features=input_features, **kwargs) + + def save_pretrained(self, save_directory, *args, **kwargs): + # overwrite save_pretrained to first save the adapter if we have one + if is_peft_available and self._hf_peft_config_loaded: + adapter_name = self._get_adapter_name() + self.peft_config[adapter_name].base_model_name_or_path = save_directory + super().save_pretrained(save_directory, *args, **kwargs) + # Then save the base model afterwards + prev_val = self._hf_peft_config_loaded + self._hf_peft_config_loaded = False + super().save_pretrained(save_directory, *args, **kwargs) + self._hf_peft_config_loaded = prev_val + + @staticmethod + def _fix_state_dict_key_on_save(key) -> tuple[str, bool]: + # save the model with the original weights format + return key.replace(".base_layer", ""), False + + def _fix_state_dict_keys_on_save(self, state_dict): + if is_peft_available and self._hf_peft_config_loaded: + # state dict is only adapter, should keep the same + return state_dict + # rename back the base model state dict + return { + self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key + } + + def _get_adapter_name(self): + return list(self.peft_config.keys())[0] + + +__all__ = [ + "GraniteSpeechCTCEncoder", + "GraniteSpeechForConditionalGeneration", + "GraniteSpeechPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/granite_speech/processing_granite_speech.py b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/processing_granite_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..84515d173c471198b987081198aeeed9415252c9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granite_speech/processing_granite_speech.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for Granite Speech.""" + +from typing import Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessorMixin +from ...tokenization_utils import PreTokenizedInput, TextInput +from ...utils import is_torch_available, logging +from ...utils.import_utils import requires_backends + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class GraniteSpeechProcessor(ProcessorMixin): + attributes = ["audio_processor", "tokenizer"] + audio_processor_class = "GraniteSpeechFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + audio_processor, + tokenizer, + audio_token="<|audio|>", + chat_template=None, + ): + self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token + super().__init__(audio_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], + audio: Union["torch.Tensor", list["torch.Tensor"]] = None, + device: str = "cpu", + images=None, + videos=None, + **kwargs, + ) -> BatchFeature: + requires_backends(self, ["torch"]) + + text = self._get_validated_text(text) + prompt_strings = text + + if audio is not None: + # NOTE - we intentionally avoid throwing for potentially misaligned + # text / audio inputs here because some inference engines will + # trigger the conditions due to the way they call multimodal + # processors, e.g., vLLM. + audio_inputs = self.audio_processor(audio, device=device) + + # TODO (@alex-jw-brooks); we should add a util to get_num_audio_tokens + # from feature lengths and call it here, rather than returning it + # from the feature extractor. + audio_embed_sizes = audio_inputs.pop("audio_embed_sizes") + + # Expand the audio placeholders to match the feature dims; this + # is similar to how many VLMs handle image tokens, e.g., llava next + prompt_strings = [] + num_replaced = 0 + for sample in text: + while self.audio_token in sample: + sample = sample.replace( + self.audio_token, + "" * audio_embed_sizes[num_replaced], + 1, + ) + num_replaced += 1 + prompt_strings.append(sample) + + prompt_strings = [sample.replace("", self.audio_token) for sample in prompt_strings] + else: + audio_inputs = {} + + if "padding" not in kwargs: + kwargs["padding"] = True + text_inputs = self.tokenizer(prompt_strings, **kwargs) + return BatchFeature(data={**text_inputs, **audio_inputs}) + + def _get_validated_text(self, text: Union[str, list]) -> list[str]: + if isinstance(text, str): + return [text] + elif isinstance(text, list) and isinstance(text[0], str): + return text + raise TypeError("Invalid text provided! Text should be a string or list of strings.") + + +__all__ = ["GraniteSpeechProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4db962865ee962ab366b4f2e683002e120b2b8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_granitemoehybrid import * + from .modeling_granitemoehybrid import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..3f317a063b3acf701768b0dbe7b165b5f41e847e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GraniteMoeHybrid model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class GraniteMoeHybridConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GraniteMoeHybridConfig`]. It is used to + instantiate an GraniteMoeHybrid model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the GraniteMoeHybrid model. Defines the number of different tokens that + can be represented by the `inputs_ids` passed when calling [`GraniteMoeHybridModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + Only relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier. + logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits. + residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier. + attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier. + num_local_experts (`int`, *optional*, defaults to 8): total number of experts. + num_experts_per_tok (`int`, *optional*, defaults to 2): number of experts per token. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxiliary loss coefficient + shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts. + position_embedding_type (`str`, *optional*): Positional embedding + type to be used; defaults to None. Allowed options: `[None, "rope"]` + layer_types (`List`, *optional*): list of strings to be used as layer types. + Allowed choices: "mamba", "attention". + mamba_n_heads (`int`, *optional*, defaults to 128): + The number of mamba heads used. + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups used. + mamba_d_state (`int`, *optional*, defaults to 256): + The dimension the mamba latent state space. + mamba_d_head (`int`, *optional*, defaults to `"auto"`): + Head embedding dimension size. + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size. + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training. + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) + of the mamba mixer block. + ```python + >>> from transformers import GraniteMoeHybridModel, GraniteMoeHybridConfig + + >>> # Initializing a GraniteMoeHybrid config + >>> configuration = GraniteMoeHybridConfig() + + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "granitemoehybrid" + attribute_map = { + "layers_block_type": "layer_types", + } + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + embedding_multiplier=1.0, + logits_scaling=1.0, + residual_multiplier=1.0, + attention_multiplier=1.0, + num_local_experts=8, + num_experts_per_tok=2, + output_router_logits=False, + router_aux_loss_coef=0.001, + shared_intermediate_size=1024, + position_embedding_type=None, + layer_types=None, + mamba_n_heads=128, + mamba_n_groups=1, + mamba_d_state=256, + mamba_d_head="auto", + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.embedding_multiplier = embedding_multiplier + self.logits_scaling = logits_scaling + self.residual_multiplier = residual_multiplier + self.attention_multiplier = attention_multiplier + self.attention_dropout = attention_dropout + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.shared_intermediate_size = shared_intermediate_size + self.position_embedding_type = position_embedding_type + + mamba_intermediate = mamba_expand * hidden_size + + if layer_types is not None and any(layer_type not in ["mamba", "attention"] for layer_type in layer_types): + raise ValueError("layer_types must be a list strings in [`mamba` `attention`]") + + if mamba_intermediate % mamba_n_heads != 0: + raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size") + + # for the mamba_v2, must satisfy the following + if mamba_d_head == "auto": + mamba_d_head = mamba_intermediate // mamba_n_heads + + if mamba_d_head * mamba_n_heads != mamba_intermediate: + raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size") + + self.mamba_n_heads = mamba_n_heads + self.mamba_d_head = mamba_d_head + self.mamba_n_groups = mamba_n_groups + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.mamba_expand = mamba_expand + self.layer_types = layer_types + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + if self.position_embedding_type == "rope": + rope_config_validation(self) + + # overwrite the function to use in `HybridMambaAttentionDynamicCache` + @property + def layers_block_type(self): + return self.layer_types if self.layer_types else ["mamba"] * self.num_hidden_layers + + +__all__ = ["GraniteMoeHybridConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6059720b04b1befd7d6373a6d3409ccf97a0fa --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -0,0 +1,1841 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granitemoehybrid.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, TypedDict, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from .configuration_granitemoehybrid import GraniteMoeHybridConfig + + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoeHybrid +# no longer copied after attention refactors +class GraniteMoeHybridAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + + self.scaling = config.attention_multiplier + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings if position_embeddings is not None else (None, None) + if position_embeddings is not None: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class HybridMambaAttentionDynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + conv_kernel_size = config.mamba_d_conv + ssm_state_size = config.mamba_d_state + + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "mamba": + self.conv_states += [ + torch.zeros( + batch_size, + (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), + conv_kernel_size, + device=device, + dtype=dtype, + ) + ] + self.ssm_states += [ + torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + ssm_state_size, + device=device, + dtype=dtype, + ) + ] + else: + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class GraniteMoeHybridMambaLayer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + The are a few differences between this and Mamba2Mixer: + - The variable use_precomputed_states is slightly different due to the hybrid cache structure + - There's a few non-obvious bugs fixed with batching in the slow path that exist in main + - Some extra variables that our layer doesn't need have been removed + - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged + """ + + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependent + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.norm = GraniteMoeHybridRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = nn.Parameter(torch.ones(self.num_heads)) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for GraniteMoeHybrid will be used when running the model on a GPU") + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.IntTensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=seq_idx, + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + # storing the states + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=seq_idx, + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=seq_idx, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.IntTensor] = None, + **kwargs, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx) + if seq_idx is not None: + raise NotImplementedError( + "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`" + ) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class GraniteMoeHybridRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +class GraniteMoeHybridMLP(nn.Module): + """ + MLP layer for shared experts + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.shared_intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False) + self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.input_linear(hidden_states) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + hidden_states = self.output_linear(hidden_states) + return hidden_states + + +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + +class GraniteMoeHybridRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteMoeHybridParallelExperts(nn.Module): + def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: + """ + Initialize the GraniteMoeHybridParallelExperts module. + The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with + many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and + [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the + [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py) + used in vllm. + + Args: + num_experts (int): + Number of experts. + input_size (int): + Size of the input. + output_size (int): + Size of the output. + """ + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def forward(self, inputs, expert_size): + """ + Forward pass of the GraniteMoeHybridParallelExperts module. + + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + + Returns: + Tensor: Output tensor. + """ + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for i in range(self.num_experts): + output_list.append(F.linear(input_list[i], self.weight[i])) + results = torch.cat(output_list, dim=0) + return results + + +class GraniteMoeHybridTopKGating(nn.Module): + def __init__(self, input_size: int, num_experts: int, top_k: int): + """ + Initialize the top-k gating mechanism. + Args: + input_size (`int`): + Size of the input. + num_experts (`int`): + Number of experts. + top_k (`int`): + Number of top experts to select. + """ + super().__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.top_k = top_k + + self.layer = nn.Linear(input_size, num_experts, bias=False) + + def forward(self, hidden_states): + # compute the top_k routing decision + logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts] + top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k] + top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k] + + # compute number of input given to each expert + zeros = torch.zeros( + [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device + ) # [num_tokens, num_experts] + gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] + expert_size = gates.long().sum(0) # [num_experts,] + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) + # (and `DataDependentOutputException`) + expert_size = expert_size.tolist() + + # sort and group input tokens according to expert assignment + top_k_experts = top_k_indices.flatten() # [num_tokens * top_k] + _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k] + batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k] + + # gather the gate values for grouped input tokens + top_k_gates = top_k_gates.flatten() # [num_tokens * top_k] + batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k] + + return index_sorted_experts, batch_index, batch_gates, expert_size, logits + + +class GraniteMoeHybridMoE(nn.Module): + """ + A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = GraniteMoeHybridParallelExperts( + config.num_local_experts, self.input_size, self.hidden_size * 2 + ) + self.output_linear = GraniteMoeHybridParallelExperts( + config.num_local_experts, self.hidden_size, self.input_size + ) + + self.router = GraniteMoeHybridTopKGating( + input_size=self.input_size, + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + ) + + def forward(self, layer_input): + """ + Forward pass of the mixture of experts layer. + + Args: + layer_input (Tensor): + Input tensor. + + Returns: + Tensor: + Output tensor. + Tensor: + Router logits. + """ + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) + + expert_inputs = layer_input[batch_index] + hidden_states = self.input_linear(expert_inputs, expert_size) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + expert_outputs = self.output_linear(hidden_states, expert_size) + + expert_outputs = expert_outputs * batch_gates[:, None] + + zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) + layer_output = zeros.index_add(0, batch_index, expert_outputs) + layer_output = layer_output.view(bsz, length, self.input_size) + return layer_output, router_logits + + +class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + # Either attention or mamba will be initialized, depending on the layer type. + self.self_attn = None + if config.num_local_experts > 0: + self.block_sparse_moe = GraniteMoeHybridMoE(config) + self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + self.shared_mlp = GraniteMoeHybridMLP(config) + self.mamba = None + + if config.layers_block_type[layer_idx] == "mamba": + self.mamba = GraniteMoeHybridMambaLayer(config, layer_idx) + else: + self.self_attn = GraniteMoeHybridAttention(config, layer_idx) + self.layer_type = config.layers_block_type[layer_idx] + + # Accept 0 experts: skip MoE if num_local_experts == 0 + self.has_experts = getattr(config, "num_local_experts", 0) > 0 + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + output_router_logits: Optional[bool] = False, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + past_key_values (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.mamba is not None: + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_position=cache_position, + cache_params=past_key_values, + attention_mask=attention_mask, + **kwargs, + ) + # No attention weights for state space layers + self_attn_weights = None + else: + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.has_experts: + moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + else: + hidden_states = self.shared_mlp(hidden_states) + router_logits = None + + hidden_states = residual + hidden_states * self.residual_multiplier + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@auto_docstring +class GraniteMoeHybridPreTrainedModel(PreTrainedModel): + config: GraniteMoeHybridConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GraniteMoeHybridDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _is_stateful = True + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteMoeHybridParallelExperts): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, GraniteMoeHybridMambaLayer): + module.dt_bias.data.fill_(1.0) + module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) + module.D.data.fill_(1.0) + elif isinstance(module, GraniteMoeHybridRMSNormGated): + module.weight.data.fill_(1.0) + + +class GraniteMoeHybridRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GraniteMoeHybridConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel): + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.embedding_multiplier = config.embedding_multiplier + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.position_embedding_type = config.position_embedding_type + self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if self.position_embedding_type == "rope" else None + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + ## overwritten because `HybridMambaAttentionDynamicCache` is needed + if use_cache and past_key_values is None: + logger.warning_once( + "GraniteMoeHybrid requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. " + "Because one was not provided, no cache will be returned." + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # embed positions + hidden_states = inputs_embeds + + position_embeddings = None + # create position embeddings to be shared across the decoder layers + if self.rotary_emb is not None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + output_router_logits=output_router_logits, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + if layer_outputs[-1] is not None: + # append router logits only of expert layers. Regular MLP layers return `None` as the router logits + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + device_index = routing_weights.device.index if routing_weights.device.index is not None else 0 + rank = routing_weights.shape[1] * int(device_index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) + return overall_loss * num_experts + + +class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__(config) + self.model = GraniteMoeHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[tuple, MoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GraniteMoeHybridForCausalLM + + >>> model = GraniteMoeHybridForCausalLM.from_pretrained("ibm/PowerMoE-3b") + >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + # Only compute necessary logits + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = logits / self.config.logits_scaling + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Flatten the tokens + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or cache_position[-1] >= input_ids.shape[1] # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif use_cache: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cache_position": cache_position, + } + ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + + +__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/modular_granitemoehybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..55ad2b43c1bbceb07dc5a6945b6a24d80ded4615 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -0,0 +1,395 @@ +# coding=utf-8 +# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch +from torch import nn + +from ...cache_utils import Cache +from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ..bamba.configuration_bamba import BambaConfig +from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache +from ..granitemoeshared.modeling_granitemoeshared import ( + GraniteFlashAttentionKwargs, + GraniteMoeSharedAttention, + GraniteMoeSharedDecoderLayer, + GraniteMoeSharedForCausalLM, + GraniteMoeSharedMLP, + GraniteMoeSharedModel, + GraniteMoeSharedPreTrainedModel, +) +from .configuration_granitemoehybrid import GraniteMoeHybridConfig + + +logger = logging.get_logger(__name__) + + +class GraniteMoeHybridAttention(GraniteMoeSharedAttention): + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__(config, layer_idx) + + +class GraniteMoeHybridMambaLayer(BambaMixer): + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__(BambaConfig(config), layer_idx) + + +class GraniteMoeHybridRMSNormGated(BambaRMSNormGated): + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps) + + +class GraniteMoeHybridMLP(GraniteMoeSharedMLP): + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__(config) + + +class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.shared_mlp = GraniteMoeHybridMLP(config) + # Either attention or mamba will be initialized, depending on the layer type. + self.self_attn = None + self.mamba = None + + if config.layers_block_type[layer_idx] == "mamba": + self.mamba = GraniteMoeHybridMambaLayer(config, layer_idx) + else: + self.self_attn = GraniteMoeHybridAttention(config, layer_idx) + self.layer_type = config.layers_block_type[layer_idx] + + # Accept 0 experts: skip MoE if num_local_experts == 0 + self.has_experts = getattr(config, "num_local_experts", 0) > 0 + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + output_router_logits: Optional[bool] = False, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + past_key_values (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.mamba is not None: + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_position=cache_position, + cache_params=past_key_values, + attention_mask=attention_mask, + **kwargs, + ) + # No attention weights for state space layers + self_attn_weights = None + else: + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.has_experts: + moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + else: + hidden_states = self.shared_mlp(hidden_states) + router_logits = None + + hidden_states = residual + hidden_states * self.residual_multiplier + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel): + config: GraniteMoeHybridConfig + _no_split_modules = ["GraniteMoeHybridDecoderLayer"] + _is_stateful = True + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteMoeHybridMambaLayer): + module.dt_bias.data.fill_(1.0) + module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) + module.D.data.fill_(1.0) + elif isinstance(module, GraniteMoeHybridRMSNormGated): + module.weight.data.fill_(1.0) + + +class GraniteMoeHybridModel(GraniteMoeSharedModel): + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + ## overwritten because `HybridMambaAttentionDynamicCache` is needed + if use_cache and past_key_values is None: + logger.warning_once( + "GraniteMoeHybrid requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. " + "Because one was not provided, no cache will be returned." + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # embed positions + hidden_states = inputs_embeds + + position_embeddings = None + # create position embeddings to be shared across the decoder layers + if self.rotary_emb is not None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + output_router_logits=output_router_logits, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + if layer_outputs[-1] is not None: + # append router logits only of expert layers. Regular MLP layers return `None` as the router logits + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__(config) + self.model = GraniteMoeHybridModel(config) + # Initialize weights and apply final processing + self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or cache_position[-1] >= input_ids.shape[1] # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif use_cache: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cache_position": cache_position, + } + ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + + +__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/helium/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/helium/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73d0966e5c1c489944c4b539c0cc06384c985c87 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/helium/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_helium import * + from .modeling_helium import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/helium/configuration_helium.py b/venv/lib/python3.13/site-packages/transformers/models/helium/configuration_helium.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb4d8d88750bc5e7becd8cd20a2ef8db1d5936b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/helium/configuration_helium.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig + + +class HeliumConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`HeliumModel`]. It is used to instantiate an Helium + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Helium 2b model. + e.g. [kyutai/helium-2b](https://huggingface.co/kyutai/helium-2b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 48000): + Vocabulary size of the Helium model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`HeliumModel`] + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 7040): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 20): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-08): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 100000.0): + The base period of the RoPE embeddings. + pad_token_id (`int`, *optional*, defaults to 3): + Padding token id. + eos_token_id (`int` | `list`, *optional*, defaults to 2): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + ```python + >>> from transformers import HeliumModel, HeliumConfig + >>> # Initializing a Helium 2b style configuration + >>> configuration = HeliumConfig() + >>> # Initializing a model from the Helium 2b style configuration + >>> model = HeliumModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "helium" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=48000, + hidden_size=2560, + intermediate_size=7040, + num_hidden_layers=24, + num_attention_heads=20, + num_key_value_heads=20, + head_dim=128, + hidden_act="silu", + attention_dropout=0.0, + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-8, + use_cache=True, + tie_word_embeddings=False, + rope_theta=100000.0, + pad_token_id=3, + eos_token_id=2, + bos_token_id=1, + attention_bias=False, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["HeliumConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/helium/modeling_helium.py b/venv/lib/python3.13/site-packages/transformers/models/helium/modeling_helium.py new file mode 100644 index 0000000000000000000000000000000000000000..c0103cefa9b08fc7af627359817c0b6191d4452b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/helium/modeling_helium.py @@ -0,0 +1,497 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/helium/modular_helium.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_helium.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_helium import HeliumConfig + + +class HeliumRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class HeliumRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: HeliumConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class HeliumMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed, k_embed + + +class HeliumAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = 1 / math.sqrt(self.head_dim) + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class HeliumDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = HeliumAttention(config=config, layer_idx=layer_idx) + + self.mlp = HeliumMLP(config) + self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class HeliumPreTrainedModel(PreTrainedModel): + config: HeliumConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["HeliumDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": HeliumDecoderLayer, + "attentions": HeliumAttention, + } + + +@auto_docstring +class HeliumModel(HeliumPreTrainedModel): + def __init__(self, config: HeliumConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = HeliumRotaryEmbedding(config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = HeliumModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, HeliumForCausalLM + + >>> model = HeliumForCausalLM.from_pretrained("google/helium-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/helium-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class HeliumForSequenceClassification(GenericForSequenceClassification, HeliumPreTrainedModel): + pass + + +class HeliumForTokenClassification(GenericForTokenClassification, HeliumPreTrainedModel): + pass + + +__all__ = [ + "HeliumPreTrainedModel", + "HeliumModel", + "HeliumForCausalLM", + "HeliumForSequenceClassification", + "HeliumForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/helium/modular_helium.py b/venv/lib/python3.13/site-packages/transformers/models/helium/modular_helium.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2538d438f9126980e7347aa98896b52a77f6eb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/helium/modular_helium.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional + +import torch +import torch.nn as nn + +from ...utils import logging +from ..gemma.modeling_gemma import GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification +from ..granite.modeling_granite import GraniteAttention +from ..llama.modeling_llama import LlamaDecoderLayer, LlamaMLP, LlamaModel, LlamaPreTrainedModel, LlamaRotaryEmbedding +from .configuration_helium import HeliumConfig + + +logger = logging.get_logger(__name__) + + +class HeliumRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class HeliumRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class HeliumMLP(LlamaMLP): + pass + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed, k_embed + + +class HeliumAttention(GraniteAttention): + def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.scaling = 1 / math.sqrt(self.head_dim) + + +class HeliumDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.mlp = HeliumMLP(config) + self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class HeliumPreTrainedModel(LlamaPreTrainedModel): + pass + + +class HeliumModel(HeliumPreTrainedModel, LlamaModel): + def __init__(self, config: HeliumConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = HeliumRotaryEmbedding(config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + +class HeliumForCausalLM(GemmaForCausalLM): + pass + + +class HeliumForSequenceClassification(GemmaForSequenceClassification): + pass + + +class HeliumForTokenClassification(GemmaForTokenClassification): + pass + + +__all__ = [ + "HeliumPreTrainedModel", + "HeliumModel", + "HeliumForCausalLM", + "HeliumForSequenceClassification", + "HeliumForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/herbert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/herbert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d0794a06e8cfbe4178a72e2b09d5292ddbc4fb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/herbert/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_herbert import * + from .tokenization_herbert_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/herbert/tokenization_herbert.py b/venv/lib/python3.13/site-packages/transformers/models/herbert/tokenization_herbert.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c6bacc87fc68bf5a4d998236e878fdf602a8f0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/herbert/tokenization_herbert.py @@ -0,0 +1,617 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors, Allegro.pl, Facebook Inc. and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import re +import unicodedata +from typing import Optional + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +# Copied from transformers.models.xlm.tokenization_xlm.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length + strings) + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +# Copied from transformers.models.xlm.tokenization_xlm.replace_unicode_punct +def replace_unicode_punct(text): + """ + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl + """ + text = text.replace(",", ",") + text = re.sub(r"。\s*", ". ", text) + text = text.replace("、", ",") + text = text.replace("”", '"') + text = text.replace("“", '"') + text = text.replace("∶", ":") + text = text.replace(":", ":") + text = text.replace("?", "?") + text = text.replace("《", '"') + text = text.replace("》", '"') + text = text.replace(")", ")") + text = text.replace("!", "!") + text = text.replace("(", "(") + text = text.replace(";", ";") + text = text.replace("1", "1") + text = text.replace("」", '"') + text = text.replace("「", '"') + text = text.replace("0", "0") + text = text.replace("3", "3") + text = text.replace("2", "2") + text = text.replace("5", "5") + text = text.replace("6", "6") + text = text.replace("9", "9") + text = text.replace("7", "7") + text = text.replace("8", "8") + text = text.replace("4", "4") + text = re.sub(r".\s*", ". ", text) + text = text.replace("~", "~") + text = text.replace("’", "'") + text = text.replace("…", "...") + text = text.replace("━", "-") + text = text.replace("〈", "<") + text = text.replace("〉", ">") + text = text.replace("【", "[") + text = text.replace("】", "]") + text = text.replace("%", "%") + return text + + +# Copied from transformers.models.xlm.tokenization_xlm.remove_non_printing_char +def remove_non_printing_char(text): + """ + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl + """ + output = [] + for char in text: + cat = unicodedata.category(char) + if cat.startswith("C"): + continue + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class HerbertTokenizer(PreTrainedTokenizer): + """ + Construct a BPE tokenizer for HerBERT. + + Peculiarities: + + - uses BERT's pre-tokenizer: BaseTokenizer splits tokens on spaces, and also on punctuation. Each occurrence of a + punctuation character will be treated separately. + + - Such pretokenized input is BPE subtokenized + + This tokenizer inherits from [`XLMTokenizer`] which contains most of the methods. Users should refer to the + superclass for more information regarding methods. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + merges_file, + tokenizer_file=None, + cls_token="", + unk_token="", + pad_token="", + mask_token="", + sep_token="", + bos_token="", + do_lowercase_and_remove_accent=False, + additional_special_tokens=[ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + lang2id=None, + id2lang=None, + **kwargs, + ): + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use HerbertTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses + + # cache of sm.MosesPunctNormalizer instance + self.cache_moses_punct_normalizer = {} + # cache of sm.MosesTokenizer instance + self.cache_moses_tokenizer = {} + self.lang_with_custom_tokenizer = {"zh", "th", "ja"} + # True for current supported model (v1.2.0), False for XLM-17 & 100 + self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent + self.lang2id = lang2id + self.id2lang = id2lang + if lang2id is not None and id2lang is not None: + assert len(lang2id) == len(id2lang) + + self.ja_word_tokenizer = None + self.zh_word_tokenizer = None + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[:-1] + merges = [tuple(merge.split()[:2]) for merge in merges] + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + + super().__init__( + unk_token=unk_token, + bos_token=bos_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + additional_special_tokens=additional_special_tokens, + lang2id=lang2id, + id2lang=id2lang, + do_lowercase_and_remove_accent=do_lowercase_and_remove_accent, + tokenizer_file=None, + **kwargs, + ) + + self.bert_pre_tokenizer = BasicTokenizer( + do_lower_case=False, + never_split=self.all_special_tokens, + tokenize_chinese_chars=False, + strip_accents=False, + ) + + @property + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.do_lower_case + def do_lower_case(self): + return self.do_lowercase_and_remove_accent + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_punct_norm + def moses_punct_norm(self, text, lang): + if lang not in self.cache_moses_punct_normalizer: + punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang) + self.cache_moses_punct_normalizer[lang] = punct_normalizer + else: + punct_normalizer = self.cache_moses_punct_normalizer[lang] + return punct_normalizer.normalize(text) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_tokenize + def moses_tokenize(self, text, lang): + if lang not in self.cache_moses_tokenizer: + moses_tokenizer = self.sm.MosesTokenizer(lang=lang) + self.cache_moses_tokenizer[lang] = moses_tokenizer + else: + moses_tokenizer = self.cache_moses_tokenizer[lang] + return moses_tokenizer.tokenize(text, return_str=False, escape=False) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_pipeline + def moses_pipeline(self, text, lang): + text = replace_unicode_punct(text) + text = self.moses_punct_norm(text, lang) + text = remove_non_printing_char(text) + return text + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.ja_tokenize + def ja_tokenize(self, text): + if self.ja_word_tokenizer is None: + try: + import Mykytea + + self.ja_word_tokenizer = Mykytea.Mykytea( + f"-model {os.path.expanduser('~')}/local/share/kytea/model.bin" + ) + except (AttributeError, ImportError): + logger.error( + "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper" + " (https://github.com/chezou/Mykytea-python) with the following steps" + ) + logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") + logger.error("2. autoreconf -i") + logger.error("3. ./configure --prefix=$HOME/local") + logger.error("4. make && make install") + logger.error("5. pip install kytea") + raise + return list(self.ja_word_tokenizer.getWS(text)) + + @property + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.vocab_size + def vocab_size(self): + return len(self.encoder) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.bpe + def bpe(self, token): + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + def _tokenize(self, text): + pre_tokens = self.bert_pre_tokenizer.tokenize(text) + + split_tokens = [] + for token in pre_tokens: + if token: + split_tokens.extend(list(self.bpe(token).split(" "))) + + return split_tokens + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = "".join(tokens).replace("", " ").strip() + return out_string + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An XLM sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + + """ + bos = [self.bos_token_id] + sep = [self.sep_token_id] + + if token_ids_1 is None: + return bos + token_ids_0 + sep + return bos + token_ids_0 + sep + token_ids_1 + sep + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sm"] = None + return state + + # Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use XLMTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses + + +__all__ = ["HerbertTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/herbert/tokenization_herbert_fast.py b/venv/lib/python3.13/site-packages/transformers/models/herbert/tokenization_herbert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc24e3c6a6e20a847043379f898fead1beb819f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/herbert/tokenization_herbert_fast.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors, Allegro.pl, Facebook Inc. and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_herbert import HerbertTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +class HerbertTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "Fast" BPE tokenizer for HerBERT (backed by HuggingFace's *tokenizers* library). + + Peculiarities: + + - uses BERT's pre-tokenizer: BertPreTokenizer splits tokens on spaces, and also on punctuation. Each occurrence of + a punctuation character will be treated separately. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the methods. Users should refer to the + superclass for more information regarding methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = HerbertTokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + cls_token="", + unk_token="", + pad_token="", + mask_token="", + sep_token="", + **kwargs, + ): + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + sep_token=sep_token, + **kwargs, + ) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An HerBERT, like BERT sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + cls = [self.cls_token_id] + sep = [self.sep_token_id] + if token_ids_1 is None: + return cls + token_ids_0 + sep + + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["HerbertTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4adb66825445f25a1c34bd7b3b86e60eed7be85f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_idefics import * + from .image_processing_idefics import * + from .modeling_idefics import * + from .modeling_tf_idefics import * + from .processing_idefics import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/configuration_idefics.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/configuration_idefics.py new file mode 100644 index 0000000000000000000000000000000000000000..e8320b98725d0b02c85783464efe11dbb0e8f0a8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/configuration_idefics.py @@ -0,0 +1,325 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Idefics model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class IdeficsVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an + Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Idefics-9B. + + e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`) + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + intermediate_size (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of image channels. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization + testing). + """ + + model_type = "idefics_vision" + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + embed_dim=768, + image_size=224, + intermediate_size=5120, + patch_size=14, + num_hidden_layers=32, + num_attention_heads=16, + num_channels=3, + hidden_act="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + self.embed_dim = embed_dim + self.image_size = image_size + self.intermediate_size = intermediate_size + self.patch_size = patch_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + + super().__init__(**kwargs) + + +class IdeficsPerceiverConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an + Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Idefics-9B. + + e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_resampler (`bool`, *optional*, defaults to `False`): + Whether or not to use the resampler + resampler_n_latents (`int`, *optional*, defaults to 64): + Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). + resampler_depth (`int`, *optional*, defaults to 6): + Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). + resampler_n_heads (`int`, *optional*, defaults to 16): + Number of heads in each Transformer block (for multi-headed self-attention). + resampler_head_dim (`int`, *optional*, defaults to 96): + Dimensionality of each head projection in the Transformer block. + qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`): + Whether or not to use qk layer norms in perceiver + """ + + model_type = "idefics_perciever" + + def __init__( + self, + use_resampler=False, + resampler_n_latents=64, + resampler_depth=6, + resampler_n_heads=16, + resampler_head_dim=96, + qk_layer_norms_perceiver=False, + **kwargs, + ): + self.use_resampler = use_resampler + self.resampler_n_latents = resampler_n_latents + self.resampler_depth = resampler_depth + self.resampler_n_heads = resampler_n_heads + self.resampler_head_dim = resampler_head_dim + self.qk_layer_norms_perceiver = qk_layer_norms_perceiver + + super().__init__(**kwargs) + + +class IdeficsConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an + Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Idefics-9B. + + e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + additional_vocab_size (`int`, *optional*, defaults to 0): + Additional vocabulary size of the model, typically for the special "" token. Additional vocab tokens + are always trainable whereas regular vocab tokens can be frozen or not. + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~IdeficsModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + alpha_initializer (`str`, *optional*, defaults to `"zeros"`): + Initialization type for the alphas. + alphas_initializer_range (`float`, *optional*, defaults to 0.0): + The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross + Attention. + alpha_type (`str`, *optional*, defaults to `"float"`): + Whether the gating alphas should be vectors or single floats. + rms_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0) + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1) + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2) + End of stream token id. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + cross_layer_interval (`int`, *optional*, default to 1) + Interval for cross attention (from text to image) layers. + qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k + freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers + freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`): + Exceptions to freezing text layers when `freeze_text_layers` is `True` + freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head + freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers + freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`): + Exceptions to freezing vision layers when `freeze_vision_layers` is `True` + use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler + vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict + perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict + + Example: + + ```python + >>> from transformers import IdeficsModel, IdeficsConfig + + >>> # Initializing a Idefics idefics-9b style configuration + >>> configuration = IdeficsConfig() + + >>> # Initializing a model from the idefics-9b style configuration + >>> model = IdeficsModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "idefics" + sub_configs = {"perceiver_config": IdeficsPerceiverConfig, "vision_config": IdeficsVisionConfig} + + def __init__( + self, + vocab_size=32000, + additional_vocab_size=0, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + dropout=0.0, + hidden_act="silu", + initializer_range=0.02, + alpha_initializer="zeros", + alphas_initializer_range=0.0, + alpha_type="float", + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + cross_layer_interval=1, + qk_layer_norms=False, + freeze_text_layers=True, + freeze_text_module_exceptions=[], + freeze_lm_head=False, + freeze_vision_layers=True, + freeze_vision_module_exceptions=[], + use_resampler=False, + vision_config=None, + perceiver_config=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.additional_vocab_size = additional_vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.alpha_initializer = alpha_initializer + self.alphas_initializer_range = alphas_initializer_range + self.alpha_type = alpha_type + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.cross_layer_interval = cross_layer_interval + self.qk_layer_norms = qk_layer_norms + self.freeze_vision_layers = freeze_vision_layers + + self.freeze_text_layers = freeze_text_layers + self.freeze_text_module_exceptions = freeze_text_module_exceptions + self.freeze_vision_module_exceptions = freeze_vision_module_exceptions + self.freeze_lm_head = freeze_lm_head + + self.use_resampler = use_resampler + + if perceiver_config is None: + self.perceiver_config = IdeficsPerceiverConfig() + elif isinstance(perceiver_config, dict): + self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config) + elif isinstance(perceiver_config, IdeficsPerceiverConfig): + self.perceiver_config = perceiver_config + + if vision_config is None: + self.vision_config = IdeficsVisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = IdeficsVisionConfig(**vision_config) + elif isinstance(vision_config, IdeficsVisionConfig): + self.vision_config = vision_config + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since + # PretrainedConfig.from_dict first instantiates the class with the config dict and only then + # updates the config object with `kwargs` from from_pretrained, so during the instantiation + # of this object many attributes have default values and haven't yet been overridden. + # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run. + + +__all__ = ["IdeficsConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/image_processing_idefics.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/image_processing_idefics.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9085331cdeda9eeb3b94a4df00fd72f05c7776 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/image_processing_idefics.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Idefics.""" + +from typing import Callable, Optional, Union + +from PIL import Image + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + make_flat_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, is_torch_available + + +IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073] +IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711] + + +def convert_to_rgb(image): + # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background + # for transparent images. The call to `alpha_composite` handles this case + if image.mode == "RGB": + return image + + image_rgba = image.convert("RGBA") + background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) + alpha_composite = Image.alpha_composite(background, image_rgba) + alpha_composite = alpha_composite.convert("RGB") + return alpha_composite + + +class IdeficsImageProcessor(BaseImageProcessor): + r""" + Constructs a Idefics image processor. + + Args: + image_size (`int`, *optional*, defaults to 224): + Resize to image size + image_mean (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + image_num_channels (`int`, *optional*, defaults to 3): + Number of image channels. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int = 224, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + image_num_channels: Optional[int] = 3, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.image_size = image_size + self.image_num_channels = image_num_channels + self.image_mean = image_mean if image_mean is not None else IDEFICS_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IDEFICS_STANDARD_STD + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + + def preprocess( + self, + images: ImageInput, + image_num_channels: Optional[int] = 3, + image_size: Optional[dict[str, int]] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + transform: Optional[Callable] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + **kwargs, + ) -> TensorType: + """ + Preprocess a batch of images. + + Args: + images (`ImageInput`): + A list of images to preprocess. + image_size (`int`, *optional*, defaults to `self.image_size`): + Resize to image size + image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`): + Number of image channels. + image_mean (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can + be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` + method. Can be overridden by the `image_std` parameter in the `preprocess` method. + transform (`Callable`, *optional*, defaults to `None`): + A custom transform function that accepts a single image can be passed for training. For example, + `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is + assumed - and then a preset of inference-specific transforms will be applied to the images + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + + Returns: + a PyTorch tensor of the processed images + + """ + image_size = image_size if image_size is not None else self.image_size + image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + size = (image_size, image_size) + + if isinstance(images, list) and len(images) == 0: + return [] + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # For training a user needs to pass their own set of transforms as a Callable. + # For reference this is what was used in the original IDEFICS training: + # transform = transforms.Compose([ + # convert_to_rgb, + # transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), + # transforms.ToTensor(), + # transforms.Normalize(mean=image_mean, std=image_std), + # ]) + if transform is not None: + if not is_torch_available(): + raise ImportError("To pass in `transform` torch must be installed") + import torch + + images = [transform(x) for x in images] + return torch.stack(images) + + # for inference we do the exact transforms that were used to train IDEFICS + images = [convert_to_rgb(x) for x in images] + # further transforms expect numpy arrays + images = [to_numpy_array(x) for x in images] + images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images] + images = [self.rescale(image=image, scale=rescale_factor) for image in images] + images = [self.normalize(x, mean=image_mean, std=image_std) for x in images] + images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images] + images = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)["pixel_values"] + + return images + + +__all__ = ["IdeficsImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/modeling_idefics.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/modeling_idefics.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0a983bb2f719b87fa3b752b4b50a0da0350f39 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/modeling_idefics.py @@ -0,0 +1,1310 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics model.""" + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_idefics import IdeficsConfig +from .perceiver import IdeficsPerceiverResampler +from .vision import IdeficsVisionEmbeddings, IdeficsVisionTransformer + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding). + """ +) +class IdeficsBaseModelOutputWithPast(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Idefics causal language model (or autoregressive) outputs. + """ +) +class IdeficsCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +def expand_inputs_for_generation( + input_ids, + expand_size=1, + is_encoder_decoder=False, + attention_mask=None, + encoder_outputs=None, + **model_kwargs, +): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + model_kwargs["pixel_values"] = model_kwargs.get("pixel_values") + model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings") + model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings") + model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask") + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) + + if model_kwargs["image_attention_mask"] is not None: + model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( + 0, expanded_return_idx + ) + + if model_kwargs["pixel_values"] is not None: + model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) + + elif model_kwargs["image_encoder_embeddings"] is not None: + model_kwargs["image_encoder_embeddings"] = model_kwargs["image_encoder_embeddings"].index_select( + 0, expanded_return_idx + ) + + elif model_kwargs["perceiver_embeddings"] is not None: + model_kwargs["perceiver_embeddings"] = model_kwargs["perceiver_embeddings"].index_select( + 0, expanded_return_idx + ) + + return input_ids, model_kwargs + + +def freeze_model(model, module_exceptions=[]): + mapping = { + "LayerNorm": nn.LayerNorm, + "Linear": nn.Linear, + "Embedding": nn.Embedding, + } + module_exceptions_mapped = [mapping[m] for m in module_exceptions] + for module in model.modules(): + if module_exceptions and any(isinstance(module, t) for t in module_exceptions_mapped): + module.requires_grad_(True) # Explicitly setting it to true to avoid any mistakes + else: + module.requires_grad_(False) + return model + + +class IdeficsDecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, + then it will create `num_additional_embeddings` additional parameters that are always trained. If + `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze: Optional[bool] = False, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + Args: + num_embeddings (`int`): + Size of the dictionary of embeddings + num_additional_embeddings (`int`): + Number of additional embeddings. Only useful when you `partially_freeze=True`. + embedding_dim (`int`): + The size of each embedding vector + partially_freeze: (`bool`, *optional*, defaults to `False`): + If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. + padding_idx (`int`, *optional*): + The padding index (needs to be less than num_embeddings) + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, + `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd + embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but + then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - + i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are + usually relatively short it's probably not faster or if faster not by much - but might be a good idea to + measure. + + """ + if self.num_additional_embeddings == 0: + return F.embedding(input_ids, self.weight) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return f"num_embeddings={self.num_embeddings}, num_additional_embeddings={self.num_additional_embeddings}, embedding_dim={self.embedding_dim}, partially_freeze={self.partially_freeze}" + + +class IdeficsDecoupledLinear(nn.Linear): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + out_additional_features: int = 0, + bias: bool = True, + partially_freeze: bool = True, + device=None, + dtype=None, + ) -> None: + """ + out_additional_features: int. Number of additional trainable dimensions. Only makes sense when + `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra + parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear. + """ + super().__init__(in_features, out_features, bias, device, dtype) + self.out_additional_features = out_additional_features + self.partially_freeze = partially_freeze + + self.in_features = in_features + self.out_features = out_features + + if partially_freeze: + self.weight.requires_grad_(False) + if bias: + self.bias.requires_grad_(False) + + if out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=in_features, + out_features=out_additional_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = F.linear(input, self.weight, self.bias) + + if self.out_additional_features > 0: + additional_features = self.additional_fc(input) + output = torch.cat((output, additional_features), -1) + + return output + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return f"in_features={self.in_features}, out_features={self.out_features}, out_additional_features={self.out_additional_features}, bias={self.bias is not None}, partially_freeze={self.partially_freeze}" + + +# this was adapted from LlamaRMSNorm +class IdeficsRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + IdeficsRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# this was adapted from LlamaRotaryEmbedding +class IdeficsEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# this was adapted from LlamaMLP +class IdeficsMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# this was adapted from LlamaAttention +class IdeficsAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout: float = 0.0, + is_cross_attention: bool = False, + config: Optional[PretrainedConfig] = None, + qk_layer_norms: bool = False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.dropout = dropout + self.is_causal = True + self.scaling = self.head_dim**-0.5 + + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.is_cross_attention = is_cross_attention + + if not hasattr(nn.functional, "scaled_dot_product_attention"): + raise ValueError("this model requires pytorch 2.0 or higher") + + if self.is_cross_attention: + kv_input_dim = ( + self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim + ) + self.q_proj = nn.Linear( + self.hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.k_proj = nn.Linear(kv_input_dim, num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear( + kv_input_dim, + num_heads * self.head_dim, + bias=False, + ) + else: + self.q_proj = nn.Linear( + self.hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear( + num_heads * self.head_dim, + hidden_size, + bias=False, + ) + self.rotary_emb = IdeficsEmbedding(self.head_dim) + + self.qk_layer_norms = qk_layer_norms + if self.qk_layer_norms: + self.q_layer_norm = IdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_layer_norm = IdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = self.is_cross_attention or key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + if not is_cross_attention: + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + else: + _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = ( + self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_values is not None: + kv_seq_len += cache_position[0] + + if not is_cross_attention: + cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len)) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.qk_layer_norms: + query_states = self.q_layer_norm(query_states) + key_states = self.k_layer_norm(key_states) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +# this was adapted from LlamaDecoderLayer +class IdeficsDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = IdeficsAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.dropout, + config=config, + layer_idx=layer_idx, + ) + self.mlp = IdeficsMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dropout = config.dropout + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + return hidden_states + + +class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer): + def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.cross_attn = IdeficsAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + is_cross_attention=True, + dropout=config.dropout, + config=config, + qk_layer_norms=config.qk_layer_norms, + layer_idx=layer_idx, + ) + self.mlp = IdeficsMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config.dropout + + self.act_cross_attn = nn.Tanh() + self.act_dense = nn.Tanh() + + if config.alpha_initializer == "zeros": + if config.alpha_type == "vector": + self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) + self.alpha_dense = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) + elif config.alpha_type == "float": + self.alpha_cross_attn = nn.Parameter(torch.zeros(1)) + self.alpha_dense = nn.Parameter(torch.zeros(1)) + else: + raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})") + + elif config.alpha_initializer == "ones": + if config.alpha_type == "vector": + self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.hidden_size)) + self.alpha_dense = nn.Parameter(torch.ones(1, 1, self.hidden_size)) + elif config.alpha_type == "float": + self.alpha_cross_attn = nn.Parameter(torch.ones(1)) + self.alpha_dense = nn.Parameter(torch.ones(1)) + else: + raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})") + + elif config.alpha_initializer in {"normal", "gaussian", "random"}: + if config.alpha_type == "vector": + self.alpha_cross_attn = nn.Parameter( + torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size)) + ) + self.alpha_dense = nn.Parameter( + torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size)) + ) + elif config.alpha_type == "float": + self.alpha_cross_attn = nn.Parameter( + torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)) + ) + self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))) + else: + raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})") + + else: + raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!") + + if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): + raise ValueError("Alpha parameters not initialized correctly!") + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_hidden_states: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + cross_attention_gate: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + r""" + image_hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)` + image_attention_mask (`torch.FloatTensor`, *optional*): + image attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + cross_attention_gate (`torch.FloatTensor`, *optional*): + gate of size `(batch, seq_len)` used to zero-out cross-attention output for tokens attending no images. + """ + if image_hidden_states is None: + raise ValueError( + "`image_hidden_states` is required for Idefics cross attention module which are visual features to be" + " conditioned on." + ) + + if cross_attention_gate is None: + raise ValueError( + "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images." + ) + + if past_key_values is not None: + raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.") + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.cross_attn( + hidden_states=hidden_states, + key_value_states=image_hidden_states, + attention_mask=image_attention_mask, + **kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + # Fill in zeros for cross_attention hidden_states of tokens attending to no images + hidden_states = hidden_states.masked_fill((cross_attention_gate == 0)[:, :, None], 0.0) + hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states + + return hidden_states + + +@auto_docstring +class IdeficsPreTrainedModel(PreTrainedModel): + config: IdeficsConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] + _supports_sdpa = True + + _supports_flash_attn = False # only eager/sdpa creation is supported + _can_compile_fullgraph = False # IDEFICS cannot compile due to dynamic control flow when checking inputs + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": IdeficsDecoderLayer, + "attentions": OutputRecorder(IdeficsAttention, index=1, layer_name="self_attn"), + } + + def _init_weights(self, module): + # important: this ported version of Idefics isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the m4 code + # base should be used for training from scratch and it contains the correct code. + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, IdeficsRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, IdeficsVisionEmbeddings): + module.class_embedding.data.normal_() + elif isinstance(module, IdeficsGatedCrossAttentionLayer): + if self.config.alpha_initializer == "zeros": + module.alpha_cross_attn.data.zero_() + module.alpha_dense.data.zero_() + elif self.config.alpha_initializer == "ones": + module.alpha_cross_attn.data.fill_(1.0) + module.alpha_dense.data.fill_(1.0) + elif self.config.alpha_initializer in {"normal", "gaussian", "random"}: + module.alpha_cross_attn.data.normal_(mean=0.0, std=self.config.alphas_initializer_range) + module.alpha_dense.data.normal_(mean=0.0, std=self.config.alphas_initializer_range) + elif isinstance(module, IdeficsPerceiverResampler): + module.latents.data.normal_() + + +@auto_docstring +class IdeficsModel(IdeficsPreTrainedModel): + """ + Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`] + + Args: + config: IdeficsConfig + """ + + def __init__(self, config: IdeficsConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = IdeficsDecoupledEmbedding( + num_embeddings=config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_text_layers, + padding_idx=self.padding_idx, + ) + + self.image_size = config.vision_config.image_size + self.vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + self.vision_config._attn_implementation = config._attn_implementation + self.vision_model = IdeficsVisionTransformer(config.vision_config) + + # Perceiver Resampler + if config.use_resampler: + perceiver_config = config.perceiver_config + self.perceiver_resampler = IdeficsPerceiverResampler( + config, + config.vision_config.embed_dim, + perceiver_config.resampler_depth, + perceiver_config.resampler_n_heads, + perceiver_config.resampler_head_dim, + perceiver_config.resampler_n_latents, + ) + + self.layers = nn.ModuleList( + [IdeficsDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + + self.cross_layer_interval = config.cross_layer_interval + num_cross_layers = config.num_hidden_layers // self.cross_layer_interval + self.gated_cross_attn_layers = nn.ModuleList( + [IdeficsGatedCrossAttentionLayer(config, layer_idx=i) for i in range(num_cross_layers)] + ) + self.gradient_checkpointing = False + + self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + self.freeze_relevant_params(config) + + def freeze_relevant_params(self, config=None): + if config is None: + config = self.config + + if config.freeze_text_layers: + self.freeze_text_layers(config.freeze_text_module_exceptions) + + if config.freeze_vision_layers: + freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) + + def freeze_text_layers(self, module_exceptions=[]): + for module in [self.layers, self.norm]: + freeze_model(module, module_exceptions=module_exceptions) + + def freeze_vision_layers(self, module_exceptions=[]): + freeze_model(self.vision_model, module_exceptions=module_exceptions) + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_encoder_embeddings: Optional[torch.FloatTensor] = None, + perceiver_embeddings: Optional[torch.FloatTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, IdeficsBaseModelOutputWithPast]: + r""" + image_encoder_embeddings (`torch.FloatTensor`, *optional*): + The output of the image encoder. + perceiver_embeddings (`torch.FloatTensor`, *optional*): + The output of the perceiver resampler. + image_attention_mask (`torch.LongTensor`, *optional*): + The attention mask for the image encoder. + """ + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + batch_size, seq_length, _ = inputs_embeds.shape + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_key_values_length + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids[:, -seq_length:] + elif position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if sum(x is None for x in [pixel_values, image_encoder_embeddings, perceiver_embeddings]) != 2: + raise ValueError( + "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." + ) + + elif pixel_values is not None: + pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility + batch_size, num_images = pixel_values.shape[:2] + pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:]) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ).last_hidden_state + + elif image_encoder_embeddings is not None: + batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size() + image_hidden_states = image_encoder_embeddings.to(dtype=self.dtype, device=device) + image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) + + if self.config.use_resampler: + if perceiver_embeddings is None: + perceiver_embeddings = self.perceiver_resampler(image_hidden_states) + image_seq_len, image_hidden_size = perceiver_embeddings.size(1), perceiver_embeddings.size(2) + else: + batch_size, num_images, image_seq_len, image_hidden_size = perceiver_embeddings.size() + image_hidden_states = perceiver_embeddings + elif perceiver_embeddings is None: + image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) + else: + raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True") + + image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size) + # # Hack to use the model in full language modeling mode + # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device) + # Make image_attention_mask compatible with hidden states + text_seq_len = image_attention_mask.size(1) + image_attention_mask = image_attention_mask.unsqueeze(-1) + image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len) + image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len) + + if image_hidden_states is not None: + image_batch_size, image_sequence_length, _ = image_hidden_states.size() + image_hidden_shape = (image_batch_size, image_sequence_length) + if image_attention_mask is None: + image_attention_mask = torch.ones(image_hidden_shape, device=device) + image_attention_mask = self.invert_attention_mask(image_attention_mask) + else: + image_attention_mask = None + + # cross_attention_gate: + # For any tokens attending to no images, the hidden_states coming out of the cross-attention should be zeroed-out. + # `image_attention_mask` has shape [bsz, 1, num_images, hidden_size] with elements equal to either 0.0 or a very negative number. + # If any of the elements are 0.0, then the token is attending to at least one image and the gate value is 1. Otherwise the gate value is 0. + # `cross_attention_gate` has shape [bsz, seq_len] with elements equal to either 0.0 or 1.0. + cross_attention_gate = ((((image_attention_mask == 0.0).any(dim=-1)).to(dtype=self.dtype)).squeeze(dim=1)).to( + device + ) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + # TODO(ls): Add cross attention values to respective lists + if idx % self.cross_layer_interval == 0: + cross_attn_block = self.gated_cross_attn_layers[idx // self.cross_layer_interval] + hidden_states = cross_attn_block( + hidden_states, + causal_mask, + image_hidden_states, + image_attention_mask=image_attention_mask, + cross_attention_gate=cross_attention_gate, + past_key_values=None, # not implemented + **kwargs, + ) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size) + + return IdeficsBaseModelOutputWithPast( + last_hidden_state=hidden_states, + image_hidden_states=image_hidden_states, + past_key_values=past_key_values, + ) + + +class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config, vision_model=None): + super().__init__(config) + self.model = IdeficsModel(config) + + self.lm_head = IdeficsDecoupledLinear( + in_features=config.hidden_size, + out_features=config.vocab_size, + out_additional_features=config.additional_vocab_size, + bias=False, + partially_freeze=config.freeze_lm_head, + ) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of + IdeficsDecoupledLinear and IdeficsDecoupledEmbedding. + """ + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings.weight = input_embeddings.weight + if input_embeddings.num_additional_embeddings > 0: + assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings + output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight + + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + if hasattr(output_embeddings, "out_additional_features") and hasattr( + input_embeddings, "num_additional_embeddings" + ): + output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_encoder_embeddings: Optional[torch.FloatTensor] = None, + perceiver_embeddings: Optional[torch.FloatTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, IdeficsCausalLMOutputWithPast]: + r""" + image_encoder_embeddings (`torch.FloatTensor`, *optional*): + The output of the image encoder. + perceiver_embeddings (`torch.FloatTensor`, *optional*): + The output of the perceiver resampler. + image_attention_mask (`torch.LongTensor`, *optional*): + The attention mask for the image encoder. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoProcessor, IdeficsForVisionText2Text + + >>> model = IdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b") + >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics-9b") + + >>> dogs_image_url_1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image1.jpeg" + >>> dogs_image_url_2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image2.jpeg" + + >>> prompts = [ + ... [ + ... "User:", + ... dogs_image_url_1, + ... "Describe this image.\nAssistant: An image of two dogs.\n", + ... "User:", + ... dogs_image_url_2, + ... "Describe this image.\nAssistant:", + ... ] + ... ] + >>> inputs = processor(prompts, return_tensors="pt") + >>> generate_ids = model.generate(**inputs, max_new_tokens=6) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True) + ```""" + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_encoder_embeddings=image_encoder_embeddings, + perceiver_embeddings=perceiver_embeddings, + image_attention_mask=image_attention_mask, + use_cache=use_cache, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return IdeficsCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + attention_mask=None, + position_ids=None, + inputs_embeds=None, + past_key_values=None, + cache_position=None, + pixel_values=None, + image_hidden_states=None, + image_attention_mask=None, + use_cache=None, + **kwargs, + ): + # Overwritten -- custom processing based on `config.use_resampler` + + images_kwargs = {} + if image_hidden_states is not None: + if self.config.use_resampler: + images_kwargs["perceiver_embeddings"] = image_hidden_states + else: + images_kwargs["image_encoder_embeddings"] = image_hidden_states + else: + images_kwargs["pixel_values"] = pixel_values + images_kwargs["interpolate_pos_encoding"] = kwargs.pop("interpolate_pos_encoding", False) + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + image_attention_mask=image_attention_mask, + **images_kwargs, + **kwargs, + ) + + if image_attention_mask is not None and inputs_embeds is None: + seq_length = model_inputs["input_ids"].shape[1] + model_inputs["image_attention_mask"] = image_attention_mask[:, -seq_length:] + + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: dict[str, Any], + is_encoder_decoder: bool = False, + **kwargs, + ) -> dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder, + **kwargs, + ) + + if "image_attention_mask" in model_kwargs: + image_attention_mask = model_kwargs["image_attention_mask"] + last_mask = image_attention_mask[:, -1, :].unsqueeze(1) + if model_kwargs.get("use_cache", True): + model_kwargs["image_attention_mask"] = last_mask + else: + model_kwargs["image_attention_mask"] = torch.cat([image_attention_mask, last_mask], dim=1) + + # Get the precomputed image_hidden_states + model_kwargs["image_hidden_states"] = outputs.image_hidden_states + return model_kwargs + + +__all__ = ["IdeficsForVisionText2Text", "IdeficsModel", "IdeficsPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/modeling_tf_idefics.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/modeling_tf_idefics.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8e75be28f89eb70f61e874b52c458a78e41de7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/modeling_tf_idefics.py @@ -0,0 +1,1778 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Idefics model.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import tensorflow as tf + +from ... import TFPreTrainedModel +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ModelOutput +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + keras_serializable, + shape_list, + unpack_inputs, +) +from ...tf_utils import invert_attention_mask, scaled_dot_product_attention +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_idefics import IdeficsConfig +from .perceiver_tf import TFIdeficsPerceiverResampler +from .vision_tf import TFIdeficsVisionTransformer + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "IdeficsConfig" + + +@dataclass +class TFIdeficsBaseModelOutputWithPast(ModelOutput): + """ + Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(tf.Tensor)`, *optional*): + Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + last_hidden_state: tf.Tensor | None = None + past_key_values: tuple[tuple[tf.Tensor]] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + image_hidden_states: tuple[tf.Tensor] | None = None + + +@dataclass +class TFIdeficsCausalLMOutputWithPast(ModelOutput): + """ + Base class for Idefics causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(tf.Tensor)`, *optional*): + Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + image_hidden_states: tuple[tf.Tensor] | None = None + + +def expand_inputs_for_generation( + input_ids, + expand_size=1, + is_encoder_decoder=False, + attention_mask=None, + encoder_outputs=None, + **model_kwargs, +): + expanded_return_idx = tf.reshape(tf.repeat(tf.range(tf.shape(input_ids)[0]), expand_size), [-1]) + input_ids = tf.gather(input_ids, expanded_return_idx) + model_kwargs["pixel_values"] = model_kwargs.get("pixel_values") + model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings") + model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings") + model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask") + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx) + + if model_kwargs["image_attention_mask"] is not None: + model_kwargs["image_attention_mask"] = tf.gather(model_kwargs["image_attention_mask"], expanded_return_idx) + + if model_kwargs["pixel_values"] is not None: + model_kwargs["pixel_values"] = tf.gather(model_kwargs["pixel_values"], expanded_return_idx) + + elif model_kwargs["image_encoder_embeddings"] is not None: + model_kwargs["image_encoder_embeddings"] = tf.gather( + model_kwargs["image_encoder_embeddings"], expanded_return_idx + ) + + elif model_kwargs["perceiver_embeddings"] is not None: + model_kwargs["perceiver_embeddings"] = tf.gather(model_kwargs["perceiver_embeddings"], expanded_return_idx) + + return input_ids, model_kwargs + + +def update_model_kwargs_for_generation(outputs, model_kwargs): + # must have this key set to at least None + if "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + else: + model_kwargs["past_key_values"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = tf.concat([token_type_ids, token_type_ids[:, -1:, ...]], axis=-1) + + # update attention masks + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = tf.concat( + [attention_mask, tf.ones_like(attention_mask[:, -1:, ...])], axis=-1 + ) + if "image_attention_mask" in model_kwargs: + image_attention_mask = model_kwargs["image_attention_mask"] + last_mask = image_attention_mask[:, -1:, ...] + model_kwargs["image_attention_mask"] = last_mask + + # Get the precomputed image_hidden_states + model_kwargs["image_hidden_states"] = outputs.image_hidden_states + + return model_kwargs + + +def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids") + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1:] + + attention_mask = kwargs.get("attention_mask") + position_ids = kwargs.get("position_ids") + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int64), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) + if past_key_values is not None: + position_ids = position_ids[:, -1:] + + pixel_values = kwargs.get("pixel_values") + image_encoder_embeddings = kwargs.get("image_encoder_embeddings") + perceiver_embeddings = kwargs.get("perceiver_embeddings") + image_attention_mask = kwargs.get("image_attention_mask") + interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "pixel_values": pixel_values, + "image_encoder_embeddings": image_encoder_embeddings, + "perceiver_embeddings": perceiver_embeddings, + "image_attention_mask": image_attention_mask, + "interpolate_pos_encoding": interpolate_pos_encoding, + } + + +def freeze_model(model, module_exceptions=[]): + mapping = { + "LayerNorm": tf.keras.layers.LayerNormalization, + "Dense": tf.keras.layers.Dense, + "Embedding": tf.keras.layers.Embedding, + } + module_exceptions_mapped = [mapping[m] for m in module_exceptions] + if not hasattr(model, "layers"): + model.trainable = False # It is just a layer + return model + for layer in model.layers: + if module_exceptions and any(isinstance(layer, t) for t in module_exceptions_mapped): + layer.trainable = True # Explicitly setting it to true to avoid any mistakes + else: + layer.trainable = False + return model + + +class TFIdeficsDecoupledEmbedding(tf.keras.layers.Embedding): + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, + then it will create `num_additional_embeddings` additional parameters that are always trained. If + `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze: bool | None = False, + dtype=None, + **kwargs, + ) -> None: + """ + Args: + num_embeddings (`int`): + Size of the dictionary of embeddings + num_additional_embeddings (`int`): + Number of additional embeddings. Only useful when you `partially_freeze=True`. + embedding_dim (`int`): + The size of each embedding vector + partially_freeze: (`bool`, *optional*, defaults to `False`): + If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. + + Note: there are a lot of other parameters to initialize a standard `tf.keras.layers.Embedding` such as `mask_zero`, + `input_length` or `embeddings_initializer`. We are not supporting these. + """ + super().__init__( + input_dim=num_embeddings, + output_dim=embedding_dim, + dtype=dtype, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.trainable = False + + if self.num_additional_embeddings > 0: + self.additional_embedding = tf.keras.layers.Embedding( + input_dim=self.num_additional_embeddings, + output_dim=embedding_dim, + dtype=dtype, + name="additional_embedding", + ) + + def call(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd + embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but + then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - + i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are + usually relatively short it's probably not faster or if faster not by much - but might be a good idea to + measure. + + """ + if self.num_additional_embeddings == 0: + return super().call(input_ids) + + # Clone so that we don't modify the original input_ids later on + input_ids = tf.identity(input_ids) + additional_vocab_indices = tf.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = tf.gather_nd(input_ids, additional_vocab_indices) + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids = tf.tensor_scatter_nd_update( + input_ids, + additional_vocab_indices, + # tensor filled with 0, having the same length as additional_vocab_indices + tf.zeros(tf.shape(additional_vocab_indices)[0], dtype=input_ids.dtype), + ) + full_vector = super().call(input_ids) + + # overwrite the records with high indices + full_vector = tf.tensor_scatter_nd_update(full_vector, additional_vocab_indices, additional_embeddings) + + return full_vector + + def extra_repr(self) -> str: + return f"num_embeddings={self.num_embeddings}, num_additional_embeddings={self.num_additional_embeddings}, embedding_dim={self.output_dim}, partially_freeze={self.partially_freeze}" + + +class TFIdeficsDecoupledLinear(tf.keras.layers.Layer): + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Dense`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + out_additional_features: int = 0, + bias: bool = True, + partially_freeze: bool = True, + **kwargs, + ) -> None: + """ + out_additional_features: int. Number of additional trainable dimensions. Only makes sense when + `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra + parameters (if any) will be trainable. If False, default to the regular behavior of tf.keras.layers.Dense. + """ + super().__init__(**kwargs) + self.out_additional_features = out_additional_features + self.partially_freeze = partially_freeze + + self.in_features = in_features + self.out_features = out_features + self.use_bias = bias + + if out_additional_features > 0: + self.additional_fc = tf.keras.layers.Dense( + units=out_additional_features, use_bias=bias, name="additional_fc" + ) + + def call(self, inputs: tf.Tensor) -> tf.Tensor: + output = tf.linalg.matmul(a=inputs, b=self.weight, transpose_b=True) + if self.bias is not None: + output = tf.nn.bias_add(output, self.bias) + + if self.out_additional_features > 0: + additional_features = self.additional_fc(inputs) + output = tf.concat([output, additional_features], axis=-1) + + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "in_features": self.in_features, + "out_features": self.out_features, + "out_additional_features": self.out_additional_features, + "bias": self.bias is not None, + "partially_freeze": self.partially_freeze, + } + ) + return config + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return f"in_features={self.in_features}, out_features={self.out_features}, out_additional_features={self.out_additional_features}, bias={self.bias is not None}, partially_freeze={self.partially_freeze}" + + @classmethod + def from_config(cls, config): + return cls(**config) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + self.weight = self.add_weight( + shape=(self.out_features, self.in_features), trainable=not self.partially_freeze, name="weight" + ) + if self.use_bias: + self.bias = self.add_weight(shape=(self.out_features,), trainable=not self.partially_freeze, name="bias") + else: + self.bias = None + if getattr(self, "additional_fc", None) is not None: + with tf.name_scope(self.additional_fc.name): + self.additional_fc.build(self.in_features) + + +def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0): + """ + Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes. + """ + bsz, tgt_len = input_ids_shape + + # Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask) + mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min) + mask_cond = tf.range(tgt_len) + mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) + + if bsz is None: + # When batch size is dynamic, expand and tile + # so we can compile a functional model + mask = tf.expand_dims(mask, 0) + mask = tf.expand_dims(mask, 0) # shape: (1, 1, tgt_len, tgt_len + past_key_values_length) + mask = tf.tile(mask, [bsz, 1, 1, 1]) + else: + # When batch size is static, directly use broadcast_to + mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) + + return mask + + +def _expand_mask(mask, dtype, tgt_len=None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = shape_list(mask) + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1) + expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len]) + + inverted_mask = 1.0 - tf.cast(expanded_mask, dtype) + + return tf.where( + tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask + ) + + +class TFIdeficsRMSNorm(tf.keras.layers.Layer): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + """ + TFIdeficsRMSNorm is equivalent to T5LayerNorm + """ + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def build(self, input_shape): + if self.built: + return + self.built = True + self.weight = self.add_weight(name="weight", shape=[self.hidden_size], initializer="ones") + + super().build(input_shape) + + def call(self, hidden_states): + variance = tf.math.reduce_mean(tf.math.square(tf.cast(hidden_states, tf.float32)), axis=-1, keepdims=True) + hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [tf.float16, tf.bfloat16]: + hidden_states = tf.cast(hidden_states, self.weight.dtype) + + return self.weight * hidden_states + + +class TFIdeficsEmbedding(tf.keras.layers.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs): + super().__init__(**kwargs) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = tf.constant( + 1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) + ) + + def _compute_cos_sin(self, seq_len): + t = tf.range(seq_len, dtype=self.inv_freq.dtype) + freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication + emb = tf.concat((freqs, freqs), axis=-1) + + return tf.cos(emb), tf.sin(emb) + + def call(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len is None: + seq_len = shape_list(x)[2] + return self._compute_cos_sin(seq_len=seq_len) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return tf.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = tf.gather(cos, position_ids) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = tf.gather(sin, position_ids) + cos = tf.expand_dims(cos, 1) + sin = tf.expand_dims(sin, 1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class TFIdeficsMLP(tf.keras.layers.Layer): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + **kwargs, + ): + super().__init__(**kwargs) + self.gate_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="gate_proj") + self.down_proj = tf.keras.layers.Dense(hidden_size, use_bias=False, name="down_proj") + self.up_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="up_proj") + self.act_fn = get_tf_activation(hidden_act) + self.intermediate_size = intermediate_size + self.hidden_size = hidden_size + + def call(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "gate_proj", None) is not None: + with tf.name_scope(self.gate_proj.name): + self.gate_proj.build(self.hidden_size) + if getattr(self, "down_proj", None) is not None: + with tf.name_scope(self.down_proj.name): + self.down_proj.build(self.intermediate_size) + if getattr(self, "up_proj", None) is not None: + with tf.name_scope(self.up_proj.name): + self.up_proj.build(self.hidden_size) + + +class TFIdeficsAttention(tf.keras.layers.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout: float = 0.0, + is_cross_attention: bool = False, + config: IdeficsConfig = None, + qk_layer_norms: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.dropout = dropout + self.config = config + self.is_causal = True + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.is_cross_attention = is_cross_attention + + self.q_proj = tf.keras.layers.Dense( + num_heads * self.head_dim, + use_bias=False, + name="q_proj", + ) + self.k_proj = tf.keras.layers.Dense( + num_heads * self.head_dim, + use_bias=False, + name="k_proj", + ) + self.v_proj = tf.keras.layers.Dense( + num_heads * self.head_dim, + use_bias=False, + name="v_proj", + ) + self.o_proj = tf.keras.layers.Dense( + hidden_size, + use_bias=False, + name="o_proj", + ) + self.rotary_emb = TFIdeficsEmbedding(self.head_dim, name="rotary_emb") + + self.qk_layer_norms = qk_layer_norms + if self.qk_layer_norms: + self.q_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="q_layer_norm") + self.k_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="k_layer_norm") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + past_key_value: tuple[tf.Tensor] | None = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> tuple[tf.Tensor, tf.Tensor | None, tuple[tf.Tensor] | None]: + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = self.is_cross_attention or key_value_states is not None + + bsz, q_len, _ = shape_list(hidden_states) + + query_states = self._shape(self.q_proj(hidden_states), q_len, bsz) + if not is_cross_attention: + key_states = self._shape(self.k_proj(hidden_states), q_len, bsz) + value_states = self._shape(self.v_proj(hidden_states), q_len, bsz) + else: + _, kv_len, _ = shape_list(key_value_states) # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = self._shape(self.k_proj(key_value_states), kv_len, bsz) + value_states = self._shape(self.v_proj(key_value_states), kv_len, bsz) + + kv_seq_len = shape_list(key_states)[-2] + if past_key_value is not None: + kv_seq_len += shape_list(past_key_value[0])[-2] + if not is_cross_attention: + # Below is to allow symbolic tensors compilation + if tf.is_tensor(kv_seq_len): + seq_len = tf.reduce_max(kv_seq_len, q_len) + else: + seq_len = max(kv_seq_len, q_len) + cos, sin = self.rotary_emb(value_states, seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.qk_layer_norms: + query_states = self.q_layer_norm(query_states) + key_states = self.k_layer_norm(key_states) + + tf.debugging.assert_equal( + tf.shape(attention_mask), + [bsz, 1, q_len, kv_seq_len], + message=f"Attention weights should be of size {[bsz, 1, q_len, kv_seq_len]}, but is {tf.shape(attention_mask)}", + ) + + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + tf.debugging.assert_equal( + tf.shape(attn_output), + [bsz, self.num_heads, q_len, self.head_dim], + message=f"Attention weights should be of size {[bsz, self.num_heads, q_len, self.head_dim]}, but is {tf.shape(attn_output)}", + ) + + attn_output = tf.reshape(tf.transpose(attn_output, perm=[0, 2, 1, 3]), (bsz, q_len, self.hidden_size)) + + attn_output = self.o_proj(attn_output) + + attn_weights = None + if output_attentions: + logger.warning_once( + "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead" + ) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.is_cross_attention: + kv_input_dim = ( + self.hidden_size + if not hasattr(self.config.vision_config, "embed_dim") + else self.config.vision_config.embed_dim + ) + else: + kv_input_dim = self.hidden_size + if getattr(self, "o_proj", None) is not None: + with tf.name_scope(self.o_proj.name): + self.o_proj.build(self.num_heads * self.head_dim) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build(self.hidden_size) + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build(kv_input_dim) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build(kv_input_dim) + if getattr(self, "rotary_emb", None) is not None: + with tf.name_scope(self.rotary_emb.name): + self.rotary_emb.build(None) + + +class TFIdeficsDecoderLayer(tf.keras.layers.Layer): + def __init__(self, config: IdeficsConfig, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + self.self_attn = TFIdeficsAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.dropout, + config=config, + name="self_attn", + ) + self.mlp = TFIdeficsMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + name="mlp", + ) + self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") + self.post_attention_layernorm = TFIdeficsRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" + ) + self.dropout = config.dropout + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + past_key_value: tuple[tf.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + training=False, + ) -> tuple[tf.Tensor, tuple[tf.Tensor, tf.Tensor] | None]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "input_layernorm", None) is not None: + with tf.name_scope(self.input_layernorm.name): + self.input_layernorm.build(None) + if getattr(self, "post_attention_layernorm", None) is not None: + with tf.name_scope(self.post_attention_layernorm.name): + self.post_attention_layernorm.build(None) + + +class TFIdeficsGatedCrossAttentionLayer(tf.keras.layers.Layer): + def __init__(self, config: IdeficsConfig, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + self.cross_attn = TFIdeficsAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + is_cross_attention=True, + dropout=config.dropout, + config=config, + qk_layer_norms=config.qk_layer_norms, + name="cross_attn", + ) + self.mlp = TFIdeficsMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + name="mlp", + ) + self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") + self.post_attention_layernorm = TFIdeficsRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" + ) + self.config = config.dropout + + self.act_cross_attn = tf.keras.activations.tanh + self.act_dense = tf.keras.activations.tanh + + self.alpha_initializer = config.alpha_initializer + self.alpha_type = config.alpha_type + self.alphas_initializer_range = config.alphas_initializer_range + + def build(self, input_shape): + if self.built: + return + self.built = True + if self.alpha_initializer == "zeros": + if self.alpha_type == "vector": + self.alpha_cross_attn = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_dense" + ) + elif self.alpha_type == "float": + self.alpha_cross_attn = self.add_weight( + shape=(1,), initializer="zeros", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight(shape=(1,), initializer="zeros", trainable=True, name="alpha_dense") + else: + raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") + + elif self.alpha_initializer == "ones": + if self.alpha_type == "vector": + self.alpha_cross_attn = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight( + shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_dense" + ) + elif self.alpha_type == "float": + self.alpha_cross_attn = self.add_weight( + shape=(1,), initializer="ones", trainable=True, name="alpha_cross_attn" + ) + self.alpha_dense = self.add_weight(shape=(1,), initializer="ones", trainable=True, name="alpha_dense") + else: + raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") + + elif self.alpha_initializer in {"normal", "gaussian", "random"}: + if self.alpha_type == "vector": + self.alpha_cross_attn = self.add_weight( + shape=(1, 1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_cross_attn", + ) + self.alpha_dense = self.add_weight( + shape=(1, 1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_dense", + ) + elif self.alpha_type == "float": + self.alpha_cross_attn = self.add_weight( + shape=(1,), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_type", + ) + self.alpha_dense = self.add_weight( + shape=(1,), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), + trainable=True, + name="alpha_dense", + ) + else: + raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") + + else: + raise NotImplementedError(f"Alpha initialization scheme {self.alpha_initializer} not yet implemented!") + + if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): + raise ValueError("Alpha parameters not initialized correctly!") + with tf.name_scope(self.cross_attn.name): + self.cross_attn.build(None) + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + with tf.name_scope(self.input_layernorm.name): + self.input_layernorm.build(None) + with tf.name_scope(self.post_attention_layernorm.name): + self.post_attention_layernorm.build(None) + super().build(input_shape) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + image_hidden_states: tf.Tensor | None = None, + image_attention_mask: tf.Tensor | None = None, + cross_attention_gate: tf.Tensor | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + past_key_value: tuple[tf.Tensor] | None = None, + ) -> tuple[tf.Tensor, tuple[tf.Tensor, tf.Tensor] | None]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states + no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored + """ + if image_hidden_states is None: + raise ValueError( + "`image_hidden_states` is required for Idefics cross attention module which are visual features to be" + " conditioned on." + ) + + if cross_attention_gate is None: + raise ValueError( + "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images." + ) + + if past_key_value is not None: + raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.") + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=image_hidden_states, + attention_mask=image_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = tf.nn.dropout(hidden_states, rate=self.config) + mask = tf.cast(cross_attention_gate == 0, dtype=hidden_states.dtype) + # Expand dimensions of mask to match hidden_states + mask = tf.expand_dims(mask, -1) + hidden_states = tf.where( + tf.broadcast_to(mask, tf.shape(hidden_states)) == 1, tf.zeros_like(hidden_states), hidden_states + ) + # when there are no images the model is used in pure language mode + # gate = 0 if no_images else 1 + hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = tf.nn.dropout(hidden_states, rate=self.config) + hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a TensorFlow [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) subclass. + Use it as a regular TensorFlow Layer and refer to the TensorFlow documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`IdeficsConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class TFIdeficsPreTrainedModel(TFPreTrainedModel): + config_class = IdeficsConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TFIdeficsDecoderLayer", "TFIdeficsGatedCrossAttentionLayer"] + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +@keras_serializable +class TFIdeficsMainLayer(tf.keras.layers.Layer): + """ + Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`] + + Args: + config: IdeficsConfig + """ + + config_class = IdeficsConfig + + def __init__(self, config: IdeficsConfig, add_pooling_year: bool = True, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = TFIdeficsDecoupledEmbedding( + num_embeddings=config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_text_layers, + name="embed_tokens", + ) + + self.image_size = config.vision_config.image_size + self.vision_config = config.vision_config + self.vision_model = TFIdeficsVisionTransformer(config.vision_config, name="vision_model") + + # Perceiver Resampler + if config.use_resampler: + perceiver_config = config.perceiver_config + self.perceiver_resampler = TFIdeficsPerceiverResampler( + config, + config.vision_config.embed_dim, + perceiver_config.resampler_depth, + perceiver_config.resampler_n_heads, + perceiver_config.resampler_head_dim, + perceiver_config.resampler_n_latents, + name="perceiver_resampler", + ) + + self.decoder_layers = [ + TFIdeficsDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) + ] + + self.cross_layer_interval = config.cross_layer_interval + num_cross_layers = config.num_hidden_layers // self.cross_layer_interval + self.gated_cross_attn_layers = [ + TFIdeficsGatedCrossAttentionLayer(config, name=f"gated_cross_attn_layers.{i}") + for i in range(num_cross_layers) + ] + self.gradient_checkpointing = False + + self.norm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm") + + self.gradient_checkpointing = False + self.freeze_relevant_params(config) + + def freeze_relevant_params(self, config=None): + if config is None: + config = self.config + + if config.freeze_text_layers: + self.freeze_text_layers(config.freeze_text_module_exceptions) + + if config.freeze_vision_layers: + freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) + + def freeze_text_layers(self, module_exceptions=[]): + for module in [self.decoder_layers, self.norm]: + freeze_model(module, module_exceptions=module_exceptions) + + def freeze_vision_layers(self, module_exceptions=[]): + freeze_model(self.vision_model, module_exceptions=module_exceptions) + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + # if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @unpack_inputs + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + past_key_values: list[tf.Tensor] | None = None, + inputs_embeds: tf.Tensor | None = None, + pixel_values: tf.Tensor | None = None, + image_encoder_embeddings: tf.Tensor | None = None, + perceiver_embeddings: tf.Tensor | None = None, + image_attention_mask: tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool | None = False, + return_dict: bool | None = None, + training: bool | None = None, + ) -> TFIdeficsBaseModelOutputWithPast | tuple[tf.Tensor]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = shape_list(input_ids) + elif inputs_embeds is not None: + batch_size, seq_length, _ = shape_list(inputs_embeds) + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = shape_list(past_key_values[0][0])[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int32), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) + elif position_ids is None: + position_ids = tf.range(past_key_values_length, seq_length + past_key_values_length, dtype=tf.int32) + position_ids = tf.expand_dims(position_ids, 0) + + no_images = False + if ( + sum((int(pixel_values is None), int(image_encoder_embeddings is None), int(perceiver_embeddings is None))) + != 2 + ): + raise ValueError( + "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." + ) + + elif pixel_values is not None: + no_images = tf.reduce_sum(tf.cast(pixel_values, dtype=tf.int32)) == 0 + pixel_values = tf.cast(pixel_values, dtype=self.dtype) # fp16 compatibility + # Below hack is because when cross-loading pytorch weights, there is an + # initial forward pass with dummy input and code below is here to handle that + if len(pixel_values.shape) == 4: + batch_size = shape_list(pixel_values)[0] + num_images = shape_list(pixel_values)[0] + # pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[1:]]) + elif len(pixel_values.shape) == 5: + batch_size, num_images = shape_list(pixel_values)[:2] + pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[2:]]) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ).last_hidden_state + + elif image_encoder_embeddings is not None: + batch_size, num_images, image_seq_len, image_hidden_size = shape_list(image_encoder_embeddings) + image_hidden_states = tf.cast(image_encoder_embeddings, dtype=self.dtype) + image_hidden_states = tf.reshape( + image_hidden_states, (batch_size * num_images, image_seq_len, image_hidden_size) + ) + + if self.config.use_resampler: + if perceiver_embeddings is None: + perceiver_embeddings = self.perceiver_resampler(image_hidden_states) + image_seq_len, image_hidden_size = shape_list(perceiver_embeddings)[1:3] + else: + batch_size, num_images, image_seq_len, image_hidden_size = shape_list(perceiver_embeddings) + image_hidden_states = perceiver_embeddings + elif perceiver_embeddings is None: + image_seq_len, image_hidden_size = shape_list(image_hidden_states)[1:3] + else: + raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True") + + image_hidden_states = tf.reshape( + image_hidden_states, (batch_size, num_images * image_seq_len, image_hidden_size) + ) + # # Hack to use the model in full language modeling mode + # image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32) + + # this is to account for the dummy inputs + if pixel_values is not None and len(pixel_values.shape) == 4 and image_attention_mask is None: + image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32) + + text_seq_len = shape_list(image_attention_mask)[1] + image_attention_mask = tf.expand_dims(image_attention_mask, -1) + image_attention_mask = tf.repeat(image_attention_mask, repeats=image_seq_len) + image_attention_mask = tf.reshape(image_attention_mask, (batch_size, text_seq_len, num_images * image_seq_len)) + + if image_hidden_states is not None: + image_batch_size, image_sequence_length, _ = shape_list(image_hidden_states) + image_hidden_shape = (image_batch_size, image_sequence_length) + if image_attention_mask is None: + image_attention_mask = tf.ones(image_hidden_shape, dtype=tf.int32) + image_attention_mask = invert_attention_mask(image_attention_mask) + else: + image_attention_mask = None + + cross_attention_gate = tf.squeeze( + tf.cast(tf.reduce_any(image_attention_mask == 0, axis=-1), dtype=self.dtype), axis=1 + ) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = tf.ones((batch_size, seq_length_with_past), dtype=tf.bool) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.decoder_layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + def vblock( + main_block, + hidden_states, + attention_mask, + position_ids, + past_key_value, + image_hidden_states, + image_attention_mask, + cross_attention_gate, + output_attentions, + use_cache, + layer_idx, + cross_layer_interval, + gated_cross_attn_layers, + ): + # TODO(ls): Add cross attention values to respective lists + if layer_idx % cross_layer_interval == 0: + xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] + outputs = xblock( + hidden_states, + attention_mask=attention_mask, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + cross_attention_gate=cross_attention_gate, + output_attentions=output_attentions, + use_cache=use_cache, + past_key_value=None, # not implemented + ) + hidden_states = outputs[0] + + layer_outputs = main_block( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + return layer_outputs + + if self.gradient_checkpointing and training: + past_key_value = None + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + layer_outputs = tf.recompute_grad( + vblock, + decoder_layer, + hidden_states, + attention_mask, + position_ids, + past_key_value, + image_hidden_states, + image_attention_mask, + output_attentions, + use_cache, + no_images, + idx, + self.cross_layer_interval, + self.gated_cross_attn_layers, + ) + else: + layer_outputs = vblock( + decoder_layer, + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + image_hidden_states=image_hidden_states, + image_attention_mask=image_attention_mask, + cross_attention_gate=cross_attention_gate, + output_attentions=output_attentions, + use_cache=use_cache, + layer_idx=idx, + cross_layer_interval=self.cross_layer_interval, + gated_cross_attn_layers=self.gated_cross_attn_layers, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + image_hidden_states = tf.reshape( + image_hidden_states, (batch_size, num_images, image_seq_len, image_hidden_size) + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states] + if v is not None + ) + return TFIdeficsBaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + image_hidden_states=image_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_tokens", None) is not None: + with tf.name_scope(self.embed_tokens.name): + self.embed_tokens.build(None) + if getattr(self, "vision_model", None) is not None: + with tf.name_scope(self.vision_model.name): + self.vision_model.build(None) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build(None) + if getattr(self, "perceiver_resampler", None) is not None: + with tf.name_scope(self.perceiver_resampler.name): + self.perceiver_resampler.build(None) + if getattr(self, "decoder_layers", None) is not None: + for layer in self.decoder_layers: + with tf.name_scope(layer.name): + layer.build(None) + if getattr(self, "gated_cross_attn_layers", None) is not None: + for layer in self.gated_cross_attn_layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFIdeficsModel(TFIdeficsPreTrainedModel): + def __init__(self, config: IdeficsConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.model = TFIdeficsMainLayer(config, name="model") + + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + past_key_values: list[tf.Tensor] | None = None, + inputs_embeds: tf.Tensor | None = None, + pixel_values: tf.Tensor | None = None, + image_encoder_embeddings: tf.Tensor | None = None, + perceiver_embeddings: tf.Tensor | None = None, + image_attention_mask: tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool | None = False, + return_dict: bool | None = None, + training: bool | None = None, + ) -> TFIdeficsBaseModelOutputWithPast | tuple[tf.Tensor]: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_encoder_embeddings=image_encoder_embeddings, + perceiver_embeddings=perceiver_embeddings, + image_attention_mask=image_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +class TFIdeficsForVisionText2Text(TFPreTrainedModel, TFCausalLanguageModelingLoss): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + config_class = IdeficsConfig + + def __init__(self, config, vision_model=None, **kwargs): + super().__init__(config, **kwargs) + self.model = TFIdeficsMainLayer(config, name="model") + self.lm_head = TFIdeficsDecoupledLinear( + config.hidden_size, + config.vocab_size, + config.additional_vocab_size, + bias=False, + partially_freeze=config.freeze_lm_head, + name="lm_head", + ) + + def tie_weights(self): + """ + Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of + IdeficsDecoupledLinear and IdeficsDecoupledEmbedding. + """ + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings.weight = input_embeddings.weight + if input_embeddings.num_additional_embeddings > 0: + assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings + output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight + + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + if hasattr(output_embeddings, "out_additional_features") and hasattr( + input_embeddings, "num_additional_embeddings" + ): + output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings + + @unpack_inputs + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFIdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + past_key_values: list[tf.Tensor] | None = None, + inputs_embeds: tf.Tensor | None = None, + pixel_values: tf.Tensor | None = None, + image_encoder_embeddings: tf.Tensor | None = None, + perceiver_embeddings: tf.Tensor | None = None, + image_attention_mask: tf.Tensor | None = None, + labels: tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool | None = False, + return_dict: bool | None = None, + training=False, + ) -> TFIdeficsCausalLMOutputWithPast | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >> from transformers import AutoTokenizer, TFIdeficsForVisionText2Text + + >> model = TFIdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b") + >> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceM4/idefics-9b") + + >> prompt = "Hey, are you consciours? Can you talk to me?" + >> inputs = tokenizer(prompt, return_tensors="tf") + + >> # Generate + >> generate_ids = model.generate(inputs.input_ids, max_length=30) + >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_encoder_embeddings=image_encoder_embeddings, + perceiver_embeddings=perceiver_embeddings, + image_attention_mask=image_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask != 0] + shift_labels = labels[..., 1:][shift_attention_mask != 0] + else: + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + # Flatten the tokens + loss = self.hf_compute_loss( + labels=tf.reshape(shift_labels, [-1]), logits=tf.reshape(shift_logits, [-1, shift_logits.shape[-1]]) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return TFIdeficsCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + image_hidden_states = kwargs.pop("image_hidden_states", None) + if image_hidden_states is not None: + if self.config.use_resampler: + kwargs["perceiver_embeddings"] = image_hidden_states + else: + kwargs["image_encoder_embeddings"] = image_hidden_states + kwargs["pixel_values"] = None + inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) + unwanted_kwargs = ["token_type_ids"] + for kwarg in unwanted_kwargs: + inputs.pop(kwarg, None) + return inputs + + @staticmethod + def _expand_inputs_for_generation( + *args, + **model_kwargs, + ): + return expand_inputs_for_generation(*args, **model_kwargs) + + @staticmethod + def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder): + return update_model_kwargs_for_generation(outputs, model_kwargs) + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),) + return reordered_past + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +__all__ = ["TFIdeficsForVisionText2Text", "TFIdeficsModel", "TFIdeficsPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/perceiver.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..c7372c6e724cf8d0600f58bb7263405e8d608a79 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/perceiver.py @@ -0,0 +1,189 @@ +# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License. +# +# MIT License +# +# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +""" + +Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially +time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note +that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to +prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that +to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore. + +References: + - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model + - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch + +""" + +from typing import Optional + +import torch +import torch.nn as nn + +from .configuration_idefics import IdeficsConfig + + +class IdeficsPerceiverResampler(nn.Module): + def __init__( + self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int + ) -> None: + """ + Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or + MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then + returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed + to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. + Could be e.g., VIT embed_dim, ResNet pool dim, and so on. + + Args: + config (`IdeficsConfig`): config object + embed_dim (`int`): The size of each embedding vector + depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). + n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention). + head_dim (`int`): Dimensionality of each head projection in the Transformer block. + n_latents (`int`): + Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). + + """ + super().__init__() + self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents + self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver + + # Create Latents for Perceiver + self.latents = nn.Parameter(torch.randn(self.n_latents, self.embed_dim), requires_grad=True) + + self.intermediate_dim = ( + self.embed_dim * 4 + if not hasattr(config.vision_config, "embed_dim") + else config.vision_config.embed_dim * 4 + ) + # Create Transformer Blocks + self.blocks = nn.ModuleList( + [ + nn.ModuleList( + [ + IdeficsPerceiverAttention(self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms), + IdeficsMLP(self.intermediate_dim, config), + ] + ) + for _ in range(depth) + ] + ) + self.layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, context: torch.Tensor) -> torch.Tensor: + """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" + # einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) + latents = self.latents.repeat(context.shape[0], 1, 1) + + # Feed through Perceiver Attention blocks... + for attn, ff in self.blocks: + latents = attn(context, latents) + latents + latents = ff(latents) + latents + + return self.layer_norm(latents) + + +class IdeficsPerceiverAttention(nn.Module): + def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool) -> None: + """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" + super().__init__() + self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim + self.qk_layer_norms = qk_layer_norms + # Normalization & Scaling + self.context_layer_norm = nn.LayerNorm(self.embed_dim) + self.latents_layer_norm = nn.LayerNorm(self.embed_dim) + if self.qk_layer_norms: + self.q_layer_norm = nn.LayerNorm(self.head_dim) + self.k_layer_norm = nn.LayerNorm(self.head_dim) + + self.qk_scale = self.head_dim**-0.5 + + # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). + self.q_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False) + + self.output_proj = nn.Linear(self.n_heads * self.head_dim, embed_dim, bias=False) + + def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """ + Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! + + Args: + context (`torch.Tensor`): + Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample. + latents (`torch.Tensor`): + Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to. + + Returns: + `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross + from context. + """ + context = self.context_layer_norm(context) + latents = self.latents_layer_norm(latents) + batch_size, seq_length, embed_dim = context.shape[:3] + + # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! + # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` + q = self.q_proj(latents) + k = self.k_proj(torch.cat([context, latents], dim=-2)) + v = self.v_proj(torch.cat([context, latents], dim=-2)) + + # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) + # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] + # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) + q, k, v = [x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) for x in (q, k, v)] + + if self.qk_layer_norms: + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + + scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k) + stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach()) + attn = stabilized_scores.softmax(dim=-1) + + # Attend & project back to output... + resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v) + # einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads) + return self.output_proj(resampled.transpose(1, 2).flatten(-2)) + + +class IdeficsMLP(nn.Module): + def __init__(self, intermediate_size, config: IdeficsConfig): + """Simple MLP block with intermediate_size and embedding size""" + super().__init__() + self.embed_dim = config.vision_config.embed_dim + self.ln = nn.LayerNorm(self.embed_dim) + self.fc = nn.Linear(self.embed_dim, intermediate_size, bias=False) + self.act = nn.ReLU() + self.c_proj = nn.Linear(intermediate_size, self.embed_dim, bias=False) + + def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.ln(hidden_states) + hidden_states = self.fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/perceiver_tf.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/perceiver_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..a4de96b68e780828ff8106de800ed8b7d3d1469f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/perceiver_tf.py @@ -0,0 +1,195 @@ +# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License. +# +# MIT License +# +# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +""" + +Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially +time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note +that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to +prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that +to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore. + +References: + - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model + - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch + +""" + +from typing import Optional + +import tensorflow as tf + +from ...modeling_tf_utils import shape_list +from .configuration_idefics import IdeficsConfig + + +class TFIdeficsPerceiverResampler(tf.keras.layers.Layer): + def __init__( + self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, **kwargs + ) -> None: + """ + Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or + MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then + returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed + to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. + Could be e.g., VIT embed_dim, ResNet pool dim, and so on. + + Args: + config (`IdeficsConfig`): config object + embed_dim (`int`): The size of each embedding vector + depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). + n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention). + head_dim (`int`): Dimensionality of each head projection in the Transformer block. + n_latents (`int`): + Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). + + """ + super().__init__(**kwargs) + self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents + self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver + + self.intermediate_dim = ( + self.embed_dim * 4 + if not hasattr(config.vision_config, "embed_dim") + else config.vision_config.embed_dim * 4 + ) + # Create Transformer Blocks + self.blocks = [] + for i in range(depth): + self.blocks.append( + [ + TFIdeficsPerceiverAttention( + self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms, name=f"blocks.{i}.0" + ), + TFIdeficsMLP(self.intermediate_dim, config, name=f"blocks.{i}.1"), + ] + ) + + self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + def build(self, input_shape): + # Create Latents for Perceiver + self.latents = self.add_weight( + shape=(self.n_latents, self.embed_dim), initializer="random_normal", trainable=True, name="latents" + ) + super().build(input_shape) + + def call(self, context: tf.Tensor) -> tf.Tensor: + """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" + # tf.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) + latents = tf.expand_dims(self.latents, axis=0) + latents = tf.tile(latents, [tf.shape(context)[0], 1, 1]) + # Feed through Perceiver Attention blocks... + for attn, ff in self.blocks: + latents = attn(context, latents) + latents + latents = ff(latents) + latents + return self.layer_norm(latents) + + +class TFIdeficsPerceiverAttention(tf.keras.layers.Layer): + def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool, **kwargs) -> None: + """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" + super().__init__(**kwargs) + self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim + self.qk_layer_norms = qk_layer_norms + # Normalization & Scaling + self.context_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="context_layer_norm") + self.latents_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="latents_layer_norm") + if self.qk_layer_norms: + self.q_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="q_layer_norm") + self.k_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="k_layer_norm") + + self.qk_scale = self.head_dim**-0.5 + + # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). + self.q_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="q_proj") + self.k_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="k_proj") + self.v_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="v_proj") + + self.output_proj = tf.keras.layers.Dense(embed_dim, use_bias=False, name="output_proj") + + def call(self, context: tf.Tensor, latents: tf.Tensor) -> tf.Tensor: + """ + Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! + + Args: + context (`tf.Tensor`): + Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample. + latents (`tf.Tensor`): + Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to. + + Returns: + `tf.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross + from context. + """ + context = self.context_layer_norm(context) + latents = self.latents_layer_norm(latents) + batch_size, seq_length, embed_dim = shape_list(context) + + # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! + # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` + q = self.q_proj(latents) + k = self.k_proj(tf.concat([context, latents], axis=-2)) + v = self.v_proj(tf.concat([context, latents], axis=-2)) + + # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) + # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] + q, k, v = [ + tf.transpose(tf.reshape(x, (batch_size, x.shape[1], self.n_heads, self.head_dim)), perm=[0, 2, 1, 3]) + for x in (q, k, v) + ] + + if self.qk_layer_norms: + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + + scores = tf.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k) + stabilized_scores = scores - tf.reduce_max(scores, axis=-1, keepdims=True) + attn = tf.nn.softmax(stabilized_scores, axis=-1) + + # Attend & project back to output... + resampled = tf.einsum("... i j, ... j d -> ... i d", attn, v) + return self.output_proj( + tf.reshape(tf.transpose(resampled, perm=[0, 2, 1, 3]), (batch_size, -1, self.n_heads * self.head_dim)) + ) + + +class TFIdeficsMLP(tf.keras.layers.Layer): + def __init__(self, intermediate_size, config: IdeficsConfig, **kwargs): + """Simple MLP block with intermediate_size and embedding size""" + super().__init__(**kwargs) + self.embed_dim = config.vision_config.embed_dim + self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="ln") + self.fc = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="fc") + self.act = tf.keras.layers.ReLU(name="act") + self.c_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="c_proj") + + def call(self, hidden_states: Optional[tuple[tf.Tensor]]) -> tf.Tensor: + hidden_states = self.ln(hidden_states) + hidden_states = self.fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/processing_idefics.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/processing_idefics.py new file mode 100644 index 0000000000000000000000000000000000000000..59c8078ad84adf69463c4850f03999f9eac2db57 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/processing_idefics.py @@ -0,0 +1,524 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for IDEFICS. +""" + +from typing import Callable, Optional, Union +from urllib.parse import urlparse + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + TextKwargs, + Unpack, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_tf_available, is_torch_available +from ...utils.deprecation import deprecate_kwarg + + +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + +IMAGE_TOKEN = "" + + +class IdeficsImagesKwargs(ImagesKwargs, total=False): + transform: Optional[Callable] + image_size: Optional[dict[str, int]] + image_mean: Optional[Union[float, list[float]]] + image_std: Optional[Union[float, list[float]]] + + +class IdeficsTextKwargs(TextKwargs, total=False): + add_eos_token: Optional[bool] + add_end_of_utterance_token: Optional[bool] + + +class IdeficsProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: IdeficsTextKwargs + images_kwargs: IdeficsImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": False, + "padding": "longest", + "add_eos_token": False, + }, + "images_kwargs": {}, + "common_kwargs": {"return_tensors": "pt"}, + } + + +# copied from m4.training.packing +def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1): + # Set elements >= num_classes to -1 + if num_classes != -1: + if return_tensors == "pt": + incremental_mask[incremental_mask >= num_classes] = -1 + elif return_tensors == "tf": + incremental_mask = tf.where(incremental_mask >= num_classes, -1, incremental_mask) + + # Create mask for negative values + if return_tensors == "pt": + negatives = incremental_mask == -1 + incremental_mask[negatives] = 0 + attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) + attn_mask[negatives, :] = 0 + elif return_tensors == "tf": + negatives = tf.equal(incremental_mask, -1) + incremental_mask = tf.where(negatives, 0, incremental_mask) + attn_mask = tf.one_hot(incremental_mask, depth=num_classes) + # Reshape 'negatives' to add an extra dimension, making it [batch_size, seq_length, 1] + negatives_expanded = tf.expand_dims(negatives, -1) + attn_mask = tf.where(negatives_expanded, tf.zeros_like(attn_mask), attn_mask) + + return attn_mask + + +# copied from m4.training.packing +def image_attention_mask_for_packed_input_ids(input_ids, tokenizer, return_tensors): + if return_tensors == "pt": + return image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer) + elif return_tensors == "tf": + return image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer) + + +def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer): + image_attention_mask = torch.full_like(input_ids, fill_value=-1) + next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + eod_token_id = tokenizer.eos_token_id + for batch_idx in range(input_ids.size(0)): + count = -1 + seen_eod = False + for idx, token_id in enumerate(input_ids[batch_idx]): + if token_id == image_token_id: + count += 1 + image_attention_mask[batch_idx][idx] = count + seen_eod = False + else: + image_attention_mask[batch_idx][idx] = count + + if seen_eod: + image_attention_mask[batch_idx][idx] = -1 + + if token_id == eod_token_id: + seen_eod = True + + for batch_idx in range(input_ids.size(0)): + count = -1 + seen_eod = False + for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1): + token_id = input_ids[batch_idx][idx] + if token_id == image_token_id: + count += 1 + next_image_attention_mask[batch_idx][idx] = count + seen_eod = False + else: + next_image_attention_mask[batch_idx][idx] = count + + if token_id == eod_token_id: + seen_eod = True + + if seen_eod: + next_image_attention_mask[batch_idx][idx] = -1 + + non_negative_indices = next_image_attention_mask[batch_idx] != -1 + next_image_attention_mask[batch_idx][non_negative_indices] -= count + next_image_attention_mask[batch_idx][non_negative_indices] *= -1 + + return image_attention_mask, next_image_attention_mask + + +def image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer): + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + eod_token_id = tokenizer.eos_token_id + batch_size = tf.shape(input_ids)[0] + image_attention_mask = tf.fill(tf.shape(input_ids), -1) + next_image_attention_mask = tf.fill(tf.shape(input_ids), -1) + + for batch_idx in range(batch_size): + count = -1 + seen_eod = False + seq_length = tf.shape(input_ids)[1] + + for idx in range(seq_length - 1, -1, -1): + token_id = input_ids[batch_idx, idx].numpy() + if token_id == image_token_id: + count += 1 + indices = [[batch_idx, idx]] + updates = [count] + image_attention_mask = tf.tensor_scatter_nd_update(image_attention_mask, indices, updates) + next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) + elif token_id == eod_token_id and not seen_eod: + seen_eod = True + count = 0 + indices = [[batch_idx, idx]] + updates = [count] + next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) + if seen_eod and token_id != eod_token_id: + indices = [[batch_idx, idx]] + updates = [-1] + next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) + return image_attention_mask, next_image_attention_mask + + +def is_url(string): + """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately + invalidated the url""" + if " " in string: + return False + result = urlparse(string) + return all([result.scheme, result.netloc]) + + +class IdeficsProcessor(ProcessorMixin): + r""" + Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor. + + [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See + the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information. + + Args: + image_processor (`IdeficsImageProcessor`): + An instance of [`IdeficsImageProcessor`]. The image processor is a required input. + tokenizer (`LlamaTokenizerFast`): + An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. + image_size (`int`, *optional*, defaults to 224): + Image size (assuming a square image) + add_end_of_utterance_token (`str`, *optional*): + The string representation of token representing end of utterance + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "IdeficsImageProcessor" + tokenizer_class = "LlamaTokenizerFast" + + def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_utterance_token=None, **kwargs): + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self.image_token_id = ( + tokenizer.image_token_id + if hasattr(tokenizer, "image_token") + else tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + ) + + self.default_image_dims = ( + self.image_processor.image_num_channels, + self.image_processor.image_size, + self.image_processor.image_size, + ) + + self.tokenizer_was_trained_with_end_of_utterance_token = ( + "" in self.tokenizer.special_tokens_map.get("additional_special_tokens", []) + ) + + @deprecate_kwarg(old_name="prompts", version="5.0.0", new_name="text", raise_if_both_names=True) + def __call__( + self, + images: Union[ImageInput, list[ImageInput], str, list[str], list[list[str]]] = None, + text: Union[ + TextInput, + PreTokenizedInput, + list[TextInput], + list[PreTokenizedInput], + list[list[TextInput]], + list[list[PreTokenizedInput]], + ] = None, + audio=None, + videos=None, + **kwargs: Unpack[IdeficsProcessorKwargs], + ) -> BatchFeature: + """This method takes batched or non-batched prompts made of text and images and converts them into prompts that + the model was trained on and prepares the image pixel values for the model to process. + + Args: + images (`Union[ImageInput, list[ImageInput], str, list[str], list[list[str]]]`): + either a single image or a batched list of images - can be passed in when text contains only text prompts, + in order to use the image-text-to-text behavior. + text (`Union[list[TextInput], [list[list[TextInput]]]]`): + either a single prompt or a batched list of prompts - see the detailed description immediately after + the end of the arguments doc section. + return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`): + The type of tensors to return. Can be one of: + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + + Returns: + a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be + directly passed to `model.generate` + + Detailed explanation: + + Each entry in `text` is either a text to be passed as is or an image that will be processed. + + An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved. + + When the processor encounters an image it'll inject `` + entry into the prompt. + + Example: + + ```python + checkpoint = "HuggingFaceM4/idefics-9b" + processor = AutoProcessor.from_pretrained(checkpoint) + url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg" + img = processor.image_processor.fetch_images([url])[0] + + prompts = [ + "User:", + img, + "Describe this image.\nAssistant: An image of two kittens in grass.\n", + "User:", + "https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg", + "Describe this image.\nAssistant:", + ] + + inputs = processor(text=prompts, return_tensors="pt") + generated_ids = model.generate(**inputs, max_length=100) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ``` + + In this example the `prompts` will be converted into: + + ``` + User:Describe this image. + Assistant: An image of two kittens in grass. + User:Describe this image. + Assistant:' + ``` + + and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the + `pixel_values` dict entry of the return value. + + This example also exemplifies that images can be passed as objects or as text urls. It can be seen that the + first image is passed as object and the second one as a url. + + To do training do: + + ```python + image_transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=self.image_mean, std=self.image_std), + ] + ) + inputs = processor(text=prompts, transform=image_transform, return_tensors="pt") + ``` + + In order to help debug prompt generation enable `debug=True` which will show you what's happening. + + """ + if images is None and text is None: + raise ValueError("You need to specify either `text` or `images` and `text`.") + + if images is None: + # assuming the user wants to use the old behavior with prompts as the only argument + prompts = text + elif text is not None: + # Assuming image-text-to-text behavior: + # Check if batched images are provided + if not isinstance(images, (list, tuple)): + images = [images] + if isinstance(text, str): + text = [text] + # Check if batched images and text are in the correct format + if isinstance(text, (list, tuple)) and len(text) != len(images): + raise ValueError( + "When providing both images and text arguments, the number of text prompts should be the same as the number of images." + "If you want to have several images per prompt, images should be nested as such: images=[[img1, img2], [img3, img4], ...] for text=[prompt1, prompt2, ...]." + ) + # Check that only text is present in the prompts + if not all(isinstance(i, str) for i in text): + raise ValueError("When using the image-text-to-text behavior, the prompts should only contain text.") + if isinstance(images[0], (list, tuple)): + # if nested images, un-nest each sublist and create `prompts` + prompts = [[sample, *image_list] for image_list, sample in zip(images, text)] + else: + prompts = list(zip(images, text)) + + output_kwargs = self._merge_kwargs( + IdeficsProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False) + add_end_of_utterance_token = output_kwargs["text_kwargs"].pop("add_end_of_utterance_token", None) + + # if the value isn't overridden by the user, check if the tokenizer was trained with this token and then use it + if add_end_of_utterance_token is None: + add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token + # turn non-batched prompts into batched + if not any(isinstance(i, (list, tuple)) for i in prompts): + prompts = [prompts] + + fake_token = "" + image_token = "" + end_of_utterance_token = "" + + def image_tokens(last_was_image): + if last_was_image: + return image_token + fake_token + else: + return fake_token + image_token + fake_token + + all_prompts = [] + all_images = [] + for sample in prompts: + # the model was trained on samples starting with + full_text = f"{self.tokenizer.bos_token}" + + # an image can either be an image object in the item or the url, everything else is a verbatim prompt text + image_objects = [] + last_was_image = False + last_was_text = False + for i, item in enumerate(sample): + if i > 0: + last_was_text = bool(not last_was_image) + + if isinstance(item, str): + item = item.strip(" ") + if is_url(item): + image = self.image_processor.fetch_images(item) + full_text += image_tokens(last_was_image) + image_objects.append(image) + last_was_image = True + else: + # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!) + if add_end_of_utterance_token and last_was_text: + full_text += end_of_utterance_token + full_text += item + last_was_image = False + else: + # must be an image obj + full_text += image_tokens(last_was_image) + image_objects.append(item) + last_was_image = True + + if add_eos_token: + full_text += self.tokenizer.eos_token + + image_objects = self.image_processor(image_objects, **output_kwargs["images_kwargs"]) + + all_prompts.append(full_text) + all_images.append(image_objects) + + # For BC + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt") + text_encoding = self.tokenizer(all_prompts, **output_kwargs["text_kwargs"]) + all_texts = text_encoding["input_ids"] + all_attention_masks = text_encoding["attention_mask"] + + # max_num_images has to be at least 1 even when there are no images + max_num_images = max(len(x) for x in all_images) + max_num_images = max(1, max_num_images) + + at_least_one_image = sum(len(x) for x in all_images) > 0 + output_input_ids = [] + output_images = [] + output_attention_masks = [] + + for text_single, attention_mask, extracted_images in zip(all_texts, all_attention_masks, all_images): + padded_input_ids = text_single + image_count = padded_input_ids.count(self.image_token_id) + local_max_num_images = min(image_count, max_num_images) + + current_images = extracted_images[:local_max_num_images] + + if len(current_images) > 0: + if return_tensors == "pt": + padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) + padded_image_tensor[: current_images.size(0)] = current_images + elif return_tensors == "tf": + # Assuming current_images is a TensorFlow tensor + # Get the shape of current_images, excluding the first dimension + image_shape = tf.shape(current_images)[1:] + # Create a shape for the padded_image_tensor + padded_shape = tf.concat([[max_num_images], image_shape], axis=0) + # Create the padded_image_tensor of zeros + padded_image_tensor = tf.zeros(padded_shape, dtype=current_images.dtype) + # Get the number of images (assuming current_images has shape [num_images, height, width, channels]) + num_images = tf.shape(current_images)[0] + # Update the padded_image_tensor with the values from current_images + indices = tf.reshape(tf.range(num_images), (-1, 1)) + updates = current_images + padded_image_tensor = tf.tensor_scatter_nd_update(padded_image_tensor, indices, updates) + else: + if return_tensors == "pt": + padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) + elif return_tensors == "tf": + padded_image_tensor = tf.zeros((max_num_images, *self.default_image_dims)) + + output_images.append(padded_image_tensor) + if return_tensors == "pt": + output_input_ids.append(torch.tensor(padded_input_ids)) + output_attention_masks.append(torch.tensor(attention_mask)) + elif return_tensors == "tf": + output_input_ids.append(tf.convert_to_tensor(padded_input_ids, dtype=tf.int32)) + output_attention_masks.append(attention_mask) + + if return_tensors == "pt": + output_input_ids = torch.stack(output_input_ids) + output_images = torch.stack(output_images) + output_attention_masks = torch.stack(output_attention_masks) + elif return_tensors == "tf": + output_input_ids = tf.stack(output_input_ids) + output_images = tf.stack(output_images) + output_attention_masks = tf.stack(output_attention_masks) + + if at_least_one_image: + image_attention_mask, _ = image_attention_mask_for_packed_input_ids( + output_input_ids, self.tokenizer, return_tensors + ) + image_attention_mask = incremental_to_binary_attention_mask( + image_attention_mask, return_tensors, num_classes=max_num_images + ) + else: + # in full language mode we set the image mask to all-0s + if return_tensors == "pt": + image_attention_mask = torch.zeros( + output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool + ) + elif return_tensors == "tf": + image_attention_mask = tf.zeros( + (output_input_ids.shape[0], output_input_ids.shape[1], 1), dtype=tf.bool + ) + return BatchFeature( + data={ + "input_ids": output_input_ids, + "attention_mask": output_attention_masks, + "pixel_values": output_images, + "image_attention_mask": image_attention_mask, + } + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(tokenizer_input_names + image_processor_input_names + ["image_attention_mask"]) + + +__all__ = ["IdeficsProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/vision.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/vision.py new file mode 100644 index 0000000000000000000000000000000000000000..f6143064835eb2cd7c05e97bf4c4ab07e29d33bb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/vision.py @@ -0,0 +1,476 @@ +# coding=utf-8 +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import ( + ModelOutput, + can_return_tuple, + logging, +) +from .configuration_idefics import IdeficsVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +class IdeficsVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +# Adapted from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings +class IdeficsVisionEmbeddings(nn.Module): + def __init__(self, config: IdeficsVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + # Heavily inspired from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/vit/modeling_vit.py#L82 + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + pos_embed = self.position_embedding(self.position_ids) + num_positions = pos_embed.shape[1] - 1 + if num_patches == num_positions and height == width: + return pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + + embed_dim = embeddings.shape[-1] + num_h_patches = height // self.config.patch_size + num_w_patches = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1 + sqrt_num_positions = math.sqrt(num_positions) + patch_pos_embed = patch_pos_embed.reshape(1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + fp32_upcasting = patch_pos_embed.dtype == torch.bfloat16 + if fp32_upcasting: + logger.warning_once( + "Upcasting patch_pos_embed to fp32 for interpolation since `upsample_bicubic2d_out_frame` in nn.functional.interpolate " + "is not implemented for 'torch.bfloat16' dtype. This will result in a slight overhead." + ) + patch_pos_embed = patch_pos_embed.to(torch.float) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(num_h_patches / sqrt_num_positions, num_w_patches / sqrt_num_positions), + mode="bicubic", + align_corners=False, + ) + if fp32_upcasting: + patch_pos_embed = patch_pos_embed.to(torch.bfloat16) + if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]: + raise ValueError( + f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the " + f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if not interpolate_pos_encoding: + if height != self.image_size or width != self.image_size: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`" + ) + + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class IdeficsVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: IdeficsVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask + else: + self.is_causal = causal_attention_mask is not None + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision +class IdeficsVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision +class IdeficsVisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: IdeficsVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = IdeficsVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = IdeficsVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->IdeficsVision +class IdeficsVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`IdeficsVisionEncoderLayer`]. + + Args: + config: IdeficsVisionConfig + """ + + def __init__(self, config: IdeficsVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer +class IdeficsVisionTransformer(nn.Module): + def __init__(self, config: IdeficsVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = IdeficsVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = IdeficsVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics/vision_tf.py b/venv/lib/python3.13/site-packages/transformers/models/idefics/vision_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8cf94022189e8c59fc84812817641a3d255a62 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/idefics/vision_tf.py @@ -0,0 +1,572 @@ +# coding=utf-8 +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling +from ...modeling_tf_utils import TFPreTrainedModel, shape_list +from ...tf_utils import flatten +from ...utils import ModelOutput, logging +from .configuration_idefics import IdeficsVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +class TFIdeficsVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[tf.Tensor] = None + last_hidden_state: Optional[tf.Tensor] = None + hidden_states: Optional[tuple[tf.Tensor]] = None + attentions: Optional[tuple[tf.Tensor]] = None + + +class TFIdeficsVisionEmbeddings(tf.keras.layers.Layer): + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = tf.keras.layers.Conv2D( + filters=self.embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + use_bias=False, + padding="valid", + data_format="channels_last", + name="patch_embedding", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = tf.keras.layers.Embedding( + self.num_positions, self.embed_dim, name="position_embedding" + ) + # self.position_ids = tf.range(self.num_positions)[tf.newaxis, :] + + def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor: + num_patches = shape_list(embeddings)[1] - 1 + pos_embed = self.position_embedding(self.position_ids) + num_positions = shape_list(pos_embed)[1] - 1 + if num_patches == num_positions and height == width: + return pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + + embed_dim = shape_list(embeddings)[-1] + num_h_patches = height // self.config.patch_size + num_w_patches = width // self.config.patch_size + num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1 + sqrt_num_positions = math.sqrt(float(num_positions)) + patch_pos_embed = tf.reshape(patch_pos_embed, (1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)) + + scale_height = num_h_patches / sqrt_num_positions + scale_width = num_w_patches / sqrt_num_positions + original_height = tf.cast(tf.shape(patch_pos_embed)[1], tf.float32) + original_width = tf.cast(tf.shape(patch_pos_embed)[2], tf.float32) + # Apply scaling + new_height = tf.cast(original_height * scale_height, tf.int32) + new_width = tf.cast(original_width * scale_width, tf.int32) + + patch_pos_embed = tf.image.resize( + patch_pos_embed, size=[new_height, new_width], method=tf.image.ResizeMethod.BICUBIC + ) + + if ( + int(num_h_patches) != shape_list(patch_pos_embed)[-3] + or int(num_w_patches) != shape_list(patch_pos_embed)[-2] + ): + raise ValueError( + f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the " + f"shape of position embedding ({shape_list(patch_pos_embed)[-2], shape_list(patch_pos_embed)[-1]})" + ) + patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, embed_dim)) + return tf.concat((class_pos_embed[tf.newaxis, :], patch_pos_embed), axis=1) + + def call(self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False) -> tf.Tensor: + # Input `pixel_values` is NCHW format which doesn't run on CPU so first thing we do is + # transpose it to change it to NHWC. We don't care to transpose it back because + # the Conv2D layer is only hit once for each query + + if isinstance(pixel_values, dict): + pixel_values = pixel_values["pixel_values"] + + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + batch_size, height, width, num_channels = shape_list(pixel_values) + if not interpolate_pos_encoding: + if height != self.image_size or width != self.image_size: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`" + ) + + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + # Change the 2D spatial dimensions to a single temporal dimension. + # shape = (batch_size, num_patches, out_channels=embed_dim) + patch_embeds = flatten(patch_embeds, 1, 2) + + class_embeds = tf.broadcast_to( + self.class_embedding[tf.newaxis, tf.newaxis, :], [batch_size, 1, self.embed_dim] + ) + embeddings = tf.concat([class_embeds, patch_embeds], axis=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + self.position_ids = tf.range(self.num_positions, name="self.position_ids")[tf.newaxis, :] + self.class_embedding = self.add_weight(shape=(self.embed_dim,), name="class_embedding") + if getattr(self, "patch_embedding", None) is not None: + with tf.name_scope(self.patch_embedding.name): + self.patch_embedding.build([None, None, None, self.config.num_channels]) + if getattr(self, "position_embedding", None) is not None: + with tf.name_scope(self.position_embedding.name): + self.position_embedding.build(None) + + +class TFIdeficsVisionAttention(tf.keras.layers.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = tf.keras.layers.Dense(self.embed_dim, name="k_proj") + self.v_proj = tf.keras.layers.Dense(self.embed_dim, name="v_proj") + self.q_proj = tf.keras.layers.Dense(self.embed_dim, name="q_proj") + self.out_proj = tf.keras.layers.Dense(self.embed_dim, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + causal_attention_mask: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[tf.Tensor, Optional[tf.Tensor], Optional[tuple[tf.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.linalg.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + tf.shape(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, src_len]}, but is {tf.shape(attn_weights)}", + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if shape_list(causal_attention_mask) != [bsz, 1, tgt_len, src_len]: + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(causal_attention_mask)}" + ) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + causal_attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + if attention_mask is not None: + if shape_list(attention_mask) != [bsz, 1, tgt_len, src_len]: + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}" + ) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attn_weights = tf.reshape(attn_weights_reshaped, (bsz * self.num_heads, tgt_len, src_len)) + else: + attn_weights_reshaped = None + + attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + + attn_output = tf.linalg.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + tf.shape(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, self.head_dim]}, but is {tf.shape(attn_output)}", + ) + + attn_output = tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)) + attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3]) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build((self.embed_dim, self.embed_dim)) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build((self.embed_dim, self.embed_dim)) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build((self.embed_dim, self.embed_dim)) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build((self.embed_dim, self.embed_dim)) + + +class TFIdeficsVisionMLP(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.activation_fn = get_tf_activation(config.hidden_act) + self.fc1 = tf.keras.layers.Dense(config.intermediate_size, name="fc1") + self.fc2 = tf.keras.layers.Dense(config.hidden_size, name="fc2") + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build(self.config.hidden_size) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build(self.config.intermediate_size) + + +class TFIdeficsVisionEncoderLayer(tf.keras.layers.Layer): + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.hidden_size + self.self_attn = TFIdeficsVisionAttention(config, name="self_attn") + self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") + self.mlp = TFIdeficsVisionMLP(config, name="mlp") + self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + causal_attention_mask: tf.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[tf.Tensor]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build([None, None, self.embed_dim]) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build([None, None, self.embed_dim]) + + +class TFIdeficsVisionEncoder(tf.keras.layers.Layer): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`TFIdeficsVisionEncoderLayer`]. + + Args: + config: IdeficsVisionConfig + """ + + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layers = [ + TFIdeficsVisionEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) + ] + self.gradient_checkpointing = False + + def call( + self, + inputs_embeds, + attention_mask: Optional[tf.Tensor] = None, + causal_attention_mask: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = None, + ) -> Union[tuple, TFBaseModelOutput]: + r""" + Args: + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = tf.recompute_grad( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFIdeficsVisionTransformer(TFPreTrainedModel): + def __init__(self, config: IdeficsVisionConfig, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.embed_dim = config.hidden_size + + self.embeddings = TFIdeficsVisionEmbeddings(config, name="embeddings") + self.pre_layrnorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm") + self.encoder = TFIdeficsVisionEncoder(config, name="encoder") + self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm") + + # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[tuple, TFBaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "pre_layrnorm", None) is not None: + with tf.name_scope(self.pre_layrnorm.name): + self.pre_layrnorm.build([None, None, self.embed_dim]) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "post_layernorm", None) is not None: + with tf.name_scope(self.post_layernorm.name): + self.post_layernorm.build([None, self.embed_dim]) diff --git a/venv/lib/python3.13/site-packages/transformers/models/ijepa/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/ijepa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8923af1de116219405577646ae2dcedee5602ccc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ijepa/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ijepa import * + from .modeling_ijepa import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/ijepa/configuration_ijepa.py b/venv/lib/python3.13/site-packages/transformers/models/ijepa/configuration_ijepa.py new file mode 100644 index 0000000000000000000000000000000000000000..5f528adad0d55711f40ea21f58e1e0196822f449 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ijepa/configuration_ijepa.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""I-JEPA model configuration""" + +from ...configuration_utils import PretrainedConfig + + +class IJepaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the I-JEPA + [facebook/ijepa_vith14_1k](https://huggingface.co/facebook/ijepa_vith14_1k) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + pooler_output_size (`int`, *optional*): + Dimensionality of the pooler layer. If None, defaults to `hidden_size`. + pooler_act (`str`, *optional*, defaults to `"tanh"`): + The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and + Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are + supported for Tensorflow. + + Example: + + ```python + >>> from transformers import IJepaConfig, IJepaModel + + >>> # Initializing a IJEPA ijepa-base-patch16-224 style configuration + >>> configuration = IJepaConfig() + + >>> # Initializing a model (with random weights) from the ijepa-base-patch16-224 style configuration + >>> model = IJepaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ijepa" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + pooler_output_size=None, + pooler_act="tanh", + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size + self.pooler_act = pooler_act + + +__all__ = ["IJepaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/ijepa/modeling_ijepa.py b/venv/lib/python3.13/site-packages/transformers/models/ijepa/modeling_ijepa.py new file mode 100644 index 0000000000000000000000000000000000000000..b26a78e49ab0e251cd5bdee256a01b8d9da61451 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ijepa/modeling_ijepa.py @@ -0,0 +1,540 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ijepa/modular_ijepa.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ijepa.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import collections.abc +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import TransformersKwargs, auto_docstring, torch_int +from ...utils.generic import can_return_tuple, check_model_inputs +from .configuration_ijepa import IJepaConfig + + +class IJepaPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config: IJepaConfig): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class IJepaEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: + super().__init__() + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = IJepaPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + patch_pos_embed = self.position_embeddings + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return patch_pos_embed + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class IJepaSelfAttention(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def forward( + self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size + + key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) + query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + return context_layer, attention_probs + + +class IJepaSelfOutput(nn.Module): + """ + The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: IJepaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class IJepaAttention(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + self.attention = IJepaSelfAttention(config) + self.output = IJepaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: set[int]): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + self_attn_output, _ = self.attention(hidden_states, head_mask) + output = self.output(self_attn_output, hidden_states) + return output + + +class IJepaIntermediate(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class IJepaOutput(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class IJepaLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: IJepaConfig): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = IJepaAttention(config) + self.intermediate = IJepaIntermediate(config) + self.output = IJepaOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states_norm = self.layernorm_before(hidden_states) + attention_output = self.attention(hidden_states_norm, head_mask) + + # first residual connection + hidden_states = attention_output + hidden_states + + # in IJepa, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + return layer_output + + +@auto_docstring +class IJepaPreTrainedModel(PreTrainedModel): + config: IJepaConfig + base_model_prefix = "ijepa" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": IJepaLayer, + "attentions": IJepaSelfAttention, + } + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, IJepaEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + if module.mask_token is not None: + module.mask_token.data.zero_() + + +class IJepaEncoder(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([IJepaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput: + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module(hidden_states, layer_head_mask) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class IJepaPooler(nn.Module): + def __init__(self, config: IJepaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.pooler_output_size) + self.activation = ACT2FN[config.pooler_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class IJepaModel(IJepaPreTrainedModel): + def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + """ + super().__init__(config) + self.config = config + self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = IJepaEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = IJepaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> IJepaPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask) + + sequence_output = encoder_outputs.last_hidden_state + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output) + + +@auto_docstring( + custom_intro=""" + IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states) + e.g. for ImageNet. + + + + Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """ +) +class IJepaForImageClassification(IJepaPreTrainedModel): + def __init__(self, config: IJepaConfig): + super().__init__(config) + + self.num_labels = config.num_labels + self.ijepa = IJepaModel(config, add_pooling_layer=False) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPooling = self.ijepa( + pixel_values, + head_mask=head_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + sequence_output = outputs.last_hidden_state + logits = self.classifier(sequence_output.mean(dim=1)) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config, **kwargs) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["IJepaPreTrainedModel", "IJepaModel", "IJepaForImageClassification"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/ijepa/modular_ijepa.py b/venv/lib/python3.13/site-packages/transformers/models/ijepa/modular_ijepa.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8e6e152f3c115db4a1712895239467ea45e7df --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/ijepa/modular_ijepa.py @@ -0,0 +1,186 @@ +from typing import Optional, Union + +import torch +import torch.nn as nn + +from transformers.models.ijepa.configuration_ijepa import IJepaConfig + +from ...modeling_outputs import BaseModelOutputWithPooling, ImageClassifierOutput +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, torch_int +from ..vit.modeling_vit import ViTEmbeddings, ViTForImageClassification, ViTModel, ViTPreTrainedModel + + +class IJepaEmbeddings(ViTEmbeddings): + def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None: + super().__init__(config, use_mask_token) + # Remove cls_token from IJepaEmbeddings, as it is not used in the model + del self.cls_token + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + patch_pos_embed = self.position_embeddings + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return patch_pos_embed + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +@auto_docstring +class IJepaPreTrainedModel(ViTPreTrainedModel): + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, IJepaEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + if module.mask_token is not None: + module.mask_token.data.zero_() + + +class IJepaModel(IJepaPreTrainedModel, ViTModel): + def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + """ + super().__init__(config) + self.config = config + self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token) + + +@auto_docstring( + custom_intro=""" + IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states) + e.g. for ImageNet. + + + + Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """ +) +class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification): + def __init__(self, config: IJepaConfig): + super().__init__(config) + self.ijepa = IJepaModel(config, add_pooling_layer=False) + self.post_init() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPooling = self.ijepa( + pixel_values, + head_mask=head_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + sequence_output = outputs.last_hidden_state + logits = self.classifier(sequence_output.mean(dim=1)) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config, **kwargs) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "IJepaPreTrainedModel", + "IJepaModel", + "IJepaForImageClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/informer/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/informer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7a901eab235ecd14d2127a9eb78022c720c0f3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/informer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_informer import * + from .modeling_informer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/informer/configuration_informer.py b/venv/lib/python3.13/site-packages/transformers/models/informer/configuration_informer.py new file mode 100644 index 0000000000000000000000000000000000000000..f62417358c82ce0412336f86f2c9d32d1c019de9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/informer/configuration_informer.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Informer model configuration""" + +from typing import Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class InformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`InformerModel`]. It is used to instantiate an + Informer model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Informer + [huggingface/informer-tourism-monthly](https://huggingface.co/huggingface/informer-tourism-monthly) architecture. + + Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + prediction_length (`int`): + The prediction length for the decoder. In other words, the prediction horizon of the model. This value is + typically dictated by the dataset and we recommend to set it appropriately. + context_length (`int`, *optional*, defaults to `prediction_length`): + The context length for the encoder. If `None`, the context length will be the same as the + `prediction_length`. + distribution_output (`string`, *optional*, defaults to `"student_t"`): + The distribution emission head for the model. Could be either "student_t", "normal" or "negative_binomial". + loss (`string`, *optional*, defaults to `"nll"`): + The loss function for the model corresponding to the `distribution_output` head. For parametric + distributions it is the negative log likelihood (nll) - which currently is the only supported one. + input_size (`int`, *optional*, defaults to 1): + The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of + multivariate targets. + scaling (`string` or `bool`, *optional* defaults to `"mean"`): + Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the + scaler is set to "mean". + lags_sequence (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 5, 6, 7]`): + The lags of the input time series as covariates often dictated by the frequency of the data. Default is + `[1, 2, 3, 4, 5, 6, 7]` but we recommend to change it based on the dataset appropriately. + num_time_features (`int`, *optional*, defaults to 0): + The number of time features in the input time series. + num_dynamic_real_features (`int`, *optional*, defaults to 0): + The number of dynamic real valued features. + num_static_categorical_features (`int`, *optional*, defaults to 0): + The number of static categorical features. + num_static_real_features (`int`, *optional*, defaults to 0): + The number of static real valued features. + cardinality (`list[int]`, *optional*): + The cardinality (number of different values) for each of the static categorical features. Should be a list + of integers, having the same length as `num_static_categorical_features`. Cannot be `None` if + `num_static_categorical_features` is > 0. + embedding_dimension (`list[int]`, *optional*): + The dimension of the embedding for each of the static categorical features. Should be a list of integers, + having the same length as `num_static_categorical_features`. Cannot be `None` if + `num_static_categorical_features` is > 0. + d_model (`int`, *optional*, defaults to 64): + Dimensionality of the transformer layers. + encoder_layers (`int`, *optional*, defaults to 2): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 2): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 2): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 2): + Number of attention heads for each attention layer in the Transformer decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 32): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + decoder_ffn_dim (`int`, *optional*, defaults to 32): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and decoder. If string, `"gelu"` and + `"relu"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the encoder, and decoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention and fully connected layers for each encoder layer. + decoder_layerdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention and fully connected layers for each decoder layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability used between the two layers of the feed-forward networks. + num_parallel_samples (`int`, *optional*, defaults to 100): + The number of samples to generate in parallel for each time step of inference. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal weight initialization distribution. + use_cache (`bool`, *optional*, defaults to `True`): + Whether to use the past key/values attentions (if applicable to the model) to speed up decoding. + attention_type (`str`, *optional*, defaults to "prob"): + Attention used in encoder. This can be set to "prob" (Informer's ProbAttention) or "full" (vanilla + transformer's canonical self-attention). + sampling_factor (`int`, *optional*, defaults to 5): + ProbSparse sampling factor (only makes affect when `attention_type`="prob"). It is used to control the + reduced query matrix (Q_reduce) input length. + distil (`bool`, *optional*, defaults to `True`): + Whether to use distilling in encoder. + + Example: + + ```python + >>> from transformers import InformerConfig, InformerModel + + >>> # Initializing an Informer configuration with 12 time steps for prediction + >>> configuration = InformerConfig(prediction_length=12) + + >>> # Randomly initializing a model (with random weights) from the configuration + >>> model = InformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "informer" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + "num_hidden_layers": "encoder_layers", + "initializer_range": "init_std", + } + + def __init__( + self, + prediction_length: Optional[int] = None, + context_length: Optional[int] = None, + distribution_output: str = "student_t", + loss: str = "nll", + input_size: int = 1, + lags_sequence: Optional[list[int]] = None, + scaling: Optional[Union[str, bool]] = "mean", + num_dynamic_real_features: int = 0, + num_static_real_features: int = 0, + num_static_categorical_features: int = 0, + num_time_features: int = 0, + cardinality: Optional[list[int]] = None, + embedding_dimension: Optional[list[int]] = None, + d_model: int = 64, + encoder_ffn_dim: int = 32, + decoder_ffn_dim: int = 32, + encoder_attention_heads: int = 2, + decoder_attention_heads: int = 2, + encoder_layers: int = 2, + decoder_layers: int = 2, + is_encoder_decoder: bool = True, + activation_function: str = "gelu", + dropout: float = 0.05, + encoder_layerdrop: float = 0.1, + decoder_layerdrop: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + num_parallel_samples: int = 100, + init_std: float = 0.02, + use_cache=True, + # Informer arguments + attention_type: str = "prob", + sampling_factor: int = 5, + distil: bool = True, + **kwargs, + ): + # time series specific configuration + self.prediction_length = prediction_length + self.context_length = context_length or prediction_length + self.distribution_output = distribution_output + self.loss = loss + self.input_size = input_size + self.num_time_features = num_time_features + self.lags_sequence = lags_sequence if lags_sequence is not None else [1, 2, 3, 4, 5, 6, 7] + self.scaling = scaling + self.num_dynamic_real_features = num_dynamic_real_features + self.num_static_real_features = num_static_real_features + self.num_static_categorical_features = num_static_categorical_features + + # set cardinality + if cardinality and num_static_categorical_features > 0: + if len(cardinality) != num_static_categorical_features: + raise ValueError( + "The cardinality should be a list of the same length as `num_static_categorical_features`" + ) + self.cardinality = cardinality + else: + self.cardinality = [0] + + # set embedding_dimension + if embedding_dimension and num_static_categorical_features > 0: + if len(embedding_dimension) != num_static_categorical_features: + raise ValueError( + "The embedding dimension should be a list of the same length as `num_static_categorical_features`" + ) + self.embedding_dimension = embedding_dimension + else: + self.embedding_dimension = [min(50, (cat + 1) // 2) for cat in self.cardinality] + + self.num_parallel_samples = num_parallel_samples + + # Transformer architecture configuration + self.feature_size = input_size * len(self.lags_sequence) + self._number_of_features + self.d_model = d_model + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + + self.activation_function = activation_function + self.init_std = init_std + + self.use_cache = use_cache + + # Informer + self.attention_type = attention_type + self.sampling_factor = sampling_factor + self.distil = distil + + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def _number_of_features(self) -> int: + return ( + sum(self.embedding_dimension) + + self.num_dynamic_real_features + + self.num_time_features + + self.num_static_real_features + + self.input_size * 2 # the log1p(abs(loc)) and log(scale) features + ) + + +__all__ = ["InformerConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/informer/modeling_informer.py b/venv/lib/python3.13/site-packages/transformers/models/informer/modeling_informer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0fc4964a9aea30e23019dbfcfb602c6b4946558 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/informer/modeling_informer.py @@ -0,0 +1,2160 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/informer/modular_informer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_informer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + SampleTSPredictionOutput, + Seq2SeqTSModelOutput, + Seq2SeqTSPredictionOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_informer import InformerConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +class InformerFeatureEmbedder(nn.Module): + """ + Embed a sequence of categorical features. + + Args: + cardinalities (`list[int]`): + List of cardinalities of the categorical features. + embedding_dims (`list[int]`): + List of embedding dimensions of the categorical features. + """ + + def __init__(self, cardinalities: list[int], embedding_dims: list[int]) -> None: + super().__init__() + + self.num_features = len(cardinalities) + self.embedders = nn.ModuleList([nn.Embedding(c, d) for c, d in zip(cardinalities, embedding_dims)]) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + if self.num_features > 1: + # we slice the last dimension, giving an array of length + # self.num_features with shape (N,T) or (N) + cat_feature_slices = torch.chunk(features, self.num_features, dim=-1) + else: + cat_feature_slices = [features] + + return torch.cat( + [ + embed(cat_feature_slice.squeeze(-1)) + for embed, cat_feature_slice in zip(self.embedders, cat_feature_slices) + ], + dim=-1, + ) + + +class InformerStdScaler(nn.Module): + """ + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. + """ + + def __init__(self, config: InformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5 + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) + denominator = denominator.clamp_min(1.0) + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator + + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + scale = torch.sqrt(variance + self.minimum_scale) + return (data - loc) / scale, loc, scale + + +class InformerMeanScaler(nn.Module): + """ + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data + accordingly. + """ + + def __init__(self, config: InformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) + num_observed = observed_indicator.sum(self.dim, keepdim=True) + + scale = ts_sum / torch.clamp(num_observed, min=1) + + # If `default_scale` is provided, we use it, otherwise we use the scale + # of the batch. + if self.default_scale is None: + batch_sum = ts_sum.sum(dim=0) + batch_observations = torch.clamp(num_observed.sum(0), min=1) + default_scale = torch.squeeze(batch_sum / batch_observations) + else: + default_scale = self.default_scale * torch.ones_like(scale) + + # apply default scale where there are no observations + scale = torch.where(num_observed > 0, scale, default_scale) + + # ensure the scale is at least `self.minimum_scale` + scale = torch.clamp(scale, min=self.minimum_scale) + scaled_data = data / scale + + if not self.keepdim: + scale = scale.squeeze(dim=self.dim) + + return scaled_data, torch.zeros_like(scale), scale + + +class InformerNOPScaler(nn.Module): + """ + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. + """ + + def __init__(self, config: InformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + + def forward( + self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + return data, loc, scale + + +class InformerSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: + super().__init__(num_positions, embedding_dim) + + def _init_weight(self): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = self.weight.shape + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + self.weight = nn.Parameter(out, requires_grad=False) + + @torch.no_grad() + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """`input_ids_shape` is expected to be [bsz x seqlen].""" + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) + + +class InformerValueEmbedding(nn.Module): + def __init__(self, feature_size, d_model): + super().__init__() + self.value_projection = nn.Linear(in_features=feature_size, out_features=d_model, bias=False) + + def forward(self, x): + return self.value_projection(x) + + +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config: InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + super()._init_weights(module) + if isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class InformerAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[InformerConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class InformerProbSparseAttention(nn.Module): + """Probabilistic Attention mechanism to select the "active" + queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and + memory requirements of vanilla attention""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + sampling_factor: int = 5, + bias: bool = True, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.factor = sampling_factor + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + key_states_time_length = key_states.size(1) # L_K + log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K + + query_states_time_length = query_states.size(1) # L_Q + log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q + + u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length) + u = min(self.factor * log_query_states_time_length, query_states_time_length) + + if key_states_time_length > 0: + index_sample = torch.randint(0, key_states_time_length, (u_part,)) + k_sample = key_states[:, index_sample, :] + else: + k_sample = key_states + + queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled + + # find the Top_k query with sparsity measurement + if u > 0: + sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div( + queries_keys_sample.sum(dim=-1), key_states_time_length + ) # M + top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top + + # calculate q_reduce: query_states[:, top_u_sparsity_measurement] + dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) + q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement] + else: + q_reduce = query_states + top_u_sparsity_measurement = None + + # Use q_reduce to calculate attention weights + attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2)) + + src_len = key_states.size(1) + if attn_weights.size() != (bsz * self.num_heads, u, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape( + bsz * self.num_heads, tgt_len, src_len + ) + + if top_u_sparsity_measurement is not None: + dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1) + prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :] + + attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view( + bsz, self.num_heads, u, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.bmm(attn_probs, value_states) + + # calculate context for updating the attn_output, based on: + # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 + if self.is_decoder: + # cast to float32 before operation to avoid overflow + context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype) + else: + v_mean_dim_time = value_states.mean(dim=-2) + context = ( + v_mean_dim_time.unsqueeze(dim=1) + .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1)) + .clone() + ) + + if top_u_sparsity_measurement is not None: + # update context: copy the attention output to the context at top_u_sparsity_measurement index + dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1) + context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output + attn_output = context + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py +class InformerConvLayer(GradientCheckpointingLayer): + def __init__(self, c_in): + super().__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=1, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class InformerEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InformerConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class InformerDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = InformerAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class InformerEncoder(InformerPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`InformerEncoderLayer`]. + + Args: + config: InformerConfig + """ + + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False + + if config.distil: + self.conv_layers = nn.ModuleList( + [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)] + ) + self.conv_layers.append(None) + else: + self.conv_layers = [None] * config.encoder_layers + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class InformerDecoder(InformerPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a + [`InformerDecoderLayer`] + + Args: + config: InformerConfig + """ + + def __init__(self, config: InformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = inputs_embeds.size()[:-1] + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) + + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class InformerModel(InformerPreTrainedModel): + def __init__(self, config: InformerConfig): + super().__init__(config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = InformerMeanScaler(config) + elif config.scaling == "std": + self.scaler = InformerStdScaler(config) + else: + self.scaler = InformerNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = InformerFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = InformerEncoder(config) + self.decoder = InformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @property + def _past_length(self) -> int: + return self.config.context_length + max(self.config.lags_sequence) + + def get_lagged_subsequences( + self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 + ) -> torch.Tensor: + """ + Returns lagged subsequences of a given sequence. Returns a tensor of shape (N, S, C, I), + where S = subsequences_length and I = len(indices), containing lagged subsequences. Specifically, lagged[i, + j, :, k] = sequence[i, -indices[k]-S+j, :]. + + Args: + sequence: Tensor + The sequence from which lagged subsequences should be extracted. Shape: (N, T, C). + subsequences_length : int + Length of the subsequences to be extracted. + shift: int + Shift the lags by this amount back. + """ + sequence_length = sequence.shape[1] + indices = [lag - shift for lag in self.config.lags_sequence] + + if max(indices) + subsequences_length > sequence_length: + raise ValueError( + f"lags cannot go further than history length, found lag {max(indices)} " + f"while history length is only {sequence_length}" + ) + + lagged_values = [] + for lag_index in indices: + begin_index = -lag_index - subsequences_length + end_index = -lag_index if lag_index > 0 else None + lagged_values.append(sequence[:, begin_index:end_index, ...]) + return torch.stack(lagged_values, dim=-1) + + def create_network_inputs( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + ): + # time feature + time_feat = ( + torch.cat( + ( + past_time_features[:, self._past_length - self.config.context_length :, ...], + future_time_features, + ), + dim=1, + ) + if future_values is not None + else past_time_features[:, self._past_length - self.config.context_length :, ...] + ) + + # target + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + context = past_values[:, -self.config.context_length :] + observed_context = past_observed_mask[:, -self.config.context_length :] + _, loc, scale = self.scaler(context, observed_context) + + inputs = ( + (torch.cat((past_values, future_values), dim=1) - loc) / scale + if future_values is not None + else (past_values - loc) / scale + ) + + # static features + if loc.ndim == 3: + squeezed_loc = loc.squeeze(1) + squeezed_scale = scale.squeeze(1) + else: + squeezed_loc = loc + squeezed_scale = scale + log_abs_loc = squeezed_loc.abs().log1p() + log_scale = squeezed_scale.log() + static_feat = torch.cat((log_abs_loc, log_scale), dim=1) + + if static_real_features is not None: + static_feat = torch.cat((static_real_features, static_feat), dim=1) + if static_categorical_features is not None: + embedded_cat = self.embedder(static_categorical_features) + static_feat = torch.cat((embedded_cat, static_feat), dim=1) + expanded_static_feat = static_feat.unsqueeze(1).expand(-1, time_feat.shape[1], -1) + + # all features + features = torch.cat((expanded_static_feat, time_feat), dim=-1) + + # lagged features + subsequences_length = ( + self.config.context_length + self.config.prediction_length + if future_values is not None + else self.config.context_length + ) + lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length) + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) + + if reshaped_lagged_sequence.shape[1] != time_feat.shape[1]: + raise ValueError( + f"input length {reshaped_lagged_sequence.shape[1]} and time feature lengths {time_feat.shape[1]} does not match" + ) + + # transformer inputs + transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) + + return transformer_inputs, loc, scale, static_feat + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + past_observed_mask: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Seq2SeqTSModelOutput, tuple]: + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_inputs, loc, scale, static_feat = self.create_network_inputs( + past_values=past_values, + past_time_features=past_time_features, + past_observed_mask=past_observed_mask, + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + future_values=future_values, + future_time_features=future_time_features, + ) + + if encoder_outputs is None: + enc_input = transformer_inputs[:, : self.config.context_length, ...] + encoder_outputs = self.encoder( + inputs_embeds=enc_input, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Avoid empty tensors and instead create a zeroes tensor which + # will be treated the same in torch, i.e. matmul with empty == all 0s + if self.config.context_length >= transformer_inputs.shape[1]: + bsz, _, dim = transformer_inputs.shape + dec_input = torch.zeros( + size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype + ) + else: + dec_input = transformer_inputs[:, self.config.context_length :, ...] + + decoder_outputs = self.decoder( + inputs_embeds=dec_input, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + (loc, scale, static_feat) + + return Seq2SeqTSModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + loc=loc, + scale=scale, + static_features=static_feat, + ) + + +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +@auto_docstring +class InformerForPrediction(InformerPreTrainedModel): + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.model = InformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + def output_params(self, dec_output): + return self.parameter_projection(dec_output) + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @torch.jit.ignore + def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution: + sliced_params = params + if trailing_n is not None: + sliced_params = [p[:, -trailing_n:] for p in params] + return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale) + + @auto_docstring + def forward( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + past_observed_mask: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + future_observed_mask: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Seq2SeqTSModelOutput, tuple]: + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerForPrediction.from_pretrained( + ... "huggingface/informer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if future_values is not None: + use_cache = False + + outputs = self.model( + past_values=past_values, + past_time_features=past_time_features, + past_observed_mask=past_observed_mask, + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + future_values=future_values, + future_time_features=future_time_features, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + use_cache=use_cache, + return_dict=return_dict, + cache_position=cache_position, + ) + + prediction_loss = None + params = None + if future_values is not None: + params = self.output_params(outputs[0]) # outputs.last_hidden_state + # loc is 3rd last and scale is 2nd last output + distribution = self.output_distribution(params, loc=outputs[-3], scale=outputs[-2]) + + loss = self.loss(distribution, future_values) + + if future_observed_mask is None: + future_observed_mask = torch.ones_like(future_values) + + if len(self.target_shape) == 0: + loss_weights = future_observed_mask + else: + loss_weights, _ = future_observed_mask.min(dim=-1, keepdim=False) + + prediction_loss = weighted_average(loss, weights=loss_weights) + + if not return_dict: + outputs = ((params,) + outputs[1:]) if params is not None else outputs[1:] + return ((prediction_loss,) + outputs) if prediction_loss is not None else outputs + + return Seq2SeqTSPredictionOutput( + loss=prediction_loss, + params=params, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + loc=outputs.loc, + scale=outputs.scale, + static_features=outputs.static_features, + ) + + @torch.no_grad() + def generate( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + future_time_features: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SampleTSPredictionOutput: + r""" + Greedily generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size + of this tensor must be larger than the `context_length` of the model, since the model will use the + larger size to construct lag features, i.e. additional values from the past which are added in order to + serve as "extra context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if + no `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length + of the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, + such as `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number + of variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things + like "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). + These could also be so-called "age" features, which basically help the model know "at which point in + life" a time-series is. Age features have small values for distant past time steps and increase + monotonically the more we approach the current time step. Holiday features are also a good example of + time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, + where the position encodings are learned from scratch internally as parameters of the model, the Time + Series Transformer requires to provide additional time features. The Time Series Transformer only + learns additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these + features must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to sampled + predictions. These could be things like "month of year", "day of the month", etc. encoded as vectors + (for instance as Fourier features). These could also be so-called "age" features, which basically help + the model know "at which point in life" a time-series is. Age features have small values for distant + past time steps and increase monotonically the more we approach the current time step. Holiday features + are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, + where the position encodings are learned from scratch internally as parameters of the model, the Time + Series Transformer requires to provide additional time features. The Time Series Transformer only + learns additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these + features must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to + the values of the time series. + + Static categorical features are features which have the same value for all time steps (static over + time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + + Return: + [`SampleTSPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of + samples, prediction_length)` or `(batch_size, number of samples, prediction_length, input_size)` for + multivariate predictions. + """ + outputs = self( + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + past_time_features=past_time_features, + past_values=past_values, + past_observed_mask=past_observed_mask, + future_time_features=future_time_features, + future_values=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + use_cache=True, + ) + + decoder = self.model.get_decoder() + enc_last_hidden = outputs.encoder_last_hidden_state + loc = outputs.loc + scale = outputs.scale + static_feat = outputs.static_features + + num_parallel_samples = self.config.num_parallel_samples + repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0) + repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0) + + repeated_past_values = ( + past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc + ) / repeated_scale + + expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1) + features = torch.cat((expanded_static_feat, future_time_features), dim=-1) + repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0) + + repeated_enc_last_hidden = enc_last_hidden.repeat_interleave(repeats=num_parallel_samples, dim=0) + + future_samples = [] + + # greedy decoding + for k in range(self.config.prediction_length): + lagged_sequence = self.model.get_lagged_subsequences( + sequence=repeated_past_values, + subsequences_length=1 + k, + shift=1, + ) + + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) + + decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1) + + dec_output = decoder(inputs_embeds=decoder_input, encoder_hidden_states=repeated_enc_last_hidden) + dec_last_hidden = dec_output.last_hidden_state + + params = self.parameter_projection(dec_last_hidden[:, -1:]) + distr = self.output_distribution(params, loc=repeated_loc, scale=repeated_scale) + next_sample = distr.sample() + + repeated_past_values = torch.cat( + (repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1 + ) + future_samples.append(next_sample) + + concat_future_samples = torch.cat(future_samples, dim=1) + + return SampleTSPredictionOutput( + sequences=concat_future_samples.reshape( + (-1, num_parallel_samples, self.config.prediction_length) + self.target_shape, + ) + ) + + +__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/informer/modular_informer.py b/venv/lib/python3.13/site-packages/transformers/models/informer/modular_informer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6e2ad30a9f858393bff04dd5b3f797574bfacf --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/informer/modular_informer.py @@ -0,0 +1,981 @@ +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Informer model.""" + +from typing import Optional, Union + +import numpy as np +import torch +from torch import nn + +from ...cache_utils import Cache, EncoderDecoderCache +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, +) +from ...utils.deprecation import deprecate_kwarg +from ..bart.modeling_bart import BartAttention +from ..time_series_transformer.modeling_time_series_transformer import ( + TimeSeriesFeatureEmbedder, + TimeSeriesMeanScaler, + TimeSeriesNOPScaler, + TimeSeriesSinusoidalPositionalEmbedding, + TimeSeriesStdScaler, + TimeSeriesTransformerDecoder, + TimeSeriesTransformerDecoderLayer, + TimeSeriesTransformerEncoder, + TimeSeriesTransformerEncoderLayer, + TimeSeriesTransformerForPrediction, + TimeSeriesTransformerModel, + TimeSeriesValueEmbedding, +) +from .configuration_informer import InformerConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +class InformerFeatureEmbedder(TimeSeriesFeatureEmbedder): + pass + + +class InformerStdScaler(TimeSeriesStdScaler): + pass + + +class InformerMeanScaler(TimeSeriesMeanScaler): + pass + + +class InformerNOPScaler(TimeSeriesNOPScaler): + pass + + +class InformerSinusoidalPositionalEmbedding(TimeSeriesSinusoidalPositionalEmbedding): + pass + + +class InformerValueEmbedding(TimeSeriesValueEmbedding): + pass + + +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config: InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + super()._init_weights(module) + if isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class InformerAttention(BartAttention): + pass + + +class InformerProbSparseAttention(nn.Module): + """Probabilistic Attention mechanism to select the "active" + queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and + memory requirements of vanilla attention""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + sampling_factor: int = 5, + bias: bool = True, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.factor = sampling_factor + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + key_states_time_length = key_states.size(1) # L_K + log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K + + query_states_time_length = query_states.size(1) # L_Q + log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q + + u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length) + u = min(self.factor * log_query_states_time_length, query_states_time_length) + + if key_states_time_length > 0: + index_sample = torch.randint(0, key_states_time_length, (u_part,)) + k_sample = key_states[:, index_sample, :] + else: + k_sample = key_states + + queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled + + # find the Top_k query with sparsity measurement + if u > 0: + sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div( + queries_keys_sample.sum(dim=-1), key_states_time_length + ) # M + top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top + + # calculate q_reduce: query_states[:, top_u_sparsity_measurement] + dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) + q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement] + else: + q_reduce = query_states + top_u_sparsity_measurement = None + + # Use q_reduce to calculate attention weights + attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2)) + + src_len = key_states.size(1) + if attn_weights.size() != (bsz * self.num_heads, u, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape( + bsz * self.num_heads, tgt_len, src_len + ) + + if top_u_sparsity_measurement is not None: + dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1) + prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :] + + attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view( + bsz, self.num_heads, u, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.bmm(attn_probs, value_states) + + # calculate context for updating the attn_output, based on: + # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 + if self.is_decoder: + # cast to float32 before operation to avoid overflow + context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype) + else: + v_mean_dim_time = value_states.mean(dim=-2) + context = ( + v_mean_dim_time.unsqueeze(dim=1) + .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1)) + .clone() + ) + + if top_u_sparsity_measurement is not None: + # update context: copy the attention output to the context at top_u_sparsity_measurement index + dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1) + context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output + attn_output = context + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py +class InformerConvLayer(GradientCheckpointingLayer): + def __init__(self, c_in): + super().__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=1, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class InformerEncoderLayer(TimeSeriesTransformerEncoderLayer): + def __init__(self, config: InformerConfig): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class InformerDecoderLayer(TimeSeriesTransformerDecoderLayer): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + + +class InformerEncoder(TimeSeriesTransformerEncoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.gradient_checkpointing = False + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + if config.distil: + self.conv_layers = nn.ModuleList( + [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)] + ) + self.conv_layers.append(None) + else: + self.conv_layers = [None] * config.encoder_layers + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class InformerDecoder(TimeSeriesTransformerDecoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + +class InformerModel(TimeSeriesTransformerModel): + def __init__(self, config: InformerConfig): + PreTrainedModel.__init__(self, config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = InformerMeanScaler(config) + elif config.scaling == "std": + self.scaler = InformerStdScaler(config) + else: + self.scaler = InformerNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = InformerFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = InformerEncoder(config) + self.decoder = InformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + super().forward(**super_kwargs) + + +class InformerForPrediction(TimeSeriesTransformerForPrediction): + def __init__(self, config: InformerConfig): + PreTrainedModel.__init__(self, config) + + self.model = InformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + @auto_docstring + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerForPrediction.from_pretrained( + ... "huggingface/informer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + super().forward(**super_kwargs) + + +__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblip/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/instructblip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2bd053d2c8b9fdaec4211b2b74e964bb88a5e3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/instructblip/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_instructblip import * + from .modeling_instructblip import * + from .processing_instructblip import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblip/configuration_instructblip.py b/venv/lib/python3.13/site-packages/transformers/models/instructblip/configuration_instructblip.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8323f15f0525887d178239b6b221dbb488d952 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/instructblip/configuration_instructblip.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""InstructBLIP model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class InstructBlipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InstructBlipVisionModel`]. It is used to + instantiate a InstructBLIP vision encoder according to the specified arguments, defining the model architecture. + Instantiating a configuration defaults will yield a similar configuration to that of the InstructBLIP + [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1408): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 39): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. to 1e-5): The epsilon used by the layer + normalization layers. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries and values in the self-attention layers. + + Example: + + ```python + >>> from transformers import InstructBlipVisionConfig, InstructBlipVisionModel + + >>> # Initializing a InstructBlipVisionConfig with Salesforce/instruct-blip-flan-t5 style configuration + >>> configuration = InstructBlipVisionConfig() + + >>> # Initializing a InstructBlipVisionModel (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration + >>> model = InstructBlipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "instructblip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1408, + intermediate_size=6144, + num_hidden_layers=39, + num_attention_heads=16, + image_size=224, + patch_size=14, + hidden_act="gelu", + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.qkv_bias = qkv_bias + + +class InstructBlipQFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InstructBlipQFormerModel`]. It is used to + instantiate a InstructBLIP Querying Transformer (Q-Former) model according to the specified arguments, defining the + model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the InstructBLIP [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) + architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. + Read the documentation from [`PretrainedConfig`] for more information. + + Note that [`InstructBlipQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling the model. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Token id used for padding sequences. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + cross_attention_frequency (`int`, *optional*, defaults to 2): + The frequency of adding cross-attention to the Transformer layers. + encoder_hidden_size (`int`, *optional*, defaults to 1408): + The hidden size of the hidden states for cross-attention. + + Examples: + + ```python + >>> from transformers import InstructBlipQFormerConfig, InstructBlipQFormerModel + + >>> # Initializing a InstructBLIP Salesforce/instruct-blip-flan-t5 style configuration + >>> configuration = InstructBlipQFormerConfig() + + >>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration + >>> model = InstructBlipQFormerModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "instructblip_qformer" + base_config_key = "qformer_config" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + cross_attention_frequency=2, + encoder_hidden_size=1408, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.cross_attention_frequency = cross_attention_frequency + self.encoder_hidden_size = encoder_hidden_size + + +class InstructBlipConfig(PretrainedConfig): + r""" + [`InstructBlipConfig`] is the configuration class to store the configuration of a + [`InstructBlipForConditionalGeneration`]. It is used to instantiate a InstructBLIP model according to the specified + arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with + the defaults will yield a similar configuration to that of the InstructBLIP + [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`InstructBlipVisionConfig`]. + qformer_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`InstructBlipQFormerConfig`]. + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize any [`PretrainedConfig`]. + num_query_tokens (`int`, *optional*, defaults to 32): + The number of query tokens passed through the Transformer. + + image_token_index (`int`, *optional*): + Token index of special image token. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ( + ... InstructBlipVisionConfig, + ... InstructBlipQFormerConfig, + ... OPTConfig, + ... InstructBlipConfig, + ... InstructBlipForConditionalGeneration, + ... ) + + >>> # Initializing a InstructBlipConfig with Salesforce/instruct-blip-flan-t5 style configuration + >>> configuration = InstructBlipConfig() + + >>> # Initializing a InstructBlipForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration + >>> model = InstructBlipForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a InstructBlipConfig from a InstructBlipVisionConfig, InstructBlipQFormerConfig and any PretrainedConfig + + >>> # Initializing InstructBLIP vision, InstructBLIP Q-Former and language model configurations + >>> vision_config = InstructBlipVisionConfig() + >>> qformer_config = InstructBlipQFormerConfig() + >>> text_config = OPTConfig() + + >>> config = InstructBlipConfig.from_text_vision_configs(vision_config, qformer_config, text_config) + ```""" + + model_type = "instructblip" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = { + "text_config": AutoConfig, + "qformer_config": InstructBlipQFormerConfig, + "vision_config": InstructBlipVisionConfig, + } + + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_token_index=None, + **kwargs, + ): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. initializing the InstructBlipVisionConfig with default values.") + + if qformer_config is None: + qformer_config = {} + logger.info("qformer_config is None. Initializing the InstructBlipQFormerConfig with default values.") + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).") + + self.vision_config = InstructBlipVisionConfig(**vision_config) + self.qformer_config = InstructBlipQFormerConfig(**qformer_config) + text_model_type = text_config.get("model_type", "opt") + self.text_config = CONFIG_MAPPING[text_model_type](**text_config) + + self.num_query_tokens = num_query_tokens + self.image_token_index = image_token_index + self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size + self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + self.initializer_factor = 1.0 + self.initializer_range = 0.02 + + @classmethod + def from_vision_qformer_text_configs( + cls, + vision_config: InstructBlipVisionConfig, + qformer_config: InstructBlipQFormerConfig, + text_config: PretrainedConfig, + **kwargs, + ): + r""" + Instantiate a [`InstructBlipConfig`] (or a derived class) from a InstructBLIP vision model, Q-Former and + language model configurations. + + Returns: + [`InstructBlipConfig`]: An instance of a configuration object + """ + + return cls( + vision_config=vision_config.to_dict(), + qformer_config=qformer_config.to_dict(), + text_config=text_config.to_dict(), + **kwargs, + ) + + +__all__ = ["InstructBlipConfig", "InstructBlipQFormerConfig", "InstructBlipVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblip/modeling_instructblip.py b/venv/lib/python3.13/site-packages/transformers/models/instructblip/modeling_instructblip.py new file mode 100644 index 0000000000000000000000000000000000000000..dad91cc312e5d247734ef8b48bb2dd89b3cc48ee --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/instructblip/modeling_instructblip.py @@ -0,0 +1,1532 @@ +# coding=utf-8 +# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch InstructBLIP model.""" + +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.generic import OutputRecorder, check_model_inputs +from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM +from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class defining the outputs of [`InstructBlipForConditionalGeneration`]. + """ +) +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip +class InstructBlipForConditionalGenerationModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Language modeling loss from the language model. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head of the language model. + vision_outputs (`BaseModelOutputWithPooling`): + Outputs of the vision encoder. + qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): + Outputs of the Q-Former (Querying Transformer). + language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`): + Outputs of the language model. + """ + + loss: Optional[tuple[torch.FloatTensor]] = None + logits: Optional[tuple[torch.FloatTensor]] = None + vision_outputs: Optional[torch.FloatTensor] = None + qformer_outputs: Optional[tuple[torch.FloatTensor]] = None + language_model_outputs: Optional[tuple[torch.FloatTensor]] = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] + if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"] + else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip +class InstructBlipVisionEmbeddings(nn.Module): + def __init__(self, config: InstructBlipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip +class InstructBlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.is_causal = False + self.attention_dropout = config.attention_dropout + + # small tweak here compared to CLIP, no bias here + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) + + if config.qkv_bias: + q_bias = nn.Parameter(torch.zeros(self.embed_dim)) + v_bias = nn.Parameter(torch.zeros(self.embed_dim)) + else: + q_bias = None + v_bias = None + + if q_bias is not None: + qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias)) + self.qkv.bias = nn.Parameter(qkv_bias) + + self.projection = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = self.qkv(hidden_states) + + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( + 2, 0, 3, 1, 4 + ) + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.projection(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.blip.modeling_blip.BlipMLP +class InstructBlipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip +class InstructBlipEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InstructBlipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = InstructBlipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = InstructBlipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + head_mask=attention_mask, + **kwargs, + ) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + return hidden_states + + +@auto_docstring +class InstructBlipPreTrainedModel(PreTrainedModel): + config: InstructBlipConfig + base_model_prefix = "blip" + supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + _no_split_modules = [ + "InstructBlipQFormerEmbeddings", + "InstructBlipAttention", + "InstructBlipQFormerMultiHeadAttention", + "InstructBlipQFormerSelfOutput", + ] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_range + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=factor) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, InstructBlipVisionEmbeddings): + nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)): + module.query_tokens.data.zero_() + + +# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip +class InstructBlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`InstructBlipEncoderLayer`]. + + Args: + config (`InstructBlipConfig`): + The corresponding vision configuration for the `InstructBlipEncoder`. + """ + + def __init__(self, config: InstructBlipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlip, BLIP->INSTRUCTBLIP +class InstructBlipVisionModel(InstructBlipPreTrainedModel): + main_input_name = "pixel_values" + config: InstructBlipVisionConfig + _can_record_outputs = { + "hidden_states": InstructBlipEncoderLayer, + "attentions": InstructBlipAttention, + } + + def __init__(self, config: InstructBlipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = InstructBlipVisionEmbeddings(config) + self.encoder = InstructBlipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPooling]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + def get_input_embeddings(self): + return self.embeddings + + +class InstructBlipQFormerMultiHeadAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention heads (%d)" + % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) + self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_scores_dtype = attention_scores.dtype + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer +class InstructBlipQFormerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip +class InstructBlipQFormerAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention) + self.output = InstructBlipQFormerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + attn_output, _ = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + attention_output = self.output(attn_output, hidden_states) + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer +class InstructBlipQFormerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer +class InstructBlipQFormerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class InstructBlipQFormerLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = InstructBlipQFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = InstructBlipQFormerIntermediate(config) + self.output = InstructBlipQFormerOutput(config) + + self.intermediate_query = InstructBlipQFormerIntermediate(config) + self.output_query = InstructBlipQFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + query_length=0, + **kwargs: Unpack[TransformersKwargs], + ): + attention_output = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + **kwargs, + ) + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + query_attention_output = self.crossattention( + query_attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ).to(layer_output.device) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + return layer_output + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip +class InstructBlipQFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + query_length=0, + **kwargs: Unpack[TransformersKwargs], + ): + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + layer_head_mask = head_mask[i] if head_mask is not None else None + + hidden_states = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + query_length=query_length, + **kwargs, + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + ) + + +class InstructBlipQFormerEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids.to(embeddings.device)) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = embeddings.to(self.layernorm.weight.dtype) + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class InstructBlipQFormerModel(InstructBlipPreTrainedModel): + """ + Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the + instruction as input. + """ + + _supports_attention_backend = False # adds position on attn weights before last matmul + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + + _can_record_outputs = { + "hidden_states": InstructBlipQFormerLayer, + "attentions": [ + OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".attention"), + ], + "cross_attentions": [ + OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".crossattention"), + ], + } + + def __init__(self, config: InstructBlipQFormerConfig): + super().__init__(config) + self.config = config + + self.embeddings = InstructBlipQFormerEmbeddings(config) + + self.encoder = InstructBlipQFormerEncoder(config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: tuple[int], + device: torch.device, + has_query: bool = False, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`tuple[int]`): + The shape of the input to the model. + device: (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})", + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + query_embeds: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Hidden states to be used in the attention computation. If cross-attention, + will be used for the query (i.e., key and value will use the encoder_hidden_states). + """ + if input_ids is None and query_embeds is None: + raise ValueError("You have to specify query_embeds when input_ids is None") + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs: BaseModelOutput = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + query_length=query_length, + **kwargs, + ) + sequence_output = encoder_outputs.last_hidden_state + pooled_output = sequence_output[:, 0, :] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + +@auto_docstring( + custom_intro=""" + InstructBLIP base Model consisting of language model, qformer and vision encoder. + """ +) +class InstructBlipModel(InstructBlipPreTrainedModel): + main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 + + def __init__(self, config: InstructBlipConfig): + super().__init__(config) + + self.vision_model = InstructBlipVisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = InstructBlipQFormerModel(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + self.language_model = AutoModel.from_config(config.text_config) + + if self.language_model._no_split_modules is not None: + self._no_split_modules.extend(self.language_model._no_split_modules) + + if self.language_model._keep_in_fp32_modules is not None: + self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared + + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.FloatTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, InstructBlipForConditionalGenerationModelOutput]: + r""" + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + """ + + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + **kwargs, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + if inputs_embeds is None: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **kwargs, + ) + else: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + + return InstructBlipForConditionalGenerationModelOutput( + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + +@auto_docstring( + custom_intro=""" + InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision + encoder, Querying Transformer (Q-Former) and a language model. + + One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue + the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token. + """ +) +class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): + config: InstructBlipConfig + main_input_name = "pixel_values" + + _can_compile_fullgraph = True + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 + + def __init__(self, config: InstructBlipConfig): + super().__init__(config) + + self.vision_model = InstructBlipVisionModel._from_config(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + + if language_model._no_split_modules is not None: + self._no_split_modules.extend(language_model._no_split_modules) + + if language_model._keep_in_fp32_modules is not None: + self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules) + + self.language_model = language_model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.language_model.get_output_embeddings() + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._tie_weights + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared + + # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + if return_dict: + return language_model_inputs, vision_outputs, query_outputs + return language_model_inputs + + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.FloatTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, InstructBlipForConditionalGenerationModelOutput]: + r""" + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size - + 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size]` + + Examples: + + ```python + >>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b") + >>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b") + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + >>> prompt = "What is unusual about this image?" + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) + + >>> outputs = model.generate( + ... **inputs, + ... do_sample=False, + ... num_beams=5, + ... max_length=256, + ... min_length=1, + ... top_p=0.9, + ... repetition_penalty=1.5, + ... length_penalty=1.0, + ... temperature=1, + ... ) + >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() + >>> print(generated_text) + The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation. + ```""" + + language_model_inputs, vision_outputs, query_outputs = self.get_image_features( + pixel_values, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **kwargs, + ) + logits = outputs[0] + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + else: + kwargs["return_dict"] = True + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=labels, + **kwargs, + ) + loss = outputs.loss + logits = outputs.logits + + return InstructBlipForConditionalGenerationModelOutput( + loss=loss, + logits=logits, + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + @torch.no_grad() + def generate( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: Optional[torch.LongTensor] = None, + qformer_attention_mask: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + **generate_kwargs, + ) -> torch.LongTensor: + """ + Overrides `generate` function to be able to use the model as a conditional generator. + + Args: + pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): + Input images to be processed. + qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + The sequence used as a prompt to be fed to the Q-Former module. + qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + Mask to avoid performing attention on padding token indices. + input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + The sequence used as a prompt for the generation. + attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + Mask to avoid performing attention on padding token indices. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the positional encoding of the image embeddings. + + Returns: + captions (list): A list of strings of length batch_size * num_captions. + """ + if hasattr(self, "hf_device_map"): + # preprocess for `accelerate` + self._preprocess_accelerate() + + batch_size = pixel_values.shape[0] + language_model_inputs, vision_outputs, query_outputs = self.get_image_features( + pixel_values, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + if inputs_embeds is None: + if input_ids is None: + image_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens = image_tokens + [self.config.text_config.bos_token_id] + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + + inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} + if not self.language_model.config.is_encoder_decoder: + inputs["input_ids"] = input_ids + + outputs = self.language_model.generate(**inputs, **generate_kwargs) + + return outputs + + +__all__ = [ + "InstructBlipQFormerModel", + "InstructBlipPreTrainedModel", + "InstructBlipModel", + "InstructBlipForConditionalGeneration", + "InstructBlipVisionModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblip/processing_instructblip.py b/venv/lib/python3.13/site-packages/transformers/models/instructblip/processing_instructblip.py new file mode 100644 index 0000000000000000000000000000000000000000..122fc11622ff51e72467b4f72bf46091ebb7c20c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/instructblip/processing_instructblip.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former. +""" + +import os +from typing import Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput +from ...utils import logging +from ..auto import AutoTokenizer + + +logger = logging.get_logger(__name__) + + +class InstructBlipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } + + +class InstructBlipProcessor(ProcessorMixin): + r""" + Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single + processor. + + [`InstructBlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the + docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information. + + Args: + image_processor (`BlipImageProcessor`): + An instance of [`BlipImageProcessor`]. The image processor is a required input. + tokenizer (`AutoTokenizer`): + An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. + qformer_tokenizer (`AutoTokenizer`): + An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input. + num_query_tokens (`int`, *optional*):" + Number of tokens used by the Qformer as queries, should be same as in model's config. + """ + + attributes = ["image_processor", "tokenizer", "qformer_tokenizer"] + image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") + tokenizer_class = "AutoTokenizer" + qformer_tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): + if not hasattr(tokenizer, "image_token"): + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + else: + self.image_token = tokenizer.image_token + self.num_query_tokens = num_query_tokens + + super().__init__(image_processor, tokenizer, qformer_tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[InstructBlipProcessorKwargs], + ) -> BatchFeature: + """ + This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and + [`BertTokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + """ + if images is None and text is None: + raise ValueError("You have to specify at least images or text.") + + output_kwargs = self._merge_kwargs( + InstructBlipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + encoding = {} + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"]) + encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids") + encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask") + + # We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token + if output_kwargs["text_kwargs"].get("max_length") is not None: + output_kwargs["text_kwargs"]["max_length"] -= self.num_query_tokens + text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if images is not None: + # Image tokens should not be padded/truncated or prepended with special BOS token + image_tokens = self.image_token.content * self.num_query_tokens + output_kwargs["text_kwargs"]["add_special_tokens"] = False + output_kwargs["text_kwargs"]["padding"] = False + output_kwargs["text_kwargs"]["truncation"] = False + image_text_encoding = self.tokenizer(image_tokens, **output_kwargs["text_kwargs"]) + for k in text_encoding: + text_encoding[k] = [image_text_encoding[k] + sample for sample in text_encoding[k]] + encoding.update(text_encoding) + + if images is not None: + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) + encoding.update(image_encoding) + + # Cast to desired return tensors type + encoding = BatchFeature(encoding, tensor_type=return_tensors) + return encoding + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + qformer_input_names = ["qformer_input_ids", "qformer_attention_mask"] + return tokenizer_input_names + image_processor_input_names + qformer_input_names + + # overwrite to save the Q-Former tokenizer in a separate folder + def save_pretrained(self, save_directory, **kwargs): + if os.path.isfile(save_directory): + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + os.makedirs(save_directory, exist_ok=True) + qformer_tokenizer_path = os.path.join(save_directory, "qformer_tokenizer") + self.qformer_tokenizer.save_pretrained(qformer_tokenizer_path) + + # We modify the attributes so that only the tokenizer and image processor are saved in the main folder + qformer_present = "qformer_tokenizer" in self.attributes + if qformer_present: + self.attributes.remove("qformer_tokenizer") + + outputs = super().save_pretrained(save_directory, **kwargs) + + if qformer_present: + self.attributes += ["qformer_tokenizer"] + return outputs + + # overwrite to load the Q-Former tokenizer from a separate folder + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + # if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs' + if isinstance(processor, tuple): + processor = processor[0] + qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer") + processor.qformer_tokenizer = qformer_tokenizer + return processor + + +__all__ = ["InstructBlipProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/jamba/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/jamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8789007ad0bc8f2a3ee1318370323a054474857b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/jamba/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_jamba import * + from .modeling_jamba import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/jamba/configuration_jamba.py b/venv/lib/python3.13/site-packages/transformers/models/jamba/configuration_jamba.py new file mode 100644 index 0000000000000000000000000000000000000000..a557562ea01887bba88af450715914797e0ee34b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/jamba/configuration_jamba.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Jamba model configuration""" + +import math + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class JambaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a + Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Jamba-v0.1 model. + + [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`JambaModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None`. + max_position_embeddings (`int`, *optional*, defaults to 262144): + This value doesn't have any real effect. The maximum sequence length that this model is intended to be + used with. It can be used with longer sequences, but performance may degrade. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_experts (`int`, *optional*, defaults to 16): + Number of experts per Sparse MLP layer. + expert_layer_period (`int`, *optional*, defaults to 2): + Once in this many layers, we will have an expert layer + expert_layer_offset (`int`, *optional*, defaults to 1): + The first layer index that contains an expert mlp layer + attn_layer_period (`int`, *optional*, defaults to 8): + Once in this many layers, we will have a vanilla attention layer + attn_layer_offset (`int`, *optional*, defaults to 4): + The first layer index that contains a vanilla attention mlp layer + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if + `True` and kernels are not available + mamba_d_state (`int`, *optional*, defaults to 16): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + + """ + + model_type = "jamba" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=65536, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + num_logits_to_keep=1, + output_router_logits=False, + router_aux_loss_coef=0.001, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + max_position_embeddings=262144, + attention_dropout=0.0, + num_experts_per_tok=2, + num_experts=16, + expert_layer_period=2, + expert_layer_offset=1, + attn_layer_period=8, + attn_layer_offset=4, + use_mamba_kernels=True, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_dt_rank="auto", + mamba_conv_bias=True, + mamba_proj_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.expert_layer_period = expert_layer_period + self.expert_layer_offset = expert_layer_offset + self.attn_layer_period = attn_layer_period + self.attn_layer_offset = attn_layer_offset + + self._check_supported_offset("attention", self.attn_layer_period, self.attn_layer_offset) + self._check_supported_offset("expert", self.expert_layer_period, self.expert_layer_offset) + + self.use_mamba_kernels = use_mamba_kernels + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return [ + "attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba" + for i in range(self.num_hidden_layers) + ] + + @property + def layers_num_experts(self): + return [ + self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1 + for i in range(self.num_hidden_layers) + ] + + def _check_supported_offset(self, property_: str, period: int, offset: int): + if offset >= period: + raise ValueError( + f"{property_} layer offset ({offset}) must be smaller than {property_} layer period ({period})" + ) + + +__all__ = ["JambaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/jamba/modeling_jamba.py b/venv/lib/python3.13/site-packages/transformers/models/jamba/modeling_jamba.py new file mode 100644 index 0000000000000000000000000000000000000000..7196045390b1f9f10ac801dbc3a533c8b87c594f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/jamba/modeling_jamba.py @@ -0,0 +1,1462 @@ +# coding=utf-8 +# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Jamba model.""" + +import math +from typing import Any, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import ( + GenericForSequenceClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available +from .configuration_jamba import JambaConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func with gate->router +def load_balancing_loss_func( + router_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + router_logits: + Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if router_logits is None or not isinstance(router_logits, tuple): + return 0 + + if isinstance(router_logits, tuple): + compute_device = router_logits[0].device + concatenated_router_logits = torch.cat( + [layer_router.to(compute_device) for layer_router in router_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + device_index = routing_weights.device.index if routing_weights.device.index is not None else 0 + rank = routing_weights.shape[1] * int(device_index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) + return overall_loss * num_experts + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba +class JambaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + JambaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class HybridMambaAttentionDynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.dtype = dtype + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_expand * config.hidden_size + ssm_state_size = config.mamba_d_state + conv_kernel_size = config.mamba_d_conv + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "mamba": + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + self.ssm_states += [ + torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + ] + else: + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba +class JambaAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: JambaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_values + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba +class JambaFlashAttention2(JambaAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_values + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba +class JambaSdpaAttention(JambaAttention): + """ + Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from JambaAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "JambaModel is using JambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = self.is_causal and causal_mask is None and q_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_values + + +JAMBA_ATTENTION_CLASSES = { + "eager": JambaAttention, + "flash_attention_2": JambaFlashAttention2, + "sdpa": JambaSdpaAttention, +} + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class JambaMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: JambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.time_step_rank = config.mamba_dt_rank + self.use_conv_bias = config.mamba_conv_bias + self.use_bias = config.mamba_proj_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=self.use_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.intermediate_size, + padding=self.conv_kernel_size - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + self.use_fast_kernels = config.use_mamba_kernels + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + batch_size, seq_len, _ = hidden_states.shape + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + ) + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + # We can't use `mamba_inner_fn` even if in training and without cache params because we have the + # inner layernorms which isn't supported by this fused kernel + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if use_precomputed_states: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + time_step = self.dt_layernorm(time_step) + B = self.b_layernorm(B) + C = self.c_layernorm(C) + + # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. + # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed + # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized + # linear layers, and requires to call the forward pass directly. + # Quantized model can't work with the original code: + # ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)``` + time_proj_bias = self.dt_proj.bias.data + with torch.no_grad(): + self.dt_proj.bias.data = torch.zeros_like(self.dt_proj.bias.data) + discrete_time_step = self.dt_proj(time_step).transpose(1, 2) + with torch.no_grad(): + self.dt_proj.bias.data = time_proj_bias + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None + if use_precomputed_states: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + + return contextualized_states + + # fmt: off + def slow_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache] = None, attention_mask: Optional[torch.LongTensor] = None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) + # 2. Convolution sequence transformation + if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: + if self.training: + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + else: + ssm_state = cache_params.ssm_states[self.layer_idx] + + ssm_state = ssm_state.to(hidden_states.device) + + if cache_params.has_previous_state and seq_len == 1 and \ + cache_params.conv_states[self.layer_idx].shape[0] == batch_size: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] + cache_params.conv_states[self.layer_idx] = conv_state + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding + else: + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx] = conv_state + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + time_step = self.dt_layernorm(time_step) + B = self.b_layernorm(B) + C = self.c_layernorm(C) + + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if use_cache: + cache_params.ssm_states[self.layer_idx] = ssm_state + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + if self.use_fast_kernels: + if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type: + raise ValueError( + "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device" + ) + return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) + return self.slow_forward(hidden_states, cache_params, attention_mask) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba +class JambaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba +class JambaSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config: JambaConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + + self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.experts = nn.ModuleList([JambaMLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.router(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class JambaAttentionDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: JambaConfig, layer_idx: int): + super().__init__() + num_experts = config.layers_num_experts[layer_idx] + self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config) + self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward (experts/MLP) + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + ff_outputs = self.feed_forward(hidden_states) + if isinstance(ff_outputs, tuple): + hidden_states, router_logits = ff_outputs + else: + hidden_states, router_logits = ff_outputs, None + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class JambaMambaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: JambaConfig, layer_idx: int): + super().__init__() + num_experts = config.layers_num_experts[layer_idx] + self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx) + + ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config) + self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_values, + attention_mask=attention_mask, + ) + self_attn_weights = None + + # residual connection after mamba + hidden_states = residual + hidden_states + + # feed-forward (experts/MLP) + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + ff_outputs = self.feed_forward(hidden_states) + if isinstance(ff_outputs, tuple): + hidden_states, router_logits = ff_outputs + else: + hidden_states, router_logits = ff_outputs, None + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (past_key_values,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@auto_docstring +class JambaPreTrainedModel(PreTrainedModel): + config: JambaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, JambaRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, JambaMambaMixer): + A = torch.arange(1, module.ssm_state_size + 1)[None, :] + A = A.expand(module.intermediate_size, -1).contiguous() + module.A_log.data.copy_(torch.log(A)) + module.D.data.fill_(1.0) + + +ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer} + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->Jamba +@auto_docstring +class JambaModel(JambaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JambaDecoderLayer`] + + Args: + config: JambaConfig + """ + + def __init__(self, config: JambaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] + decoder_layers.append(layer_class(config, layer_idx=i)) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + if layer_outputs[-1] is not None: + # append router logits only of expert layers. Regular MLP layers return `None` as the router logits + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba +class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: JambaConfig): + super().__init__(config) + self.model = JambaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, JambaForCausalLM + + >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or cache_position[-1] >= input_ids.shape[1] # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + + +class JambaForSequenceClassification(GenericForSequenceClassification, JambaPreTrainedModel): ... + + +__all__ = ["JambaForCausalLM", "JambaForSequenceClassification", "JambaModel", "JambaPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kosmos2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51dbdaf40a1050b8ff624d86ed58cd07fcd9174f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_kosmos2 import * + from .modeling_kosmos2 import * + from .processing_kosmos2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/kosmos2/configuration_kosmos2.py b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/configuration_kosmos2.py new file mode 100644 index 0000000000000000000000000000000000000000..56b26eb171780f5a2f1a022ad991f726773e8bb7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/configuration_kosmos2.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""KOSMOS-2 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Kosmos2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Kosmos2TextModel`]. It is used to instantiate a + KOSMOS-2 text decoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text decoder of the KOSMOS-2 + [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 65037): + Vocabulary size of the Kosmos2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Kosmos2Model`]. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + embed_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the layers and the pooler layer. + layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + ffn_dim (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(embed_dim). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + pad_token_id (`int`, *optional*, defaults to 1): + Token id used for padding. + bos_token_id (`int`, *optional*, defaults to 0): + Token id used for beginning of string. + eos_token_id (`int`, *optional*, defaults to 2): + Token id used for end of string. + ```""" + + model_type = "kosmos_2_text_model" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "attention_heads", + "hidden_size": "embed_dim", + "num_hidden_layers": "layers", + } + + def __init__( + self, + vocab_size=65037, + max_position_embeddings=2048, + embed_dim=2048, + layers=24, + ffn_dim=8192, + attention_heads=32, + activation_function="gelu", + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + layerdrop=0.0, + layer_norm_eps=1e-5, + init_std=0.02, + scale_embedding=True, + use_cache=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.embed_dim = embed_dim + self.layers = layers + self.ffn_dim = ffn_dim + self.attention_heads = attention_heads + self.activation_function = activation_function + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.init_std = init_std + self.scale_embedding = scale_embedding + self.use_cache = use_cache + + +class Kosmos2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Kosmos2VisionModel`]. It is used to instantiate a + KOSMOS-2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the KOSMOS-2 + [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + ```""" + + model_type = "kosmos_2_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=224, + patch_size=14, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class Kosmos2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Kosmos2Model`]. It is used to instantiate a + KOSMOS-2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the KOSMOS-2 + [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Kosmos2TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Kosmos2VisionConfig`]. + latent_query_num (`int`, *optional*, defaults to 64): + The number of latent query tokens that represent the image features used in the text decoder component. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Kosmos2Config, Kosmos2Model + + >>> # Initializing a Kosmos-2 kosmos-2-patch14-224 style configuration + >>> configuration = Kosmos2Config() + + >>> # Initializing a model (with random weights) from the kosmos-2-patch14-224 style configuration + >>> model = Kosmos2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "kosmos-2" + sub_configs = {"text_config": Kosmos2TextConfig, "vision_config": Kosmos2VisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + latent_query_num=64, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `Kosmos2TextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. Initializing the `Kosmos2VisionConfig` with default values.") + + self.text_config = Kosmos2TextConfig(**text_config) + self.vision_config = Kosmos2VisionConfig(**vision_config) + + self.latent_query_num = latent_query_num + + +__all__ = ["Kosmos2Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kosmos2/modeling_kosmos2.py b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/modeling_kosmos2.py new file mode 100644 index 0000000000000000000000000000000000000000..76acda9f0de98da05318f49a08c058222a4aaa75 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/modeling_kosmos2.py @@ -0,0 +1,1846 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch KOSMOS-2 model.""" + +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + CausalLMOutputWithCrossAttentions, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.deprecation import deprecate_kwarg +from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig + + +logger = logging.get_logger(__name__) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +class Kosmos2ModelOutput(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + projection_attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute + the weighted average in the self-attention heads. + vision_model_output (`BaseModelOutputWithPooling`, *optional*): + The output of the [`Kosmos2VisionModel`]. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_embeds: Optional[torch.FloatTensor] = None + projection_attentions: Optional[tuple[torch.FloatTensor]] = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Model output class for `Kosmos2ForConditionalGeneration`. + """ +) +class Kosmos2ForConditionalGenerationModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + projection_attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute + the weighted average in the self-attention heads. + vision_model_output (`BaseModelOutputWithPooling`, *optional*): + The output of the [`Kosmos2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_embeds: Optional[torch.FloatTensor] = None + projection_attentions: Optional[tuple[torch.FloatTensor]] = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Kosmos2 +class Kosmos2VisionEmbeddings(nn.Module): + def __init__(self, config: Kosmos2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." + ) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class Kosmos2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask + else: + self.is_causal = causal_attention_mask is not None + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision +class Kosmos2VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision +class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Kosmos2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Kosmos2VisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Kosmos2VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Kosmos2Vision +class Kosmos2VisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Kosmos2VisionEncoderLayer`]. + + Args: + config: Kosmos2VisionConfig + """ + + def __init__(self, config: Kosmos2VisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward` +class Kosmos2VisionTransformer(nn.Module): + # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPVision->Kosmos2Vision,ALTCLIP_VISION->KOSMOS2_VISION,AltCLIP->Kosmos2Vision + def __init__(self, config: Kosmos2VisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Kosmos2VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = Kosmos2VisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Similar to `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` but allowing to pass `position_ids` +class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__ + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values_length: int = 0, + position_ids: Optional[torch.Tensor] = None, + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + if position_ids is None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids( + input_ids, self.padding_idx, past_key_values_length + ).to(input_ids.device) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + if position_ids is None: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +class KosmosTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`. + def __init__( + self, + config, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: Optional[bool] = False, + add_inner_attn_layernorm: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, + ): + super().__init__() + self.config = config + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + # End opy + self.inner_attn_ln = None + if add_inner_attn_layernorm: + self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + batch_size, seq_length = hidden_states.shape[:2] + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + if self.inner_attn_ln is not None: + attn_output = self.inner_attn_ln(attn_output) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Kosmos2TextFFN(nn.Module): + def __init__(self, config: Kosmos2TextConfig): + super().__init__() + + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim) + self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim) + + self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.ffn_layernorm(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + return hidden_states + + +class Kosmos2TextBlock(GradientCheckpointingLayer): + def __init__(self, config: Kosmos2TextConfig, layer_idx=None): + super().__init__() + self.embed_dim = config.embed_dim + + self.self_attn = KosmosTextAttention( + config, + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + add_inner_attn_layernorm=True, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + if config.add_cross_attention: + self.encoder_attn = KosmosTextAttention( + config, + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + add_inner_attn_layernorm=False, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.ffn = Kosmos2TextFFN(config) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + if not hasattr(self, "encoder_attn"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + hidden_states = self.final_layer_norm(hidden_states) + + # FFN + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + return outputs + + +class Kosmos2TextTransformer(nn.Module): + """ + Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`]. + + Args: + config: Kosmos2TextConfig + """ + + def __init__(self, config: Kosmos2TextConfig): + super().__init__() + self.config = config + self.dropout = config.dropout + self.layerdrop = config.layerdrop + + self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id) + + self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding( + num_positions=config.max_position_embeddings, + embedding_dim=config.embed_dim, + padding_idx=config.pad_token_id, + ) + + self.layers = nn.ModuleList([Kosmos2TextBlock(config, layer_idx=i) for i in range(config.layers)]) + self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + + self.gradient_checkpointing = False + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward_embedding( + self, + input_ids, + inputs_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + img_input_mask: Optional[torch.Tensor] = None, + past_key_values_length: int = 0, + position_ids: Optional[torch.Tensor] = None, + ): + # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`. + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if image_embeds is not None: + inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view( + -1, image_embeds.size(-1) + ) + + inputs_embeds = inputs_embeds * self.embed_scale + + # embed positions + positions = self.embed_positions( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + position_ids=position_ids, + ) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + return hidden_states + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # We don't need img info. when `past_key_values_length` > 0 + if past_key_values_length > 0: + image_embeds = None + image_embeds_position_mask = None + + hidden_states = self.forward_embedding( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_embeds=image_embeds, + img_input_mask=image_embeds_position_mask, + past_key_values_length=past_key_values_length, + position_ids=position_ids, + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add final layer norm + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class Kosmos2PreTrainedModel(PreTrainedModel): + config: Kosmos2Config + supports_gradient_checkpointing = True + _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"] + _supports_attention_backend = True + _supports_flash_attn = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(self, Kosmos2VisionModel): + factor = self.config.initializer_factor + elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)): + factor = self.config.vision_config.initializer_factor + + if isinstance(self, (Kosmos2TextModel, Kosmos2TextForCausalLM)): + std = self.config.init_std + elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)): + std = self.config.text_config.init_std + + if isinstance(module, Kosmos2VisionEmbeddings): + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, Kosmos2VisionAttention): + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, Kosmos2VisionMLP): + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, KosmosTextAttention): + nn.init.normal_(module.q_proj.weight, std=std) + nn.init.normal_(module.k_proj.weight, std=std) + nn.init.normal_(module.v_proj.weight, std=std) + nn.init.normal_(module.out_proj.weight, std=std) + elif isinstance(module, Kosmos2TextFFN): + nn.init.normal_(module.fc1.weight, std=std) + nn.init.normal_(module.fc2.weight, std=std) + elif isinstance(module, Kosmos2TextForCausalLM): + nn.init.normal_(module.lm_head.weight, std=std) + elif isinstance(module, Kosmos2ImageToTextProjection): + nn.init.normal_(module.dense.weight, std=std) + nn.init.normal_(module.latent_query) + elif isinstance(module, Kosmos2TextTransformer): + module.embed_tokens.weight.data.normal_(mean=0.0, std=std) + if module.embed_tokens.padding_idx is not None: + module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class Kosmos2VisionModel(Kosmos2PreTrainedModel): + config: Kosmos2VisionConfig + main_input_name = "pixel_values" + + # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model + def __init__(self, config: Kosmos2VisionConfig): + super().__init__(config) + self.model = Kosmos2VisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.get_input_embeddings with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model + def get_input_embeddings(self) -> nn.Module: + return self.model.embeddings.patch_embedding + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + return self.model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + +class Kosmos2TextModel(Kosmos2PreTrainedModel): + config: Kosmos2TextConfig + + def __init__(self, config: Kosmos2TextConfig): + super().__init__(config) + self.model = Kosmos2TextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.model.embed_tokens + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, + 1]`: + + - 1 for places where to put the image features, + - 0 for places that are not for image features (i.e. for text tokens). + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + +@auto_docstring( + custom_intro=""" + The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """ +) +class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): + config: Kosmos2TextConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Kosmos2TextConfig): + super().__init__(config) + + self.model = Kosmos2TextTransformer(config) + self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.model.embed_tokens + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, + 1]`: + + - 1 for places where to put the image features, + - 0 for places that are not for image features (i.e. for text tokens). + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + lm_logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + image_embeds=None, + image_embeds_position_mask=None, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + cache_position=None, + **model_kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + if cache_position[0] != 0: + image_embeds = None + image_embeds_position_mask = None + + # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) + elif image_embeds_position_mask is not None: + batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size() + mask_len = image_embeds_position_mask.size()[-1] + image_embeds_position_mask = torch.cat( + ( + image_embeds_position_mask, + torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device), + ), + dim=1, + ) + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **model_kwargs, + ) + # Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer + model_inputs.pop("position_ids", None) + + return model_inputs + + +class Kosmos2ImageToTextProjection(nn.Module): + """The layer that transforms the image model's output to part of the text model's input (namely, image features)""" + + def __init__(self, config: Kosmos2Config): + super().__init__() + self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim) + self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim)) + + self.x_attn = KosmosTextAttention( + config.text_config, + config.text_config.embed_dim, + config.text_config.attention_heads, + dropout=config.text_config.attention_dropout, + is_decoder=False, + add_inner_attn_layernorm=False, + ) + + def forward(self, features): + hidden_states = self.dense(features) + + # shape = [batch, latent_query_num, h_dim] + latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1) + key_value_states = torch.cat([hidden_states, latent_query], dim=1) + + hidden_states, attn_weights = self.x_attn( + hidden_states=latent_query, + encoder_hidden_states=key_value_states, + past_key_values=None, + attention_mask=None, + output_attentions=None, + ) + + return hidden_states, attn_weights + + +@auto_docstring( + custom_intro=""" + KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model. + """ +) +class Kosmos2Model(Kosmos2PreTrainedModel): + config: Kosmos2Config + main_input_name = "pixel_values" + + def __init__(self, config: Kosmos2Config): + super().__init__(config) + + self.text_model = Kosmos2TextModel(config.text_config) + self.vision_model = Kosmos2VisionModel(config.vision_config) + self.image_to_text_projection = Kosmos2ImageToTextProjection(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.model.embed_tokens + + def set_input_embeddings(self, value): + self.text_model.model.embed_tokens = value + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + return_attentions: Optional[bool] = False, + interpolate_pos_encoding: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + return_attentions (`bool`, *optional*, defaults to `False`): + Whether to return `projection_attentions` or not. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate positional embeddings or not. + """ + vision_model_output = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. + image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) + # normalized features + image_embeds = nn.functional.normalize(image_embeds, dim=-1) + image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) + + if return_attentions: + return image_embeds, projection_attentions + return image_embeds + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + image_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, Kosmos2ModelOutput]: + r""" + image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, + 1]`: + + - 1 for places where to put the image features, + - 0 for places that are not for image features (i.e. for text tokens). + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Kosmos2Model + + >>> model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224") + >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") + + >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = ( + ... " An image of a snowman" + ... " warming himself by a fire" + ... "" + ... ) + + >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True) + + >>> last_hidden_state = model( + ... pixel_values=inputs["pixel_values"], + ... input_ids=inputs["input_ids"], + ... attention_mask=inputs["attention_mask"], + ... image_embeds_position_mask=inputs["image_embeds_position_mask"], + ... ).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 91, 2048] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_model_output = None + projection_attentions = None + if image_embeds is None: + if pixel_values is None: + raise ValueError("You have to specify either `pixel_values` or `image_embeds`.") + image_embeds, projection_attentions = self.get_image_features( + pixel_values, return_attentions=True, interpolate_pos_encoding=interpolate_pos_encoding + ) + + outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + **kwargs, + ) + + return Kosmos2ModelOutput( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_embeds=image_embeds, + projection_attentions=projection_attentions, + vision_model_output=vision_model_output, + ) + + +@auto_docstring( + custom_intro=""" + KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a + language model. + """ +) +class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): + config: Kosmos2Config + main_input_name = "pixel_values" + _tied_weights_keys = ["text_model.lm_head.weight"] + + def __init__(self, config: Kosmos2Config): + super().__init__(config) + + self.text_model = Kosmos2TextForCausalLM(config.text_config) + self.vision_model = Kosmos2VisionModel(config.vision_config) + + self.image_to_text_projection = Kosmos2ImageToTextProjection(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.model.embed_tokens + + def set_input_embeddings(self, value): + self.text_model.model.embed_tokens = value + + def get_output_embeddings(self) -> nn.Module: + return self.text_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.text_model.set_output_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + image_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Kosmos2ForConditionalGenerationModelOutput]: + r""" + image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, + 1]`: + + - 1 for places where to put the image features, + - 0 for places that are not for image features (i.e. for text tokens). + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration + + >>> model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224") + >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") + + >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> prompt = " An image of" + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> generated_ids = model.generate( + ... pixel_values=inputs["pixel_values"], + ... input_ids=inputs["input_ids"], + ... attention_mask=inputs["attention_mask"], + ... image_embeds=None, + ... image_embeds_position_mask=inputs["image_embeds_position_mask"], + ... use_cache=True, + ... max_new_tokens=64, + ... ) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False) + >>> processed_text + ' An image of a snowman warming himself by a fire.' + + >>> caption, entities = processor.post_process_generation(generated_text) + >>> caption + 'An image of a snowman warming himself by a fire.' + + >>> entities + [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_model_output = None + projection_attentions = None + if image_embeds is None: + if pixel_values is None: + raise ValueError("You have to specify either `pixel_values` or `image_embeds`.") + + vision_model_output = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. + image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) + # normalized features + image_embeds = nn.functional.normalize(image_embeds, dim=-1) + image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) + + lm_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + **kwargs, + ) + + return Kosmos2ForConditionalGenerationModelOutput( + loss=lm_outputs.loss, + logits=lm_outputs.logits, + past_key_values=lm_outputs.past_key_values, + hidden_states=lm_outputs.hidden_states, + attentions=lm_outputs.attentions, + image_embeds=image_embeds, + projection_attentions=projection_attentions, + vision_model_output=vision_model_output, + ) + + @torch.no_grad() + def generate( + self, + pixel_values: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + # in order to allow `inputs` argument (as in `GenerationMixin`) + inputs = kwargs.pop("inputs", None) + if pixel_values is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed." + f"Make sure to either pass `inputs` or pixel_values=..." + ) + if pixel_values is None and inputs is not None: + pixel_values = inputs + + if image_embeds is None: + vision_model_output = self.vision_model(pixel_values) + # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. + image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) + # normalized features + image_embeds = nn.functional.normalize(image_embeds, dim=-1) + image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) + + output = self.text_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return output + + +__all__ = ["Kosmos2ForConditionalGeneration", "Kosmos2Model", "Kosmos2PreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kosmos2/processing_kosmos2.py b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/processing_kosmos2.py new file mode 100644 index 0000000000000000000000000000000000000000..58b3dff1e07a545973f53d03b49f284ca5a3cee2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kosmos2/processing_kosmos2.py @@ -0,0 +1,696 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for KOSMOS-2.""" + +import copy +import math +import re +from typing import Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils import AddedToken +from ...tokenization_utils_base import BatchEncoding, TextInput + + +BboxInput = Union[ + list[tuple[int, int]], + list[tuple[float, float, float, float]], + list[list[tuple[int, int]]], + list[list[tuple[float, float, float]]], +] + + +class Kosmos2ImagesKwargs(ImagesKwargs, total=False): + bboxes: Optional[list[float]] + num_image_tokens: Optional[int] + first_image_token_id: Optional[int] + + +class Kosmos2TextKwargs(TextKwargs, total=False): + add_eos_token: Optional[bool] + + +class Kosmos2ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: Kosmos2TextKwargs + images_kwargs: Kosmos2ImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "verbose": True, + "add_eos_token": False, + }, + "images_kwargs": { + "num_image_tokens": 64, + }, + } + + +class Kosmos2Processor(ProcessorMixin): + r""" + Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single + processor. + + [`Kosmos2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and some functionalities of + [`XLMRobertaTokenizerFast`]. See the docstring of [`~Kosmos2Processor.__call__`] and [`~Kosmos2Processor.decode`] + for more information. + + Args: + image_processor (`CLIPImageProcessor`): + An instance of [`CLIPImageProcessor`]. The image processor is a required input. + tokenizer (`XLMRobertaTokenizerFast`): + An instance of ['XLMRobertaTokenizerFast`]. The tokenizer is a required input. + num_patch_index_tokens (`int`, *optional*, defaults to 1024): + The number of tokens that represent patch indices. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast") + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs): + tokenizer.return_token_type_ids = False + + self.eod_token = "" + + self.boi_token = "" + self.eoi_token = "" + + self.eoc_token = "" + self.eol_token = "" + + self.bop_token = "" + self.eop_token = "" + + self.boo_token = "" + self.eoo_token = "" + + self.dom_token = "" + + self.grd_token = "" + + self.tag_tokens = [ + self.eod_token, + self.boi_token, + self.eoi_token, + self.eoc_token, + self.eol_token, + self.bop_token, + self.eop_token, + self.boo_token, + self.eoo_token, + self.dom_token, + self.grd_token, + ] + + self.num_patch_index_tokens = num_patch_index_tokens + patch_index_tokens = [f"" for x in range(self.num_patch_index_tokens)] + + tokens_to_add = [] + for token in self.tag_tokens + patch_index_tokens: + tokens_to_add.append(AddedToken(token, lstrip=True, rstrip=False, normalized=False)) + tokenizer.add_tokens(tokens_to_add) + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, list[TextInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Kosmos2ProcessorKwargs], + ) -> BatchFeature: + """ + This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and + [`XLMRobertaTokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + + The rest of this documentation shows the arguments specific to `Kosmos2Processor`. + + Args: + bboxes (`Union[list[tuple[int]], list[tuple[float]], list[list[tuple[int]]], list[list[tuple[float]]]]`, *optional*): + The bounding bboxes associated to `texts`. + num_image_tokens (`int`, *optional* defaults to 64): + The number of (consecutive) places that are used to mark the placeholders to store image information. + This should be the same as `latent_query_num` in the instance of `Kosmos2Config` you are using. + first_image_token_id (`int`, *optional*): + The token id that will be used for the first place of the subsequence that is reserved to store image + information. If unset, will default to `self.tokenizer.unk_token_id + 1`. + add_eos_token (`bool`, defaults to `False`): + Whether or not to include `EOS` token id in the encoding when `add_special_tokens=True`. + """ + if images is None and text is None: + raise ValueError("You have to specify either images or text.") + + output_kwargs = self._merge_kwargs( + Kosmos2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + bboxes = output_kwargs["images_kwargs"].pop("bboxes", None) + num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens", 64) + first_image_token_id = output_kwargs["images_kwargs"].pop("first_image_token_id", None) + add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False) + + add_special_tokens = output_kwargs["text_kwargs"]["add_special_tokens"] + padding = output_kwargs["text_kwargs"]["padding"] + return_tensors = output_kwargs["text_kwargs"].setdefault("return_tensors", None) + + encoding = BatchFeature() + + if images is not None: + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) + encoding.update(image_encoding) + + if text is not None: + text = self.preprocess_examples(text, images, bboxes, num_image_tokens=num_image_tokens) + + if add_special_tokens and not add_eos_token: + if isinstance(text, str): + text = f"{self.tokenizer.bos_token}{text}" + elif isinstance(text, list): + text = [f"{self.tokenizer.bos_token}{s}" for s in text] + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token + ) + output_kwargs["text_kwargs"]["padding"] = padding if images is None else False + output_kwargs["text_kwargs"]["return_tensors"] = return_tensors if images is None else None + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + encoding.update(text_encoding) + + output_kwargs["text_kwargs"]["add_special_tokens"] = add_special_tokens + output_kwargs["text_kwargs"]["padding"] = padding + output_kwargs["text_kwargs"]["return_tensors"] = return_tensors + + if text is not None and images is not None: + # Use the id of the first token after + if first_image_token_id is None: + first_image_token_id = self.tokenizer.unk_token_id + 1 + + # To see if we need one more `0` (for ``) at the beginning of `image_embeds_position_mask`. + with_bos = add_special_tokens + + # The first (actual) `` token is always at the 1st or 2nd place (after `` if any). Here we look + # for the second `` token (which indicate the first image token). + start_index = int(with_bos) + 1 + + # Add `image_embeds_position_mask`: the leading and trailing `0` are for `boi` and `eoi` tokens. The `1` indicates + # the places of image tokens. + image_token_ids = list(range(first_image_token_id, first_image_token_id + num_image_tokens)) + base_image_embeds_position_mask = [0] + [1] * num_image_tokens + [0] + + # loop over `encoding["input_ids"]` + input_ids = [] + image_embeds_position_mask = [] + all_input_ids = encoding["input_ids"] + # not batched -> (changed to) batch of size 1 + if isinstance(text, str): + all_input_ids = [all_input_ids] + encoding["attention_mask"] = [encoding["attention_mask"]] + for text_ids in all_input_ids: + # change the ids for the fake `` tokens in `input_ids` + text_ids = text_ids[:start_index] + image_token_ids + text_ids[start_index + num_image_tokens :] + input_ids.append(text_ids) + + mask = copy.copy(base_image_embeds_position_mask) + if with_bos: + # for `` + mask = [0] + mask + # trailing part (which are not related to the image) + mask += [0] * (len(text_ids) - len(mask)) + image_embeds_position_mask.append(mask) + + if isinstance(text, list): + sorted_length = sorted( + [(idx, len(x)) for idx, x in enumerate(text_encoding.input_ids)], key=lambda x: x[-1] + ) + _, min_len_not_padded = sorted_length[0] + idx, _ = sorted_length[-1] + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token + ) + output_kwargs["text_kwargs"]["return_tensors"] = None + + text_encoding = self.tokenizer(text=[text[idx]], **output_kwargs["text_kwargs"]) + max_len_padded = len(text_encoding.input_ids[0]) + + if min_len_not_padded != max_len_padded: + if self.tokenizer.padding_side == "right": + input_ids = [x + [self.tokenizer.pad_token_id] * (max_len_padded - len(x)) for x in input_ids] + image_embeds_position_mask = [ + x + [0] * (max_len_padded - len(x)) for x in image_embeds_position_mask + ] + encoding["attention_mask"] = [ + x + [0] * (max_len_padded - len(x)) for x in encoding["attention_mask"] + ] + elif self.tokenizer.padding_side == "left": + input_ids = [[self.tokenizer.pad_token_id] * (max_len_padded - len(x)) + x for x in input_ids] + image_embeds_position_mask = [ + [0] * (max_len_padded - len(x)) + x for x in image_embeds_position_mask + ] + encoding["attention_mask"] = [ + [0] * (max_len_padded - len(x)) + x for x in encoding["attention_mask"] + ] + + # un-batch if necessary + if isinstance(text, str) and return_tensors is None: + input_ids = input_ids[0] + encoding["attention_mask"] = encoding["attention_mask"][0] + image_embeds_position_mask = image_embeds_position_mask[0] + + # update (with the target tensor type if specified) + encoding.update( + BatchEncoding( + data={ + "input_ids": input_ids, + "attention_mask": encoding["attention_mask"], + "image_embeds_position_mask": image_embeds_position_mask, + }, + tensor_type=return_tensors, + ) + ) + + return encoding + + def _check_bboxes_for_single_text(self, bboxes): + """ + Check `bboxes` for a single text example. It could be + - `None`: no bounding box associated to a text. + - A list with each element being the bounding boxes associated to one ` ... ` pair found + in a text. This could be: + - `None`: no bounding box associated to a ` ... ` pair. + - A tuple of 2 integers: A single bounding box specified by patch indices. + - A tuple of 4 float point number: A single bounding box specified by (normalized) coordinates. + - A list containing the above 2 tuple types: Multiple bounding boxes for a + ` ... ` pair. + """ + if bboxes is None: + return + elif not isinstance(bboxes, list): + raise ValueError("`bboxes` (for a single text example) should be `None` or a list.") + + # `bbox` is the bounding boxes for a single pair + for bbox in bboxes: + if bbox is None: + continue + elif not isinstance(bbox, list): + bbox = [bbox] + for element in bbox: + if not isinstance(element, tuple) or not ( + (len(element) == 2 and all(isinstance(x, int) for x in element)) + or (len(element) == 4 and all(isinstance(x, float) for x in element)) + ): + raise ValueError( + "Each element in `bboxes` (for a single text example) should be either `None`, a tuple containing " + "2 integers or 4 float point numbers, or a list containing such tuples. Also " + "make sure the arguments `texts` and `bboxes` passed to `preprocess_text` are both in " + "batches or both for a single example." + ) + + def _preprocess_single_example(self, text, image, bboxes, img_info_tokens): + text = text.strip() + if image is not None: + # Add ` ... (fake) image tokens ... ` + text = f"{img_info_tokens} {text}" + + # Add ` ` after ` phrase text ` + text = self._insert_patch_index_tokens(text, bboxes) + return text + + def preprocess_examples( + self, + texts: Union[TextInput, list[TextInput]], + images: Optional[ImageInput] = None, + bboxes: BboxInput = None, + num_image_tokens: Optional[int] = 64, + ) -> Union[str, list[str]]: + """Add image and bounding box information to `texts` as image and patch index tokens. + + Args: + texts (`Union[TextInput, list[TextInput]]`): The texts to be processed. + images (`ImageInput`, *optional*): The images associated to `texts`. + bboxes (`Union[list[tuple[int]], list[tuple[float]], list[list[tuple[int]]], list[list[tuple[float]]]]`, *optional*): + The bounding bboxes associated to `texts`. + num_image_tokens (`int`, *optional*, defaults to 64): + The number of image tokens (used as latent queries). This should corresponds to the `latent_query_num` + attribute in `Kosmos2Config`. + + Returns: + `Union[TextInput, list[TextInput]]`: The processed texts with image and patch index tokens. + """ + # These are fake `` tokens enclosed between (the actual) `` token and ``. + img_tokens = [self.boi_token] * num_image_tokens + img_info_tokens = " ".join([self.boi_token] + img_tokens + [self.eoi_token]) + + # make batch to simplify processing logic + batched = True + if isinstance(texts, str): + batched = False + texts = [texts] + + if images is None: + images = [None] * len(texts) + elif not isinstance(images, list): + images = [images] + if len(texts) != len(images): + raise ValueError( + f"The number of examples in `texts` and `images` should be the same. Got {len(texts)} v.s. {len(images)} instead." + ) + + if not batched: + self._check_bboxes_for_single_text(bboxes) + bboxes = [bboxes] + elif bboxes is not None: + if not isinstance(bboxes, list): + raise ValueError("`bboxes` should be `None` or a list (as a batch) when `texts` is passed as a batch.") + for x in bboxes: + self._check_bboxes_for_single_text(x) + else: + bboxes = [None] * len(texts) + + if len(bboxes) != len(texts): + raise ValueError( + f"The number of examples in `texts` and `bboxes` should be the same. Got {len(texts)} v.s. {len(bboxes)} instead." + ) + + result = [ + self._preprocess_single_example(text, image, bbox, img_info_tokens) + for text, image, bbox in zip(texts, images, bboxes) + ] + # un-batch if necessary + if not batched: + result = result[0] + + return result + + def post_process_generation(self, text, cleanup_and_extract=True): + caption = text.split(self.eoi_token)[-1] + if cleanup_and_extract: + return clean_text_and_extract_entities_with_bboxes(caption) + return caption + + def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts] + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return tokenizer_input_names + image_processor_input_names + ["image_embeds_position_mask"] + + def _insert_patch_index_tokens(self, text: str, bboxes: Union[list[tuple[int]], list[tuple[float]]]) -> str: + if bboxes is None or len(bboxes) == 0: + return text + + matched_phrases = list(re.finditer(r".+?", string=text)) + if len(matched_phrases) != len(bboxes): + raise ValueError( + f"The number of elements in `bboxes` should be the same as the number of ` ... ` pairs in `text`. Got {len(matched_phrases)} v.s. {len(bboxes)} instead." + ) + + # insert object's patch index tokens + # the found ` ... ` pairs. + curr_pos = 0 + buffer = [] + for matched, bbox in zip(matched_phrases, bboxes): + _, end = matched.span() + buffer.append(text[curr_pos:end]) + curr_pos = end + # A phrase without bbox + if bbox is None: + continue + # A phrase with a single bbox + if isinstance(bbox, tuple): + bbox = [bbox] + patch_index_strings = [] + # A phrase could have multiple bboxes + if not all(box is not None for box in bbox): + raise ValueError( + "The multiple bounding boxes for a single phrase should not contain any `None` value." + ) + for box in bbox: + patch_index_1, patch_index_2 = self._convert_bbox_to_patch_index_tokens(box) + patch_index_strings.append(f"{patch_index_1} {patch_index_2}") + # `bbox` being an empty list + if len(patch_index_strings) == 0: + continue + position_str = " ".join(patch_index_strings) + buffer.append(f" {position_str} ") + # remaining + if curr_pos < len(text): + buffer.append(text[curr_pos:]) + + text = "".join(buffer) + return text + + def _convert_bbox_to_patch_index_tokens( + self, bbox: Union[tuple[int, int], tuple[float, float, float, float]] + ) -> tuple[str, str]: + # already computed patch indices + if len(bbox) == 2: + idx_1, idx_2 = bbox + # bbox specified with (normalized) coordinates + else: + # use `self.tokenizer` to get `num_patches_per_side` + num_patches_per_side = int(math.sqrt(self.num_patch_index_tokens)) + idx_1, idx_2 = coordinate_to_patch_index(bbox, num_patches_per_side) + + token_1 = f"" + token_2 = f"" + + return token_1, token_2 + + +def coordinate_to_patch_index(bbox: tuple[float, float, float, float], num_patches_per_side: int) -> tuple[int, int]: + """Convert a bounding box to a pair of patch indices. + + Args: + bbox (`tuple[float, float, float, float]`): + The 4 coordinates of the bounding box, with the format being (x1, y1, x2, y2) specifying the upper-left and + lower-right corners of the box. It should have x2 > x1 and y2 > y1. + num_patches_per_side (`int`): the number of patches along each side. + + Returns: + `tuple[int, int]`: A pair of patch indices representing the upper-left patch and lower-right patch. + """ + (x1, y1, x2, y2) = bbox + + if not (x2 > x1 and y2 > y1): + raise ValueError("The coordinates in `bbox` should be `(x1, y1, x2, y2)` with `x2 > x1` and `y2 > y1`.") + + ul_x = math.floor(x1 * num_patches_per_side) + ul_y = math.floor(y1 * num_patches_per_side) + + lr_x = math.ceil(x2 * num_patches_per_side - 1) + lr_y = math.ceil(y2 * num_patches_per_side - 1) + + ul_idx = ul_y * num_patches_per_side + ul_x + lr_idx = lr_y * num_patches_per_side + lr_x + + return ul_idx, lr_idx + + +# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38 +# (with format modifications) +def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int): + """ + Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a + bounding box, returns the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2). + + Args: + ul_idx (`int`): the index of the grid cell that corresponds to the upper-left corner of the bounding box. + lr_idx (`int`): the index of the grid cell that corresponds to the lower-right corner of the bounding box. + num_patches_per_side (`int`): the number of patches along each side. + + Returns: + `tuple[float]`: the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2). + """ + # Compute the size of each cell in the grid + cell_size = 1.0 / num_patches_per_side + + # Compute the x and y indices of the upper-left and lower-right corners of the bounding box + ul_x = ul_idx % num_patches_per_side + ul_y = ul_idx // num_patches_per_side + + lr_x = lr_idx % num_patches_per_side + lr_y = lr_idx // num_patches_per_side + + # Compute the normalized coordinates of the bounding box + if ul_idx == lr_idx: + x1 = ul_x * cell_size + y1 = ul_y * cell_size + x2 = lr_x * cell_size + cell_size + y2 = lr_y * cell_size + cell_size + elif ul_x == lr_x or ul_y == lr_y: + x1 = ul_x * cell_size + y1 = ul_y * cell_size + x2 = lr_x * cell_size + cell_size + y2 = lr_y * cell_size + cell_size + else: + x1 = ul_x * cell_size + cell_size / 2 + y1 = ul_y * cell_size + cell_size / 2 + x2 = lr_x * cell_size + cell_size / 2 + y2 = lr_y * cell_size + cell_size / 2 + + return x1, y1, x2, y2 + + +# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L4-L33 +# (with format modifications) +def extract_entities_with_patch_indices(text): + """Extract entities contained in `text`. The bounding bboxes is given in the form of patch indices. + + This function is only intended to be used within `clean_text_and_extract_entities_with_bboxes` where further + processing happens, including converting to normalized coordinates and whitespace character cleaning up. + + Examples: + + ```python + >>> text = " An image of a snowman warming himself by a fire." + >>> entities = extract_entities_with_patch_indices(text) + >>> entities + [(' a snowman', (31, 41), [(44, 863)]), (' a fire', (130, 137), [(5, 911)])] + ```""" + # The regular expression pattern for matching the required formats + pattern = r"(?:(([^<]+)))?((?:)*)" + + # Find all matches in the given string + matches = re.finditer(pattern, text) + + # Initialize an empty list to store the valid patch_index combinations + entities_with_patch_indices = [] + + for match in matches: + # span of a `phrase` that is between and + span = match.span(2) + phrase_tag, phrase, match_content = match.groups() + if not phrase_tag: + phrase = None + # We take the starting position of `` + span = (match.span(0)[0], match.span(0)[0]) + + # Split the match_content by the delimiter to get individual patch_index pairs + patch_index_pairs = match_content.split("") + + entity_bboxes = [] + for pair in patch_index_pairs: + # Extract the xxxx and yyyy values from the patch_index pair + x = re.search(r"", pair) + y = re.search(r"", pair[1:]) + + if x and y: + if phrase: + entity_bboxes.append((int(x.group(1)), int(y.group(1)))) + else: + entity_bboxes.append((int(x.group(1)), int(y.group(1)))) + + if phrase: + entities_with_patch_indices.append((phrase, span, entity_bboxes)) + else: + for bbox in entity_bboxes: + # fake entity name + entity = f"" + entities_with_patch_indices.append((entity, span, [bbox])) + + return entities_with_patch_indices + + +def adjust_entity_positions(entity, text): + """Adjust the positions of the entities in `text` to be relative to the text with special fields removed.""" + entity_name, (start, end) = entity + # computed the length of strings with special fields (tag tokens, patch index tokens, etc.) removed + adjusted_start = len(re.sub("<.*?>", "", text[:start])) + adjusted_end = len(re.sub("<.*?>", "", text[:end])) + adjusted_entity = (entity_name, (adjusted_start, adjusted_end)) + return adjusted_entity + + +def _cleanup_spaces(text, entities): + """Remove the spaces around the text and the entities in it.""" + new_text = text.strip() + leading_spaces = len(text) - len(text.lstrip()) + + new_entities = [] + for entity_name, (start, end), bboxes in entities: + entity_name_leading_spaces = len(entity_name) - len(entity_name.lstrip()) + entity_name_trailing_spaces = len(entity_name) - len(entity_name.rstrip()) + + start = start - leading_spaces + entity_name_leading_spaces + end = end - leading_spaces - entity_name_trailing_spaces + entity_name = entity_name.strip() + + new_entities.append((entity_name, (start, end), bboxes)) + + return new_text, new_entities + + +# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L77-L87 +# (with format modifications) +def clean_text_and_extract_entities_with_bboxes(text, num_patches_per_side=32): + """Remove the tag tokens from `text`, extract entities in it with some cleaning up of white characters. + + Examples: + + ```python + >>> text = " An image of a snowman warming himself by a fire." + >>> clean_text, entities = clean_text_and_extract_entities_with_bboxes(text) + >>> clean_text + 'An image of a snowman warming himself by a fire.' + + >>> entities + [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])] + ```""" + # remove special fields (tag tokens, patch index tokens, etc.) + processed_text = re.sub("<.*?>", "", text) + + entities_with_patch_indices = extract_entities_with_patch_indices(text) + entities = [] + for item in entities_with_patch_indices: + entity, bboxes = item[0:2], item[2] + adjusted_entity = adjust_entity_positions(entity, text) + bboxes_in_coords = [patch_index_to_coordinate(bbox[0], bbox[1], num_patches_per_side) for bbox in bboxes] + + entities.append(adjusted_entity + (bboxes_in_coords,)) + + return _cleanup_spaces(processed_text, entities) + + +__all__ = ["Kosmos2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5823883c6cb8ba35c0964d62d4f32e5ca3d24d64 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_kyutai_speech_to_text import * + from .feature_extraction_kyutai_speech_to_text import * + from .modeling_kyutai_speech_to_text import * + from .processing_kyutai_speech_to_text import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..8693fd66679d6e6a747eaf30627d65d8f48a0ca7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.s + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`KyutaiSpeechToTextForConditionalGeneration`]. + It is used to instantiate a Kyutai Speech-to-Text model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + 2.6b-en model. + + e.g. [kyutai/stt-2.6b-en-trfs](https://huggingface.co/kyutai/stt-2.6b-en-trfs) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + codebook_vocab_size (`int`, *optional*, defaults to 2049): + Vocabulary size of the codebook. Defines the number of different audio tokens that can be represented by each codebook. + vocab_size (`int`, *optional*, defaults to 4001): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `input_ids` passed when calling the model. + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the layers and the pooler layer of the main decoder. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of decoder layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the main decoder block. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + max_position_embeddings (`int`, *optional*, defaults to 750): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + rope_theta (`float`, *optional*, defaults to 100000.0): + The base period of the RoPE embeddings. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + sliding_window (`int`, *optional*, defaults to 375): + Sliding window attention window size. If not specified, will default to `3000`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ffn_dim (`int`, *optional*, defaults to 11264): + Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even. + rms_norm_eps (`float`, *optional*, defaults to 1e-08): + The epsilon used by the rms normalization layers. + num_codebooks (`int`, *optional*, defaults to 32): + The number of audio codebooks for each audio channels. + audio_bos_token_id (`int`, *optional*, defaults to 2048): + Beginning of stream token id for codebook tokens. + audio_pad_token_id (`int`, *optional*, defaults to 69569): + Padding token id for codebook tokens. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings. + pad_token_id (`int`, *optional*, defaults to 3): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 48000): + Beginning of stream token id for text tokens. + codec_config (`PretrainedConfig`, *optional*): + Configuration for the codec. + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + - **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the audio encoder config. + - **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the depth decoder config. + + + Example: + ```python + >>> from transformers import KyutaiSpeechToTextConfig, KyutaiSpeechToTextForConditionalGeneration + + >>> # Initializing a KyutaiSpeechToTextConfig + >>> configuration = KyutaiSpeechToTextConfig() + + >>> # Initializing a model + >>> model = KyutaiSpeechToTextForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "kyutai_speech_to_text" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"codec_config": AutoConfig} + + def __init__( + self, + codebook_vocab_size=2049, + vocab_size=4001, + hidden_size=2048, + num_hidden_layers=48, + num_attention_heads=32, + num_key_value_heads=None, + max_position_embeddings=750, + rope_theta=100000.0, + hidden_act="silu", + head_dim=None, + initializer_range=0.02, + use_cache=True, + sliding_window=375, + attention_dropout=0.0, + ffn_dim=11264, + rms_norm_eps=1e-8, + num_codebooks=32, + audio_bos_token_id=2048, + audio_pad_token_id=69569, + tie_word_embeddings=False, + pad_token_id=3, + bos_token_id=48000, + codec_config=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, bos_token_id=bos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs + ) + + if codec_config is None: + self.codec_config = AutoConfig.for_model("mimi") + logger.info("codec_config is None, using default audio encoder config.") + elif isinstance(codec_config, dict): + self.codec_config = AutoConfig.for_model(**codec_config) + elif isinstance(codec_config, PretrainedConfig): + self.codec_config = codec_config + + self.num_codebooks = num_codebooks + self.frame_size = self.codec_config.frame_size + + self.audio_bos_token_id = audio_bos_token_id + self.audio_pad_token_id = audio_pad_token_id + self.codebook_vocab_size = codebook_vocab_size + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + if ffn_dim % 2 == 1: + raise ValueError(f"`ffn_dim={ffn_dim}` must be even.") + self.ffn_dim = ffn_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.sliding_window = sliding_window + + +__all__ = ["KyutaiSpeechToTextConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..d076ccb1de787c425c54f8e9baadb99d4c58e6e6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py @@ -0,0 +1,237 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs an KyutaiSpeechToText feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + chunk_length_s (`float`, *optional*): + If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. + overlap (`float`, *optional*): + Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following + formulae : `int((1.0 - self.overlap) * self.chunk_length)`. + audio_delay_seconds (`float`, *optional*, defaults to 0.0): + The delay in seconds to add after the audio (right padding). + audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0): + The silence prefix in seconds to add before the audio (left padding). + """ + + model_input_names = ["input_values", "padding_mask"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 24000, + padding_value: float = 0.0, + chunk_length_s: Optional[float] = None, + overlap: Optional[float] = None, + audio_delay_seconds: Optional[float] = 0.0, + audio_silence_prefix_seconds: Optional[float] = 0.0, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.chunk_length_s = chunk_length_s + self.overlap = overlap + self.audio_delay_seconds = audio_delay_seconds + self.audio_silence_prefix_seconds = audio_silence_prefix_seconds + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_length(self) -> Optional[int]: + if self.chunk_length_s is None: + return None + else: + return int(self.chunk_length_s * self.sampling_rate) + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_stride(self) -> Optional[int]: + if self.chunk_length_s is None or self.overlap is None: + return None + else: + return max(1, int((1.0 - self.overlap) * self.chunk_length)) + + def __call__( + self, + raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Optional[Union[bool, str, PaddingStrategy]] = None, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio + (`feature_size = 2`). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, *optional*, defaults to `False`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if padding and truncation: + raise ValueError("Both padding and truncation were set. Make sure you only set one.") + elif padding is None: + # by default let's pad the inputs + padding = True + + is_batched = bool( + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] + elif not is_batched and not isinstance(raw_audio, np.ndarray): + raw_audio = np.asarray(raw_audio, dtype=np.float32) + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): + raw_audio = raw_audio.astype(np.float32) + + # always return batch + if not is_batched: + raw_audio = [np.asarray(raw_audio).T] + + # verify inputs are valid + for idx, example in enumerate(raw_audio): + if example.ndim > 2: + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") + if self.feature_size == 1 and example.ndim != 1: + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") + if self.feature_size == 2 and example.shape[-1] != 2: + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") + + padded_inputs = None + input_values = BatchFeature({"input_values": raw_audio}) + if self.chunk_stride is not None and self.chunk_length is not None and max_length is None: + if truncation: + max_length = min(array.shape[0] for array in raw_audio) + nb_step = int(np.floor(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + elif padding: + max_length = max(array.shape[0] for array in raw_audio) + nb_step = int(np.ceil(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + padding = "max_length" + else: + padded_inputs = input_values + + # normal padding on batch + if padded_inputs is None: + padded_inputs = self.pad( + input_values, + max_length=max_length, + truncation=truncation, + padding=padding, + return_attention_mask=padding, + ) + + if padding: + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") + + # now let's pad left and right + pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate) + pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate) + padded_inputs["input_values"] = np.pad( + padded_inputs["input_values"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0.0, + ) + if padding: + padded_inputs["padding_mask"] = np.pad( + padded_inputs["padding_mask"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0, + ) + + input_values = [] + for example in padded_inputs.pop("input_values"): + if self.feature_size == 1: + example = example[..., None] + input_values.append(example.T) + + padded_inputs["input_values"] = input_values + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["KyutaiSpeechToTextFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..77c636570d588f8388698d684d2525d1af58293e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -0,0 +1,1373 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import types +from typing import Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationConfig, GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from ..auto import AutoModel +from .configuration_kyutai_speech_to_text import KyutaiSpeechToTextConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + # Ignore copy + def forward(self, x): + output = self._norm(x.float()) + output = output * self.weight.float() + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class KyutaiSpeechToTextFlexibleLinear(nn.Module): + def __init__(self, input_size, output_size, num_layers): + super().__init__() + # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size) + self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size)) + + def forward(self, x, layer_idx=None): + """ + `KyutaiSpeechToTextFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it. + In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence. + + For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`. + If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + + + Args: + x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)` + layer_idx (`torch.Tensor`, *optional*): + Can be used to specify which codebook's layers(s) to use. + If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + """ + + # Use torch.gather to select the corresponding weights for each sample + # (codebooks, output_size, hidden_size) + selected_weights = torch.index_select(self.weight, 0, layer_idx) if layer_idx is not None else self.weight + + # (1, codebooks, hidden_size, output_size) + selected_weights = selected_weights.transpose(1, 2)[None, :, :, :] + + # (batch_size, codebooks, 1, hidden_size) x (1, codebooks, hidden_size, output_size) + # -> (batch_size, codebooks, 1, output_size) + x = torch.matmul(x[:, :, None, :], selected_weights) + + # (batch_size, codebooks, output_size) + return x.squeeze(2) + + +@auto_docstring +class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel): + config: KyutaiSpeechToTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["KyutaiSpeechToTextDecoderLayer", "MimiTransformerLayer"] + _supports_flash_attn = True + _supports_sdpa = True + + main_input_name = "input_ids" + + def _init_weights(self, module): + std = self.config.initializer_range + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, KyutaiSpeechToTextFlexibleLinear): + module.weight.data.normal_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, KyutaiSpeechToTextRMSNorm): + module.weight.data.fill_(1.0) + + +class KyutaiSpeechToTextConv1dPaddingCache: + """ + Padding cache for KyutaiSpeechToTextConv1d causal convolutions in order to support streaming via cache padding. + See: https://huggingface.co/papers/2005.06720 & https://huggingface.co/papers/2204.07064 + + A padding cache is a list of cached partial hidden states for each convolution layer. + Hidden states are cached from the previous call to the KyutaiSpeechToTextConv1d forward pass, given the padding size. + """ + + def __init__( + self, + num_layers: int, + per_layer_padding: list[int], + per_layer_padding_mode: list[str], + per_layer_in_channels: list[int], + ): + # ensure correct number of layers for each arg + from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)} + + if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers: + raise ValueError( + f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`" + ) + elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode): + raise NotImplementedError( + "`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode" + ) + + self.per_layer_padding = per_layer_padding + self.per_layer_padding_mode = per_layer_padding_mode + self.per_layer_in_channels = per_layer_in_channels + self.per_layer_is_init = [True] * num_layers + + self.padding_cache = [None] * num_layers + + def update(self, hidden_states: torch.Tensor, layer_idx: int): + """ + Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache. + + Parameters: + hidden_states (`torch.Tensor`): + The hidden states to be partially cached. + layer_idx (`int`): + The index of the layer to cache the states for. + Returns: + `torch.Tensor` or `None`, the current padding cache. + """ + batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device + padding = self.per_layer_padding[layer_idx] + padding_mode = self.per_layer_padding_mode[layer_idx] + in_channels = self.per_layer_in_channels[layer_idx] + + if self.padding_cache[layer_idx] is None: + if padding_mode == "constant": + current_cache = torch.zeros( + batch_size, + in_channels, + padding, + device=device, + dtype=dtype, + ) + elif padding_mode == "replicate": + current_cache = ( + torch.ones( + batch_size, + in_channels, + padding, + device=device, + dtype=dtype, + ) + * hidden_states[..., :1] + ) + else: + current_cache = self.padding_cache[layer_idx] + + # update the cache + if padding > 0: + padding_states = hidden_states[:, :, -padding:] + else: + padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device) + self.padding_cache[layer_idx] = padding_states + + return current_cache + + +class KyutaiSpeechToTextEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_tokens = nn.Embedding( + config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1, + config.hidden_size, + padding_idx=config.audio_pad_token_id, + ) + audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size + audio_tokens_offsets += config.vocab_size + audio_tokens_offsets = nn.functional.pad( + audio_tokens_offsets, (1, 0) + ) # pad one 0 to the left for the text token + self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False) + + def forward(self, input_ids): + input_ids = torch.where( + input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets + ) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.sum(dim=2) + return inputs_embeds + + +class KyutaiSpeechToTextLinear(nn.Module): + def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False): + super().__init__() + + self.use_flexible_linear = use_flexible_linear + + if not use_flexible_linear: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = KyutaiSpeechToTextFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks) + + def forward(self, x, layer_idx=None): + if self.use_flexible_linear: + return self.linear(x, layer_idx) + else: + return self.linear(x) + + +class KyutaiSpeechToTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: KyutaiSpeechToTextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class KyutaiSpeechToTextGatingMLP(nn.Module): + def __init__(self, config, use_flexible_linear=False): + super().__init__() + + self.activation_fn = ACT2FN[config.hidden_act] + ffn_dim = config.ffn_dim + hidden_size = config.hidden_size + num_layers = config.num_codebooks if use_flexible_linear else 1 + if num_layers == 1: + self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False) + else: + self.fc1 = KyutaiSpeechToTextFlexibleLinear(hidden_size, ffn_dim, num_layers) + self.fc2 = KyutaiSpeechToTextFlexibleLinear(ffn_dim // 2, hidden_size, num_layers) + + def forward(self, hidden_states: torch.Tensor, layer_idx: Optional[int] = None) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx) + + batch_size, sequence_length, _ = hidden_states.shape + hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1) + hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :] + hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx) + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class KyutaiSpeechToTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: KyutaiSpeechToTextConfig, + layer_idx: Optional[int] = None, + use_flexible_linear=False, + use_rope=True, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + self.scaling = 1 / math.sqrt(self.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = KyutaiSpeechToTextLinear( + self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.k_proj = KyutaiSpeechToTextLinear( + self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.v_proj = KyutaiSpeechToTextLinear( + self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.o_proj = KyutaiSpeechToTextLinear( + self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear + ) + + # rotary embeddings are not used in the depth decoder + self.rotary_emb = None + if use_rope: + self.rope_theta = config.rope_theta + self.rotary_emb = KyutaiSpeechToTextRotaryEmbedding(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->KyutaiSpeechToText +# TODO cyril: modular +class KyutaiSpeechToTextFlashAttention2(KyutaiSpeechToTextAttention): + """ + KyutaiSpeechToText flash attention module. This module inherits from `KyutaiSpeechToTextAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if isinstance(past_key_values, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (KyutaiSpeechToTextRMSNorm handles it correctly) + + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->KyutaiSpeechToText +# TODO cyril: modular +class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention): + """ + KyutaiSpeechToText attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `KyutaiSpeechToTextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from KyutaiSpeechToTextAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "KyutaiSpeechToTextModel is using KyutaiSpeechToTextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = causal_mask is None and q_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + return attn_output, None + + +KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES = { + "eager": KyutaiSpeechToTextAttention, + "flash_attention_2": KyutaiSpeechToTextFlashAttention2, + "sdpa": KyutaiSpeechToTextSdpaAttention, +} + + +class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: KyutaiSpeechToTextConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True): + super().__init__() + self.hidden_size = config.hidden_size + self.use_flexible_linear = use_flexible_linear + + self.self_attn = KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope + ) + + self.mlp = KyutaiSpeechToTextGatingMLP(config, use_flexible_linear) + self.input_layernorm = KyutaiSpeechToTextRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = KyutaiSpeechToTextRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + self._attn_implementation = config._attn_implementation + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = ( + self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position) + ) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class KyutaiSpeechToTextModel(KyutaiSpeechToTextPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = KyutaiSpeechToTextEmbeddings(config) + self.layers = nn.ModuleList( + [ + KyutaiSpeechToTextDecoderLayer(config, layer_idx, use_flexible_linear=False) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = KyutaiSpeechToTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of KyutaiSpeechToText. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # StaticCache + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + config: KyutaiSpeechToTextConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`KyutaiSpeechToTextConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + is_static_sliding_cache = isinstance(past_key_values, StaticCache) and all(past_key_values.is_sliding) + if not is_static_sliding_cache or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( + cache_position.reshape(-1, 1) - text_config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@auto_docstring +class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _keep_in_fp32_modules_strict = ["codec_model"] + + def __init__(self, config): + super().__init__(config) + self.model = KyutaiSpeechToTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.codec_model = AutoModel.from_config(config.codec_config) + + # we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None + # yet the codec_model needs a generation config to initialize it's cache for streaming inference + # we therefore initialize a generation config for the codec model + self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> import torch + >>> from datasets import load_dataset, Audio + >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration + + >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" + >>> model_id = "kyutai/stt-2.6b-en-trfs" + + >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) + >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + + >>> ds = load_dataset( + ... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ... ) + + >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + >>> inputs = processor( + ... ds[0]["audio"]["array"], + ... ) + >>> inputs.to(torch_device) + + >>> output_tokens = model.generate(**inputs) + >>> print(processor.batch_decode(output_tokens, skip_special_tokens=True)) + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _prepare_generation_config(self, *args, **kwargs): + generation_config, model_kwargs = super()._prepare_generation_config(*args, **kwargs) + # this should be passed to the model kwargs for the input preparation + model_kwargs["audio_window_size"] = ( + generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None + ) + return generation_config, model_kwargs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + inputs, input_name, model_kwargs = super()._prepare_model_inputs( + inputs=inputs, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + + audio_window_size = model_kwargs.get("audio_window_size", None) + if audio_window_size is None: + audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item() + model_kwargs["audio_window_size"] = audio_window_size + + batch_size = inputs.shape[0] + device = inputs.device + + # initialize audio tokens + model_kwargs["audio_tokens"] = torch.zeros( + (batch_size, audio_window_size, self.config.num_codebooks), + device=device, + dtype=torch.long, + ) + model_kwargs["current_window"] = ( + torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous() + ) + + # let's use generate's cache preparation to prepare the cache for the codec model + temporary_model_kwargs = {} + + # monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin + # Add cache-related methods from GenerationMixin to codec model + cache_methods = [ + "_prepare_cache_for_generation", + "_get_cache", + ] + for method in cache_methods: + setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) + + setattr( + self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model) + ) + + self.codec_model.generation_config.cache_implementation = "dynamic" + self.codec_model._prepare_cache_for_generation( + generation_config=self.codec_model.generation_config, + model_kwargs=temporary_model_kwargs, + generation_mode=None, + batch_size=batch_size, + max_cache_length=self.config.codec_config.sliding_window, + ) + + if "past_key_values" in temporary_model_kwargs: + model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"] + + # initialize the padding cache for the codec model + per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], [] + for layer_name in self.codec_model.encoder._mimiconv1d_layer_names: + per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total) + per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode) + per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels) + + # downsample layer + per_layer_padding.append(self.codec_model.downsample.padding_total) + per_layer_padding_mode.append(self.codec_model.downsample.pad_mode) + per_layer_in_channels.append(self.codec_model.downsample.in_channels) + + model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache( + num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1, + per_layer_padding=per_layer_padding, + per_layer_padding_mode=per_layer_padding_mode, + per_layer_in_channels=per_layer_in_channels, + ) + + return inputs, input_name, model_kwargs + + def prepare_inputs_for_generation( + self, + *args, + audio_tokens: Optional[torch.LongTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.Tensor] = None, + audio_window_size: Optional[int] = None, + current_window: Optional[tuple[int, int]] = None, + encoder_past_key_values: Optional[Cache] = None, + padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + if input_values is not None: + cache_position = model_inputs["cache_position"] + start, end = current_window[0] + + # first cache position is for bos token, so we need to offset by -1 + if cache_position[-1] - 1 >= end: + # we need to encode the new audio tokens + with torch.no_grad(): + input_values_start_idx = start * self.config.frame_size + input_values_end_idx = (start + audio_window_size) * self.config.frame_size + current_input_values = input_values[..., input_values_start_idx:input_values_end_idx] + codec_model_output = self.codec_model.encode( + current_input_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + ) + new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2) + + audio_tokens.copy_(new_audio_tokens) + + start = end.clone() + end = end + audio_window_size + current_window.copy_( + torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1) + ) + + # first cache position is for bos token, so we need to offset by -1 + current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0) + current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :] + + current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id + + input_ids = model_inputs.pop("input_ids") + input_ids = torch.cat( + [input_ids.unsqueeze(2), current_audio_tokens], + dim=2, + ) + model_inputs["input_ids"] = input_ids + + return model_inputs + + # TODO: @eustlb, this should be standardized + @classmethod + def from_pretrained(cls, *args, **kwargs): + if kwargs.get("output_loading_info", False): + model, loading_info = super().from_pretrained(*args, **kwargs) + else: + model = super().from_pretrained(*args, **kwargs) + + # copy depth decoder generation conf attr to the depth decoder generation config + prefix = "codec_" + prefix_len = len(prefix) + codec_model_attrs = { + attr[prefix_len:]: value + for attr, value in vars(model.generation_config).items() + if attr.startswith(prefix) + } + + vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs}) + + # remove the depth decoder generation conf attr from the model generation config + for attr in codec_model_attrs: + delattr(model.generation_config, prefix + attr) + + if "output_loading_info" in kwargs: + return model, loading_info + else: + return model + + # TODO: @eustlb, this should be standardized + def save_pretrained(self, *args, **kwargs): + prefix = "codec_" + codec_model_attrs = self.codec_model.generation_config.to_diff_dict() + codec_model_attrs.pop("transformers_version", None) + for attr, value in codec_model_attrs.items(): + setattr(self.generation_config, prefix + attr, value) + + super().save_pretrained(*args, **kwargs) + + def generate(self, *args, **kwargs): + r""" + This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information. + """ + max_new_tokens = kwargs.pop("max_new_tokens", None) + input_values = kwargs.get("input_values") + + # TODO: @eustlb, we should have per-batch-idx values + # here we do not use padding_mask to be aligned to what's done in the original codebase + max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size + + if max_new_tokens is None or max_new_tokens > max_audio_frames: + if max_new_tokens is not None: + logger.warning( + f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})." + f"Setting `max_new_tokens` to {max_audio_frames}." + ) + max_new_tokens = max_audio_frames + + return super().generate( + *args, + max_new_tokens=max_new_tokens, + **kwargs, + ) + + +__all__ = [ + "KyutaiSpeechToTextPreTrainedModel", + "KyutaiSpeechToTextModel", + "KyutaiSpeechToTextForConditionalGeneration", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..d3707d659e1e23b1b81cd4dc8c5f5fc6eb3377fe --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -0,0 +1,513 @@ +# coding=utf-8 +# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...cache_utils import Cache +from ...feature_extraction_utils import BatchFeature +from ...generation import GenerationConfig, GenerationMixin +from ...modeling_utils import PreTrainedModel +from ...utils import PaddingStrategy, TensorType, logging +from ..auto import AutoModel +from ..encodec.feature_extraction_encodec import EncodecFeatureExtractor +from ..llama.modeling_llama import LlamaForCausalLM +from ..mimi.modeling_mimi import MimiConv1dPaddingCache +from ..moshi.modeling_moshi import MoshiModel, MoshiPreTrainedModel + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextFeatureExtractor(EncodecFeatureExtractor): + r""" + Constructs an KyutaiSpeechToText feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + chunk_length_s (`float`, *optional*): + If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. + overlap (`float`, *optional*): + Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following + formulae : `int((1.0 - self.overlap) * self.chunk_length)`. + audio_delay_seconds (`float`, *optional*, defaults to 0.0): + The delay in seconds to add after the audio (right padding). + audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0): + The silence prefix in seconds to add before the audio (left padding). + """ + + def __init__( + self, + audio_delay_seconds: Optional[float] = 0.0, + audio_silence_prefix_seconds: Optional[float] = 0.0, + **super_kwargs, + ): + super().__init__(**super_kwargs) + self.audio_delay_seconds = audio_delay_seconds + self.audio_silence_prefix_seconds = audio_silence_prefix_seconds + + def __call__( + self, + raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Optional[Union[bool, str, PaddingStrategy]] = None, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio + (`feature_size = 2`). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, *optional*, defaults to `False`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if padding and truncation: + raise ValueError("Both padding and truncation were set. Make sure you only set one.") + elif padding is None: + # by default let's pad the inputs + padding = True + + is_batched = bool( + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] + elif not is_batched and not isinstance(raw_audio, np.ndarray): + raw_audio = np.asarray(raw_audio, dtype=np.float32) + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): + raw_audio = raw_audio.astype(np.float32) + + # always return batch + if not is_batched: + raw_audio = [np.asarray(raw_audio).T] + + # verify inputs are valid + for idx, example in enumerate(raw_audio): + if example.ndim > 2: + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") + if self.feature_size == 1 and example.ndim != 1: + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") + if self.feature_size == 2 and example.shape[-1] != 2: + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") + + padded_inputs = None + input_values = BatchFeature({"input_values": raw_audio}) + if self.chunk_stride is not None and self.chunk_length is not None and max_length is None: + if truncation: + max_length = min(array.shape[0] for array in raw_audio) + nb_step = int(np.floor(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + elif padding: + max_length = max(array.shape[0] for array in raw_audio) + nb_step = int(np.ceil(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + padding = "max_length" + else: + padded_inputs = input_values + + # normal padding on batch + if padded_inputs is None: + padded_inputs = self.pad( + input_values, + max_length=max_length, + truncation=truncation, + padding=padding, + return_attention_mask=padding, + ) + + if padding: + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") + + # now let's pad left and right + pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate) + pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate) + padded_inputs["input_values"] = np.pad( + padded_inputs["input_values"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0.0, + ) + if padding: + padded_inputs["padding_mask"] = np.pad( + padded_inputs["padding_mask"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0, + ) + + input_values = [] + for example in padded_inputs.pop("input_values"): + if self.feature_size == 1: + example = example[..., None] + input_values.append(example.T) + + padded_inputs["input_values"] = input_values + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +class KyutaiSpeechToTextPreTrainedModel(MoshiPreTrainedModel): + pass + + +class KyutaiSpeechToTextConv1dPaddingCache(MimiConv1dPaddingCache): + pass + + +class KyutaiSpeechToTextEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_tokens = nn.Embedding( + config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1, + config.hidden_size, + padding_idx=config.audio_pad_token_id, + ) + audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size + audio_tokens_offsets += config.vocab_size + audio_tokens_offsets = nn.functional.pad( + audio_tokens_offsets, (1, 0) + ) # pad one 0 to the left for the text token + self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False) + + def forward(self, input_ids): + input_ids = torch.where( + input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets + ) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.sum(dim=2) + return inputs_embeds + + +class KyutaiSpeechToTextModel(MoshiModel): + def __init__(self, config): + super().__init__(config) + self.embed_tokens = KyutaiSpeechToTextEmbeddings(config) + + +class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin): + _keep_in_fp32_modules_strict = ["codec_model"] + + def __init__(self, config): + super().__init__(config) + self.codec_model = AutoModel.from_config(config.codec_config) + + # we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None + # yet the codec_model needs a generation config to initialize it's cache for streaming inference + # we therefore initialize a generation config for the codec model + self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config) + + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> import torch + >>> from datasets import load_dataset, Audio + >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration + + >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" + >>> model_id = "kyutai/stt-2.6b-en-trfs" + + >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) + >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + + >>> ds = load_dataset( + ... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ... ) + + >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + >>> inputs = processor( + ... ds[0]["audio"]["array"], + ... ) + >>> inputs.to(torch_device) + + >>> output_tokens = model.generate(**inputs) + >>> print(processor.batch_decode(output_tokens, skip_special_tokens=True)) + ```""" + super().forward(**super_kwargs) + + def _prepare_generation_config(self, *args, **kwargs): + generation_config, model_kwargs = GenerationMixin._prepare_generation_config(self, *args, **kwargs) + # this should be passed to the model kwargs for the input preparation + model_kwargs["audio_window_size"] = ( + generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None + ) + return generation_config, model_kwargs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs( + self, + inputs=inputs, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + + audio_window_size = model_kwargs.get("audio_window_size", None) + if audio_window_size is None: + audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item() + model_kwargs["audio_window_size"] = audio_window_size + + batch_size = inputs.shape[0] + device = inputs.device + + # initialize audio tokens + model_kwargs["audio_tokens"] = torch.zeros( + (batch_size, audio_window_size, self.config.num_codebooks), + device=device, + dtype=torch.long, + ) + model_kwargs["current_window"] = ( + torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous() + ) + + # let's use generate's cache preparation to prepare the cache for the codec model + temporary_model_kwargs = {} + + # monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin + # Add cache-related methods from GenerationMixin to codec model + cache_methods = [ + "_prepare_cache_for_generation", + "_get_cache", + ] + for method in cache_methods: + setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) + + setattr( + self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model) + ) + + self.codec_model.generation_config.cache_implementation = "dynamic" + self.codec_model._prepare_cache_for_generation( + generation_config=self.codec_model.generation_config, + model_kwargs=temporary_model_kwargs, + generation_mode=None, + batch_size=batch_size, + max_cache_length=self.config.codec_config.sliding_window, + ) + + if "past_key_values" in temporary_model_kwargs: + model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"] + + # initialize the padding cache for the codec model + per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], [] + for layer_name in self.codec_model.encoder._mimiconv1d_layer_names: + per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total) + per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode) + per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels) + + # downsample layer + per_layer_padding.append(self.codec_model.downsample.padding_total) + per_layer_padding_mode.append(self.codec_model.downsample.pad_mode) + per_layer_in_channels.append(self.codec_model.downsample.in_channels) + + model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache( + num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1, + per_layer_padding=per_layer_padding, + per_layer_padding_mode=per_layer_padding_mode, + per_layer_in_channels=per_layer_in_channels, + ) + + return inputs, input_name, model_kwargs + + def prepare_inputs_for_generation( + self, + *args, + audio_tokens: Optional[torch.LongTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.Tensor] = None, + audio_window_size: Optional[int] = None, + current_window: Optional[tuple[int, int]] = None, + encoder_past_key_values: Optional[Cache] = None, + padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None, + **kwargs, + ): + model_inputs = GenerationMixin.prepare_inputs_for_generation(self, *args, **kwargs) + + if input_values is not None: + cache_position = model_inputs["cache_position"] + start, end = current_window[0] + + # first cache position is for bos token, so we need to offset by -1 + if cache_position[-1] - 1 >= end: + # we need to encode the new audio tokens + with torch.no_grad(): + input_values_start_idx = start * self.config.frame_size + input_values_end_idx = (start + audio_window_size) * self.config.frame_size + current_input_values = input_values[..., input_values_start_idx:input_values_end_idx] + codec_model_output = self.codec_model.encode( + current_input_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + ) + new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2) + + audio_tokens.copy_(new_audio_tokens) + + start = end.clone() + end = end + audio_window_size + current_window.copy_( + torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1) + ) + + # first cache position is for bos token, so we need to offset by -1 + current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0) + current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :] + + current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id + + input_ids = model_inputs.pop("input_ids") + input_ids = torch.cat( + [input_ids.unsqueeze(2), current_audio_tokens], + dim=2, + ) + model_inputs["input_ids"] = input_ids + + return model_inputs + + # TODO: @eustlb, this should be standardized + @classmethod + def from_pretrained(cls, *args, **kwargs): + if kwargs.get("output_loading_info", False): + model, loading_info = PreTrainedModel.from_pretrained(*args, **kwargs) + else: + model = PreTrainedModel.from_pretrained(*args, **kwargs) + + # copy depth decoder generation conf attr to the depth decoder generation config + prefix = "codec_" + prefix_len = len(prefix) + codec_model_attrs = { + attr[prefix_len:]: value + for attr, value in vars(model.generation_config).items() + if attr.startswith(prefix) + } + + vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs}) + + # remove the depth decoder generation conf attr from the model generation config + for attr in codec_model_attrs: + delattr(model.generation_config, prefix + attr) + + if "output_loading_info" in kwargs: + return model, loading_info + else: + return model + + # TODO: @eustlb, this should be standardized + def save_pretrained(self, *args, **kwargs): + prefix = "codec_" + codec_model_attrs = self.codec_model.generation_config.to_diff_dict() + codec_model_attrs.pop("transformers_version", None) + for attr, value in codec_model_attrs.items(): + setattr(self.generation_config, prefix + attr, value) + + PreTrainedModel.save_pretrained(self, *args, **kwargs) + + def generate(self, *args, **kwargs): + r""" + This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information. + """ + max_new_tokens = kwargs.pop("max_new_tokens", None) + input_values = kwargs.get("input_values") + + # TODO: @eustlb, we should have per-batch-idx values + # here we do not use padding_mask to be aligned to what's done in the original codebase + max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size + + if max_new_tokens is None or max_new_tokens > max_audio_frames: + if max_new_tokens is not None: + logger.warning( + f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})." + f"Setting `max_new_tokens` to {max_audio_frames}." + ) + max_new_tokens = max_audio_frames + + return GenerationMixin.generate( + *args, + max_new_tokens=max_new_tokens, + **kwargs, + ) + + +__all__ = [ + "KyutaiSpeechToTextPreTrainedModel", + "KyutaiSpeechToTextModel", + "KyutaiSpeechToTextForConditionalGeneration", + "KyutaiSpeechToTextFeatureExtractor", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..91a588857377431e605bf6b45041e8d7e7795c6c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py @@ -0,0 +1,42 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...processing_utils import ProcessingKwargs, ProcessorMixin + + +class KyutaiSpeechToTextProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "audio_kwargs": { + "sampling_rate": 24000, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class KyutaiSpeechToTextProcessor(ProcessorMixin): + r""" + Constructs a Moshi ASR processor which wraps [`EncodecFeatureExtractor`] and + [`PreTrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and + tokenizer functionalities. See the [`~KyutaiSpeechToTextProcessor.__call__`] for more + information. + """ + + feature_extractor_class = "KyutaiSpeechToTextFeatureExtractor" + tokenizer_class = "PreTrainedTokenizerFast" + valid_processor_kwargs = KyutaiSpeechToTextProcessorKwargs + + +__all__ = ["KyutaiSpeechToTextProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b68a523c0b0c362d9930f5bee492cea73f3937f0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_layoutlmv2 import * + from .feature_extraction_layoutlmv2 import * + from .image_processing_layoutlmv2 import * + from .image_processing_layoutlmv2_fast import * + from .modeling_layoutlmv2 import * + from .processing_layoutlmv2 import * + from .tokenization_layoutlmv2 import * + from .tokenization_layoutlmv2_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/configuration_layoutlmv2.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/configuration_layoutlmv2.py new file mode 100644 index 0000000000000000000000000000000000000000..b729ddbb1d429f3cfe464f303ec55ae391d498c3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/configuration_layoutlmv2.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LayoutLMv2 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import is_detectron2_available, logging + + +logger = logging.get_logger(__name__) + + +# soft dependency +if is_detectron2_available(): + import detectron2 + + +class LayoutLMv2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LayoutLMv2Model`]. It is used to instantiate an + LayoutLMv2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LayoutLMv2 + [microsoft/layoutlmv2-base-uncased](https://huggingface.co/microsoft/layoutlmv2-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the LayoutLMv2 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`LayoutLMv2Model`] or [`TFLayoutLMv2Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`LayoutLMv2Model`] or + [`TFLayoutLMv2Model`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + max_2d_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum value that the 2D position embedding might ever be used with. Typically set this to something + large just in case (e.g., 1024). + max_rel_pos (`int`, *optional*, defaults to 128): + The maximum number of relative positions to be used in the self-attention mechanism. + rel_pos_bins (`int`, *optional*, defaults to 32): + The number of relative position bins to be used in the self-attention mechanism. + fast_qkv (`bool`, *optional*, defaults to `True`): + Whether or not to use a single matrix for the queries, keys, values in the self-attention layers. + max_rel_2d_pos (`int`, *optional*, defaults to 256): + The maximum number of relative 2D positions in the self-attention mechanism. + rel_2d_pos_bins (`int`, *optional*, defaults to 64): + The number of 2D relative position bins in the self-attention mechanism. + image_feature_pool_shape (`list[int]`, *optional*, defaults to [7, 7, 256]): + The shape of the average-pooled feature map. + coordinate_size (`int`, *optional*, defaults to 128): + Dimension of the coordinate embeddings. + shape_size (`int`, *optional*, defaults to 128): + Dimension of the width and height embeddings. + has_relative_attention_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use a relative attention bias in the self-attention mechanism. + has_spatial_attention_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use a spatial attention bias in the self-attention mechanism. + has_visual_segment_embedding (`bool`, *optional*, defaults to `False`): + Whether or not to add visual segment embeddings. + detectron2_config_args (`dict`, *optional*): + Dictionary containing the configuration arguments of the Detectron2 visual backbone. Refer to [this + file](https://github.com/microsoft/unilm/blob/master/layoutlmft/layoutlmft/models/layoutlmv2/detectron2_config.py) + for details regarding default values. + + Example: + + ```python + >>> from transformers import LayoutLMv2Config, LayoutLMv2Model + + >>> # Initializing a LayoutLMv2 microsoft/layoutlmv2-base-uncased style configuration + >>> configuration = LayoutLMv2Config() + + >>> # Initializing a model (with random weights) from the microsoft/layoutlmv2-base-uncased style configuration + >>> model = LayoutLMv2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "layoutlmv2" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + max_2d_position_embeddings=1024, + max_rel_pos=128, + rel_pos_bins=32, + fast_qkv=True, + max_rel_2d_pos=256, + rel_2d_pos_bins=64, + convert_sync_batchnorm=True, + image_feature_pool_shape=[7, 7, 256], + coordinate_size=128, + shape_size=128, + has_relative_attention_bias=True, + has_spatial_attention_bias=True, + has_visual_segment_embedding=False, + detectron2_config_args=None, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + initializer_range=initializer_range, + layer_norm_eps=layer_norm_eps, + pad_token_id=pad_token_id, + **kwargs, + ) + self.max_2d_position_embeddings = max_2d_position_embeddings + self.max_rel_pos = max_rel_pos + self.rel_pos_bins = rel_pos_bins + self.fast_qkv = fast_qkv + self.max_rel_2d_pos = max_rel_2d_pos + self.rel_2d_pos_bins = rel_2d_pos_bins + self.convert_sync_batchnorm = convert_sync_batchnorm + self.image_feature_pool_shape = image_feature_pool_shape + self.coordinate_size = coordinate_size + self.shape_size = shape_size + self.has_relative_attention_bias = has_relative_attention_bias + self.has_spatial_attention_bias = has_spatial_attention_bias + self.has_visual_segment_embedding = has_visual_segment_embedding + self.detectron2_config_args = ( + detectron2_config_args if detectron2_config_args is not None else self.get_default_detectron2_config() + ) + + @classmethod + def get_default_detectron2_config(cls): + return { + "MODEL.MASK_ON": True, + "MODEL.PIXEL_STD": [57.375, 57.120, 58.395], + "MODEL.BACKBONE.NAME": "build_resnet_fpn_backbone", + "MODEL.FPN.IN_FEATURES": ["res2", "res3", "res4", "res5"], + "MODEL.ANCHOR_GENERATOR.SIZES": [[32], [64], [128], [256], [512]], + "MODEL.RPN.IN_FEATURES": ["p2", "p3", "p4", "p5", "p6"], + "MODEL.RPN.PRE_NMS_TOPK_TRAIN": 2000, + "MODEL.RPN.PRE_NMS_TOPK_TEST": 1000, + "MODEL.RPN.POST_NMS_TOPK_TRAIN": 1000, + "MODEL.POST_NMS_TOPK_TEST": 1000, + "MODEL.ROI_HEADS.NAME": "StandardROIHeads", + "MODEL.ROI_HEADS.NUM_CLASSES": 5, + "MODEL.ROI_HEADS.IN_FEATURES": ["p2", "p3", "p4", "p5"], + "MODEL.ROI_BOX_HEAD.NAME": "FastRCNNConvFCHead", + "MODEL.ROI_BOX_HEAD.NUM_FC": 2, + "MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION": 14, + "MODEL.ROI_MASK_HEAD.NAME": "MaskRCNNConvUpsampleHead", + "MODEL.ROI_MASK_HEAD.NUM_CONV": 4, + "MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION": 7, + "MODEL.RESNETS.DEPTH": 101, + "MODEL.RESNETS.SIZES": [[32], [64], [128], [256], [512]], + "MODEL.RESNETS.ASPECT_RATIOS": [[0.5, 1.0, 2.0]], + "MODEL.RESNETS.OUT_FEATURES": ["res2", "res3", "res4", "res5"], + "MODEL.RESNETS.NUM_GROUPS": 32, + "MODEL.RESNETS.WIDTH_PER_GROUP": 8, + "MODEL.RESNETS.STRIDE_IN_1X1": False, + } + + def get_detectron2_config(self): + detectron2_config = detectron2.config.get_cfg() + for k, v in self.detectron2_config_args.items(): + attributes = k.split(".") + to_set = detectron2_config + for attribute in attributes[:-1]: + to_set = getattr(to_set, attribute) + setattr(to_set, attributes[-1], v) + + return detectron2_config + + +__all__ = ["LayoutLMv2Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py new file mode 100644 index 0000000000000000000000000000000000000000..8c70e1ed643101401373ae29637d4d597873485e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py @@ -0,0 +1,40 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for LayoutLMv2. +""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_layoutlmv2 import LayoutLMv2ImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class LayoutLMv2FeatureExtractor(LayoutLMv2ImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class LayoutLMv2FeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use LayoutLMv2ImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["LayoutLMv2FeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2.py new file mode 100644 index 0000000000000000000000000000000000000000..de2e7361a6d3f5abe3a9e4f6c0d864db62837717 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for LayoutLMv2.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import flip_channel_order, resize, to_channel_dimension_format, to_pil_image +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_pytesseract_available, + is_vision_available, + logging, + requires_backends, +) +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL + +# soft dependency +if is_pytesseract_available(): + import pytesseract + +logger = logging.get_logger(__name__) + + +def normalize_box(box, width, height): + return [ + int(1000 * (box[0] / width)), + int(1000 * (box[1] / height)), + int(1000 * (box[2] / width)), + int(1000 * (box[3] / height)), + ] + + +def apply_tesseract( + image: np.ndarray, + lang: Optional[str], + tesseract_config: Optional[str] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes.""" + tesseract_config = tesseract_config if tesseract_config is not None else "" + + # apply OCR + pil_image = to_pil_image(image, input_data_format=input_data_format) + image_width, image_height = pil_image.size + data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config) + words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] + + # filter empty words and corresponding coordinates + irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()] + words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices] + left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices] + top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices] + width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices] + height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices] + + # turn coordinates into (left, top, left+width, top+height) format + actual_boxes = [] + for x, y, w, h in zip(left, top, width, height): + actual_box = [x, y, x + w, y + h] + actual_boxes.append(actual_box) + + # finally, normalize the bounding boxes + normalized_boxes = [] + for box in actual_boxes: + normalized_boxes.append(normalize_box(box, image_width, image_height)) + + assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes" + + return words, normalized_boxes + + +@requires(backends=("vision",)) +class LayoutLMv2ImageProcessor(BaseImageProcessor): + r""" + Constructs a LayoutLMv2 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be + overridden by `do_resize` in `preprocess`. + size (`dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in `preprocess`. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + apply_ocr (`bool`, *optional*, defaults to `True`): + Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by + `apply_ocr` in `preprocess`. + ocr_lang (`str`, *optional*): + The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is + used. Can be overridden by `ocr_lang` in `preprocess`. + tesseract_config (`str`, *optional*, defaults to `""`): + Any additional custom configuration flags that are forwarded to the `config` parameter when calling + Tesseract. For example: '--psm 6'. Can be overridden by `tesseract_config` in `preprocess`. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + apply_ocr: bool = True, + ocr_lang: Optional[str] = None, + tesseract_config: Optional[str] = "", + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.apply_ocr = apply_ocr + self.ocr_lang = ocr_lang + self.tesseract_config = tesseract_config + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + apply_ocr: Optional[bool] = None, + ocr_lang: Optional[str] = None, + tesseract_config: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Desired size of the output image after resizing. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PIL.Image` resampling + filter. Only has an effect if `do_resize` is set to `True`. + apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`): + Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. + ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`): + The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is + used. + tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`): + Any additional custom configuration flags that are forwarded to the `config` parameter when calling + Tesseract. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size) + resample = resample if resample is not None else self.resample + apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr + ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang + tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if apply_ocr: + requires_backends(self, "pytesseract") + words_batch = [] + boxes_batch = [] + for image in images: + words, boxes = apply_tesseract(image, ocr_lang, tesseract_config, input_data_format=input_data_format) + words_batch.append(words) + boxes_batch.append(boxes) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + # flip color channels from RGB to BGR (as Detectron2 requires this) + images = [flip_channel_order(image, input_data_format=input_data_format) for image in images] + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + + if apply_ocr: + data["words"] = words_batch + data["boxes"] = boxes_batch + return data + + +__all__ = ["LayoutLMv2ImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..354bbe21c4dba00fcb1b373d916950e9f643b7ab --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for LayoutLMv2.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs +from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + logging, + requires_backends, +) +from .image_processing_layoutlmv2 import apply_tesseract + + +logger = logging.get_logger(__name__) + + +class LayoutLMv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + Args: + apply_ocr (`bool`, *optional*, defaults to `True`): + Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by + the `apply_ocr` parameter in the `preprocess` method. + ocr_lang (`str`, *optional*): + The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is + used. Can be overridden by the `ocr_lang` parameter in the `preprocess` method. + tesseract_config (`str`, *optional*): + Any additional custom configuration flags that are forwarded to the `config` parameter when calling + Tesseract. For example: '--psm 6'. Can be overridden by the `tesseract_config` parameter in the + `preprocess` method. + """ + + apply_ocr: Optional[bool] + ocr_lang: Optional[str] + tesseract_config: Optional[str] + + +@auto_docstring +class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + size = {"height": 224, "width": 224} + rescale_factor = None + do_resize = True + apply_ocr = True + ocr_lang = None + tesseract_config = "" + valid_kwargs = LayoutLMv2FastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[LayoutLMv2FastImageProcessorKwargs]): + super().__init__(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[LayoutLMv2FastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + apply_ocr: bool, + ocr_lang: Optional[str], + tesseract_config: Optional[str], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Tesseract OCR to get words + normalized bounding boxes + if apply_ocr: + requires_backends(self, "pytesseract") + words_batch = [] + boxes_batch = [] + for image in images: + if image.is_cuda: + logger.warning_once( + "apply_ocr can only be performed on cpu. Tensors will be transferred to cpu before processing." + ) + words, boxes = apply_tesseract( + image.cpu(), ocr_lang, tesseract_config, input_data_format=ChannelDimension.FIRST + ) + words_batch.append(words) + boxes_batch.append(boxes) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # flip color channels from RGB to BGR (as Detectron2 requires this) + stacked_images = stacked_images.flip(1) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + data = BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + if apply_ocr: + data["words"] = words_batch + data["boxes"] = boxes_batch + + return data + + +__all__ = ["LayoutLMv2ImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py new file mode 100644 index 0000000000000000000000000000000000000000..3f444fbb6b281b42477d730881e719586cfc3522 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -0,0 +1,1394 @@ +# coding=utf-8 +# Copyright 2021 Microsoft Research The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LayoutLMv2 model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward +from ...utils import auto_docstring, is_detectron2_available, logging, requires_backends +from .configuration_layoutlmv2 import LayoutLMv2Config + + +# soft dependency +if is_detectron2_available(): + import detectron2 + from detectron2.modeling import META_ARCH_REGISTRY + + # This is needed as otherwise their overload will break sequential loading by overwriting buffer over and over. See + # https://github.com/facebookresearch/detectron2/blob/9604f5995cc628619f0e4fd913453b4d7d61db3f/detectron2/layers/batch_norm.py#L83-L86 + detectron2.layers.batch_norm.FrozenBatchNorm2d._load_from_state_dict = torch.nn.Module._load_from_state_dict + +logger = logging.get_logger(__name__) + + +class LayoutLMv2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) + self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) + self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) + self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def _calc_spatial_position_embeddings(self, bbox): + try: + left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) + upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) + right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) + lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) + except IndexError as e: + raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e + + h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1]) + w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0]) + + spatial_position_embeddings = torch.cat( + [ + left_position_embeddings, + upper_position_embeddings, + right_position_embeddings, + lower_position_embeddings, + h_position_embeddings, + w_position_embeddings, + ], + dim=-1, + ) + return spatial_position_embeddings + + +class LayoutLMv2SelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.fast_qkv = config.fast_qkv + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.has_relative_attention_bias = config.has_relative_attention_bias + self.has_spatial_attention_bias = config.has_spatial_attention_bias + + if config.fast_qkv: + self.qkv_linear = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=False) + self.q_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size)) + self.v_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size)) + else: + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def compute_qkv(self, hidden_states): + if self.fast_qkv: + qkv = self.qkv_linear(hidden_states) + q, k, v = torch.chunk(qkv, 3, dim=-1) + if q.ndimension() == self.q_bias.ndimension(): + q = q + self.q_bias + v = v + self.v_bias + else: + _sz = (1,) * (q.ndimension() - 1) + (-1,) + q = q + self.q_bias.view(*_sz) + v = v + self.v_bias.view(*_sz) + else: + q = self.query(hidden_states) + k = self.key(hidden_states) + v = self.value(hidden_states) + return q, k, v + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + batch_size, seq_length, _ = hidden_states.shape + query, key, value = self.compute_qkv(hidden_states) + + # (B, L, H*D) -> (B, H, L, D) + query_layer = query.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = key.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + + query_layer = query_layer / math.sqrt(self.attention_head_size) + # [BSZ, NAT, L, L] + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + if self.has_relative_attention_bias: + attention_scores += rel_pos + if self.has_spatial_attention_bias: + attention_scores += rel_2d_pos + attention_scores = attention_scores.float().masked_fill_( + attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min + ) + attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +class LayoutLMv2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = LayoutLMv2SelfAttention(config) + self.output = LayoutLMv2SelfOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class LayoutLMv2SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->LayoutLMv2 +class LayoutLMv2Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM +class LayoutLMv2Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LayoutLMv2Layer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = LayoutLMv2Attention(config) + self.intermediate = LayoutLMv2Intermediate(config) + self.output = LayoutLMv2Output(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small + absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions + >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should + allow for more graceful generalization to longer sequences than the model has been trained on. + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + + ret = 0 + if bidirectional: + num_buckets //= 2 + ret += (relative_position > 0).long() * num_buckets + n = torch.abs(relative_position) + else: + n = torch.max(-relative_position, torch.zeros_like(relative_position)) + # now n is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = n < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + +class LayoutLMv2Encoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LayoutLMv2Layer(config) for _ in range(config.num_hidden_layers)]) + + self.has_relative_attention_bias = config.has_relative_attention_bias + self.has_spatial_attention_bias = config.has_spatial_attention_bias + + if self.has_relative_attention_bias: + self.rel_pos_bins = config.rel_pos_bins + self.max_rel_pos = config.max_rel_pos + self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False) + + if self.has_spatial_attention_bias: + self.max_rel_2d_pos = config.max_rel_2d_pos + self.rel_2d_pos_bins = config.rel_2d_pos_bins + self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False) + self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False) + + self.gradient_checkpointing = False + + def _calculate_1d_position_embeddings(self, position_ids): + rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) + rel_pos = relative_position_bucket( + rel_pos_mat, + num_buckets=self.rel_pos_bins, + max_distance=self.max_rel_pos, + ) + # Since this is a simple indexing operation that is independent of the input, + # no need to track gradients for this operation + # + # Without this no_grad context, training speed slows down significantly + with torch.no_grad(): + rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2) + rel_pos = rel_pos.contiguous() + return rel_pos + + def _calculate_2d_position_embeddings(self, bbox): + position_coord_x = bbox[:, :, 0] + position_coord_y = bbox[:, :, 3] + rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1) + rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1) + rel_pos_x = relative_position_bucket( + rel_pos_x_2d_mat, + num_buckets=self.rel_2d_pos_bins, + max_distance=self.max_rel_2d_pos, + ) + rel_pos_y = relative_position_bucket( + rel_pos_y_2d_mat, + num_buckets=self.rel_2d_pos_bins, + max_distance=self.max_rel_2d_pos, + ) + # Since this is a simple indexing operation that is independent of the input, + # no need to track gradients for this operation + # + # Without this no_grad context, training speed slows down significantly + with torch.no_grad(): + rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2) + rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2) + rel_pos_x = rel_pos_x.contiguous() + rel_pos_y = rel_pos_y.contiguous() + rel_2d_pos = rel_pos_x + rel_pos_y + return rel_2d_pos + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + bbox=None, + position_ids=None, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + rel_pos = self._calculate_1d_position_embeddings(position_ids) if self.has_relative_attention_bias else None + rel_2d_pos = self._calculate_2d_position_embeddings(bbox) if self.has_spatial_attention_bias else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class LayoutLMv2PreTrainedModel(PreTrainedModel): + config: LayoutLMv2Config + base_model_prefix = "layoutlmv2" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, LayoutLMv2SelfAttention): + if self.config.fast_qkv: + module.q_bias.data.zero_() + module.v_bias.data.zero_() + elif isinstance(module, LayoutLMv2Model): + if hasattr(module, "visual_segment_embedding"): + module.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range) + + +def my_convert_sync_batchnorm(module, process_group=None): + # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d` + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + return nn.modules.SyncBatchNorm.convert_sync_batchnorm(module, process_group) + module_output = module + if isinstance(module, detectron2.layers.FrozenBatchNorm2d): + module_output = torch.nn.SyncBatchNorm( + num_features=module.num_features, + eps=module.eps, + affine=True, + track_running_stats=True, + process_group=process_group, + ) + module_output.weight = torch.nn.Parameter(module.weight) + module_output.bias = torch.nn.Parameter(module.bias) + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = torch.tensor(0, dtype=torch.long, device=module.running_mean.device) + for name, child in module.named_children(): + module_output.add_module(name, my_convert_sync_batchnorm(child, process_group)) + del module + return module_output + + +class LayoutLMv2VisualBackbone(nn.Module): + def __init__(self, config): + super().__init__() + self.cfg = config.get_detectron2_config() + meta_arch = self.cfg.MODEL.META_ARCHITECTURE + model = META_ARCH_REGISTRY.get(meta_arch)(self.cfg) + assert isinstance(model.backbone, detectron2.modeling.backbone.FPN) + self.backbone = model.backbone + + assert len(self.cfg.MODEL.PIXEL_MEAN) == len(self.cfg.MODEL.PIXEL_STD) + num_channels = len(self.cfg.MODEL.PIXEL_MEAN) + self.register_buffer( + "pixel_mean", + torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1), + persistent=False, + ) + self.register_buffer( + "pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1), persistent=False + ) + self.out_feature_key = "p2" + if torch.are_deterministic_algorithms_enabled(): + logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`") + input_shape = (224, 224) + backbone_stride = self.backbone.output_shape()[self.out_feature_key].stride + self.pool = nn.AvgPool2d( + ( + math.ceil(math.ceil(input_shape[0] / backbone_stride) / config.image_feature_pool_shape[0]), + math.ceil(math.ceil(input_shape[1] / backbone_stride) / config.image_feature_pool_shape[1]), + ) + ) + else: + self.pool = nn.AdaptiveAvgPool2d(config.image_feature_pool_shape[:2]) + if len(config.image_feature_pool_shape) == 2: + config.image_feature_pool_shape.append(self.backbone.output_shape()[self.out_feature_key].channels) + assert self.backbone.output_shape()[self.out_feature_key].channels == config.image_feature_pool_shape[2] + + def forward(self, images): + images_input = ((images if torch.is_tensor(images) else images.tensor) - self.pixel_mean) / self.pixel_std + features = self.backbone(images_input) + features = features[self.out_feature_key] + features = self.pool(features).flatten(start_dim=2).transpose(1, 2).contiguous() + return features + + def synchronize_batch_norm(self): + if not ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > -1 + ): + raise RuntimeError("Make sure torch.distributed is set up properly.") + + self_rank = torch.distributed.get_rank() + node_size = torch.cuda.device_count() + world_size = torch.distributed.get_world_size() + if not (world_size % node_size == 0): + raise RuntimeError("Make sure the number of processes can be divided by the number of nodes") + + node_global_ranks = [list(range(i * node_size, (i + 1) * node_size)) for i in range(world_size // node_size)] + sync_bn_groups = [ + torch.distributed.new_group(ranks=node_global_ranks[i]) for i in range(world_size // node_size) + ] + node_rank = self_rank // node_size + + self.backbone = my_convert_sync_batchnorm(self.backbone, process_group=sync_bn_groups[node_rank]) + + +class LayoutLMv2Pooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class LayoutLMv2Model(LayoutLMv2PreTrainedModel): + def __init__(self, config): + requires_backends(self, "detectron2") + super().__init__(config) + self.config = config + self.has_visual_segment_embedding = config.has_visual_segment_embedding + self.embeddings = LayoutLMv2Embeddings(config) + + self.visual = LayoutLMv2VisualBackbone(config) + self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size) + if self.has_visual_segment_embedding: + self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0]) + self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.visual_dropout = nn.Dropout(config.hidden_dropout_prob) + + self.encoder = LayoutLMv2Encoder(config) + self.pooler = LayoutLMv2Pooler(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + if inputs_embeds is None: + inputs_embeds = self.embeddings.word_embeddings(input_ids) + position_embeddings = self.embeddings.position_embeddings(position_ids) + spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox) + token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings + embeddings = self.embeddings.LayerNorm(embeddings) + embeddings = self.embeddings.dropout(embeddings) + return embeddings + + def _calc_img_embeddings(self, image, bbox, position_ids): + visual_embeddings = self.visual_proj(self.visual(image)) + position_embeddings = self.embeddings.position_embeddings(position_ids) + spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox) + embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings + if self.has_visual_segment_embedding: + embeddings += self.visual_segment_embedding + embeddings = self.visual_LayerNorm(embeddings) + embeddings = self.visual_dropout(embeddings) + return embeddings + + def _calc_visual_bbox(self, image_feature_pool_shape, bbox, device, final_shape): + visual_bbox_x = torch.div( + torch.arange( + 0, + 1000 * (image_feature_pool_shape[1] + 1), + 1000, + device=device, + dtype=bbox.dtype, + ), + self.config.image_feature_pool_shape[1], + rounding_mode="floor", + ) + visual_bbox_y = torch.div( + torch.arange( + 0, + 1000 * (self.config.image_feature_pool_shape[0] + 1), + 1000, + device=device, + dtype=bbox.dtype, + ), + self.config.image_feature_pool_shape[0], + rounding_mode="floor", + ) + visual_bbox = torch.stack( + [ + visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1), + visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1), + visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1), + visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1), + ], + dim=-1, + ).view(-1, bbox.size(-1)) + + visual_bbox = visual_bbox.repeat(final_shape[0], 1, 1) + + return visual_bbox + + def _get_input_shape(self, input_ids=None, inputs_embeds=None): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + return input_ids.size() + elif inputs_embeds is not None: + return inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + bbox: Optional[torch.LongTensor] = None, + image: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + bbox (`torch.LongTensor` of shape `((batch_size, sequence_length), 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`): + Batch of document images. + + Examples: + + ```python + >>> from transformers import AutoProcessor, LayoutLMv2Model, set_seed + >>> from PIL import Image + >>> import torch + >>> from datasets import load_dataset + + >>> set_seed(0) + + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased") + >>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased") + + + >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa") + >>> image = dataset["test"][0]["image"] + + >>> encoding = processor(image, return_tensors="pt") + + >>> outputs = model(**encoding) + >>> last_hidden_states = outputs.last_hidden_state + + >>> last_hidden_states.shape + torch.Size([1, 342, 768]) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = self._get_input_shape(input_ids, inputs_embeds) + device = input_ids.device if input_ids is not None else inputs_embeds.device + + visual_shape = list(input_shape) + visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1] + visual_shape = torch.Size(visual_shape) + # needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur + final_shape = list(self._get_input_shape(input_ids, inputs_embeds)) + final_shape[1] += visual_shape[1] + final_shape = torch.Size(final_shape) + + visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, device, final_shape) + final_bbox = torch.cat([bbox, visual_bbox], dim=1) + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + + visual_attention_mask = torch.ones(visual_shape, device=device) + final_attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if position_ids is None: + seq_length = input_shape[1] + position_ids = self.embeddings.position_ids[:, :seq_length] + position_ids = position_ids.expand(input_shape) + + visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat( + input_shape[0], 1 + ) + final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1) + + if bbox is None: + bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) + + text_layout_emb = self._calc_text_embeddings( + input_ids=input_ids, + bbox=bbox, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + visual_emb = self._calc_img_embeddings( + image=image, + bbox=visual_bbox, + position_ids=visual_position_ids, + ) + final_emb = torch.cat([text_layout_emb, visual_emb], dim=1) + + extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2) + + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min + + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + final_emb, + extended_attention_mask, + bbox=final_bbox, + position_ids=final_position_ids, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + LayoutLMv2 Model with a sequence classification head on top (a linear layer on top of the concatenation of the + final hidden state of the [CLS] token, average-pooled initial visual embeddings and average-pooled final visual + embeddings, e.g. for document image classification tasks such as the + [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset. + """ +) +class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.layoutlmv2 = LayoutLMv2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.layoutlmv2.embeddings.word_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + bbox: Optional[torch.LongTensor] = None, + image: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`): + Batch of document images. + token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Example: + + ```python + >>> from transformers import AutoProcessor, LayoutLMv2ForSequenceClassification, set_seed + >>> from PIL import Image + >>> import torch + >>> from datasets import load_dataset + + >>> set_seed(0) + + >>> dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True) + >>> data = next(iter(dataset)) + >>> image = data["image"].convert("RGB") + + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased") + >>> model = LayoutLMv2ForSequenceClassification.from_pretrained( + ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes + ... ) + + >>> encoding = processor(image, return_tensors="pt") + >>> sequence_label = torch.tensor([data["label"]]) + + >>> outputs = model(**encoding, labels=sequence_label) + + >>> loss, logits = outputs.loss, outputs.logits + >>> predicted_idx = logits.argmax(dim=-1).item() + >>> predicted_answer = dataset.info.features["label"].names[4] + >>> predicted_idx, predicted_answer # results are not good without further fine-tuning + (7, 'advertisement') + ``` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + visual_shape = list(input_shape) + visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1] + visual_shape = torch.Size(visual_shape) + final_shape = list(input_shape) + final_shape[1] += visual_shape[1] + final_shape = torch.Size(final_shape) + + visual_bbox = self.layoutlmv2._calc_visual_bbox( + self.config.image_feature_pool_shape, bbox, device, final_shape + ) + + visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat( + input_shape[0], 1 + ) + + initial_image_embeddings = self.layoutlmv2._calc_img_embeddings( + image=image, + bbox=visual_bbox, + position_ids=visual_position_ids, + ) + + outputs = self.layoutlmv2( + input_ids=input_ids, + bbox=bbox, + image=image, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + sequence_output, final_image_embeddings = outputs[0][:, :seq_length], outputs[0][:, seq_length:] + + cls_final_output = sequence_output[:, 0, :] + + # average-pool the visual embeddings + pooled_initial_image_embeddings = initial_image_embeddings.mean(dim=1) + pooled_final_image_embeddings = final_image_embeddings.mean(dim=1) + # concatenate with cls_final_output + sequence_output = torch.cat( + [cls_final_output, pooled_initial_image_embeddings, pooled_final_image_embeddings], dim=1 + ) + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + LayoutLMv2 Model with a token classification head on top (a linear layer on top of the text part of the hidden + states) e.g. for sequence labeling (information extraction) tasks such as + [FUNSD](https://guillaumejaume.github.io/FUNSD/), [SROIE](https://rrc.cvc.uab.es/?ch=13), + [CORD](https://github.com/clovaai/cord) and [Kleister-NDA](https://github.com/applicaai/kleister-nda). + """ +) +class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.layoutlmv2 = LayoutLMv2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.layoutlmv2.embeddings.word_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + bbox: Optional[torch.LongTensor] = None, + image: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`): + Batch of document images. + token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + + Example: + + ```python + >>> from transformers import AutoProcessor, LayoutLMv2ForTokenClassification, set_seed + >>> from PIL import Image + >>> from datasets import load_dataset + + >>> set_seed(0) + + >>> datasets = load_dataset("nielsr/funsd", split="test") + >>> labels = datasets.features["ner_tags"].feature.names + >>> id2label = {v: k for v, k in enumerate(labels)} + + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr") + >>> model = LayoutLMv2ForTokenClassification.from_pretrained( + ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels) + ... ) + + >>> data = datasets[0] + >>> image = Image.open(data["image_path"]).convert("RGB") + >>> words = data["words"] + >>> boxes = data["bboxes"] # make sure to normalize your bounding boxes + >>> word_labels = data["ner_tags"] + >>> encoding = processor( + ... image, + ... words, + ... boxes=boxes, + ... word_labels=word_labels, + ... padding="max_length", + ... truncation=True, + ... return_tensors="pt", + ... ) + + >>> outputs = model(**encoding) + >>> logits, loss = outputs.logits, outputs.loss + + >>> predicted_token_class_ids = logits.argmax(-1) + >>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]] + >>> predicted_tokens_classes[:5] # results are not good without further fine-tuning + ['I-HEADER', 'I-HEADER', 'I-QUESTION', 'I-HEADER', 'I-QUESTION'] + ``` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.layoutlmv2( + input_ids=input_ids, + bbox=bbox, + image=image, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + # only take the text part of the output representations + sequence_output = outputs[0][:, :seq_length] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel): + def __init__(self, config, has_visual_segment_embedding=True): + r""" + has_visual_segment_embedding (`bool`, *optional*, defaults to `True`): + Whether or not to add visual segment embeddings. + """ + super().__init__(config) + self.num_labels = config.num_labels + config.has_visual_segment_embedding = has_visual_segment_embedding + self.layoutlmv2 = LayoutLMv2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.layoutlmv2.embeddings.word_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + bbox: Optional[torch.LongTensor] = None, + image: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`): + Batch of document images. + token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + + Example: + + In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us + a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image). + + ```python + >>> from transformers import AutoProcessor, LayoutLMv2ForQuestionAnswering, set_seed + >>> import torch + >>> from PIL import Image + >>> from datasets import load_dataset + + >>> set_seed(0) + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased") + >>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased") + + >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa") + >>> image = dataset["test"][0]["image"] + >>> question = "When is coffee break?" + >>> encoding = processor(image, question, return_tensors="pt") + + >>> outputs = model(**encoding) + >>> predicted_start_idx = outputs.start_logits.argmax(-1).item() + >>> predicted_end_idx = outputs.end_logits.argmax(-1).item() + >>> predicted_start_idx, predicted_end_idx + (30, 191) + + >>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1] + >>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens) + >>> predicted_answer # results are not good without further fine-tuning + '44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from' + ``` + + ```python + >>> target_start_index = torch.tensor([7]) + >>> target_end_index = torch.tensor([14]) + >>> outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index) + >>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item() + >>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item() + >>> predicted_answer_span_start, predicted_answer_span_end + (30, 191) + ``` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.layoutlmv2( + input_ids=input_ids, + bbox=bbox, + image=image, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + # only take the text part of the output representations + sequence_output = outputs[0][:, :seq_length] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "LayoutLMv2ForQuestionAnswering", + "LayoutLMv2ForSequenceClassification", + "LayoutLMv2ForTokenClassification", + "LayoutLMv2Layer", + "LayoutLMv2Model", + "LayoutLMv2PreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/processing_layoutlmv2.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/processing_layoutlmv2.py new file mode 100644 index 0000000000000000000000000000000000000000..603cdf4df4e93e615fbb127bdf2057ae8d1d3b2c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/processing_layoutlmv2.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for LayoutLMv2. +""" + +import warnings +from typing import Optional, Union + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class LayoutLMv2Processor(ProcessorMixin): + r""" + Constructs a LayoutLMv2 processor which combines a LayoutLMv2 image processor and a LayoutLMv2 tokenizer into a + single processor. + + [`LayoutLMv2Processor`] offers all the functionalities you need to prepare data for the model. + + It first uses [`LayoutLMv2ImageProcessor`] to resize document images to a fixed size, and optionally applies OCR to + get words and normalized bounding boxes. These are then provided to [`LayoutLMv2Tokenizer`] or + [`LayoutLMv2TokenizerFast`], which turns the words and bounding boxes into token-level `input_ids`, + `attention_mask`, `token_type_ids`, `bbox`. Optionally, one can provide integer `word_labels`, which are turned + into token-level `labels` for token classification tasks (such as FUNSD, CORD). + + Args: + image_processor (`LayoutLMv2ImageProcessor`, *optional*): + An instance of [`LayoutLMv2ImageProcessor`]. The image processor is a required input. + tokenizer (`LayoutLMv2Tokenizer` or `LayoutLMv2TokenizerFast`, *optional*): + An instance of [`LayoutLMv2Tokenizer`] or [`LayoutLMv2TokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "LayoutLMv2ImageProcessor" + tokenizer_class = ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None, + boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None, + word_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchEncoding: + """ + This method first forwards the `images` argument to [`~LayoutLMv2ImageProcessor.__call__`]. In case + [`LayoutLMv2ImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and + bounding boxes along with the additional arguments to [`~LayoutLMv2Tokenizer.__call__`] and returns the output, + together with resized `images`. In case [`LayoutLMv2ImageProcessor`] was initialized with `apply_ocr` set to + `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user along with the additional + arguments to [`~LayoutLMv2Tokenizer.__call__`] and returns the output, together with resized `images``. + + Please refer to the docstring of the above two methods for more information. + """ + # verify input + if self.image_processor.apply_ocr and (boxes is not None): + raise ValueError( + "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." + ) + + if self.image_processor.apply_ocr and (word_labels is not None): + raise ValueError( + "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." + ) + + if return_overflowing_tokens is True and return_offsets_mapping is False: + raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") + + # first, apply the image processor + features = self.image_processor(images=images, return_tensors=return_tensors) + + # second, apply the tokenizer + if text is not None and self.image_processor.apply_ocr and text_pair is None: + if isinstance(text, str): + text = [text] # add batch dimension (as the image processor always adds a batch dimension) + text_pair = features["words"] + + encoded_inputs = self.tokenizer( + text=text if text is not None else features["words"], + text_pair=text_pair if text_pair is not None else None, + boxes=boxes if boxes is not None else features["boxes"], + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + return_tensors=return_tensors, + **kwargs, + ) + + # add pixel values + images = features.pop("pixel_values") + if return_overflowing_tokens is True: + images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"]) + encoded_inputs["image"] = images + + return encoded_inputs + + def get_overflowing_images(self, images, overflow_to_sample_mapping): + # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image + images_with_overflow = [] + for sample_idx in overflow_to_sample_mapping: + images_with_overflow.append(images[sample_idx]) + + if len(images_with_overflow) != len(overflow_to_sample_mapping): + raise ValueError( + "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got" + f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}" + ) + + return images_with_overflow + + @property + def model_input_names(self): + return ["input_ids", "bbox", "token_type_ids", "attention_mask", "image"] + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + @property + def feature_extractor(self): + warnings.warn( + "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", + FutureWarning, + ) + return self.image_processor + + +__all__ = ["LayoutLMv2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py new file mode 100644 index 0000000000000000000000000000000000000000..7d82b5cf41041b0024134c0f1a6294c0cace824c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py @@ -0,0 +1,1545 @@ +# coding=utf-8 +# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for LayoutLMv2.""" + +import collections +import os +import sys +import unicodedata +from typing import Optional, Union + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...tokenization_utils_base import ( + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + +LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **bbox** -- List of bounding boxes to be fed to a model. + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`). +""" + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +table = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")) + + +def subfinder(mylist, pattern): + matches = [] + indices = [] + for idx, i in enumerate(range(len(mylist))): + if mylist[i] == pattern[0] and mylist[i : i + len(pattern)] == pattern: + matches.append(pattern) + indices.append(idx) + if matches: + return matches[0], indices[0] + else: + return None, 0 + + +class LayoutLMv2Tokenizer(PreTrainedTokenizer): + r""" + Construct a LayoutLMv2 tokenizer. Based on WordPiece. [`LayoutLMv2Tokenizer`] can be used to turn words, word-level + bounding boxes and optional word labels to token-level `input_ids`, `attention_mask`, `token_type_ids`, `bbox`, and + optional `labels` (for token classification). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + [`LayoutLMv2Tokenizer`] runs end-to-end tokenization: punctuation splitting and wordpiece. It also turns the + word-level bounding boxes into token-level bounding boxes. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + cls_token_box=[0, 0, 0, 0], + sep_token_box=[1000, 1000, 1000, 1000], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + tokenize_chinese_chars=True, + strip_accents=None, + model_max_length: int = 512, + additional_special_tokens: Optional[list[str]] = None, + **kwargs, + ): + sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + cls_token = AddedToken(cls_token, special=True) if isinstance(cls_token, str) else cls_token + mask_token = AddedToken(mask_token, special=True) if isinstance(mask_token, str) else mask_token + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + # additional properties + self.cls_token_box = cls_token_box + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + cls_token_box=cls_token_box, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + model_max_length=model_max_length, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None, + boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None, + word_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[list[list[list[int]]]] = None, + word_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[list[list[list[int]]]] = None, + word_labels: Optional[list[list[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + batch_outputs = self._batch_prepare_for_model( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_text_or_text_pairs, + is_pair: Optional[bool] = None, + boxes: Optional[list[list[int]]] = None, + word_labels: Optional[list[list[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens. + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)): + batch_text_or_text_pair, boxes_example = example + outputs = self.prepare_for_model( + batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair, + batch_text_or_text_pair[1] if is_pair else None, + boxes_example, + word_labels=word_labels[idx] if word_labels is not None else None, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING) + def encode( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[list[list[int]]] = None, + word_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> list[int]: + encoded_inputs = self.encode_plus( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[list[list[int]]] = None, + word_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated, + `__call__` should be used instead. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + boxes=boxes, + text_pair=text_pair, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[list[list[int]]] = None, + word_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + return self.prepare_for_model( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[list[list[int]]] = None, + word_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens, + truncates sequences if overflowing while taking into account the special tokens and manages a moving window + (with user defined stride) for overflowing tokens. Please Note, for *text_pair* different than `None` and + *truncation_strategy = longest_first* or `True`, it is not possible to return overflowing tokens. Such a + combination of arguments will raise an error. + + Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into + token-level `labels`. The word label is used for the first token of the word, while remaining tokens are + labeled with -100, such that they will be ignored by the loss function. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + tokens = [] + pair_tokens = [] + token_boxes = [] + pair_token_boxes = [] + labels = [] + + if text_pair is None: + if word_labels is None: + # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference) + for word, box in zip(text, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + else: + # CASE 2: token classification (training) + for word, box, label in zip(text, boxes, word_labels): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + if self.only_label_first_subword: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1)) + else: + labels.extend([label] * len(word_tokens)) + else: + # CASE 3: document visual question answering (inference) + # text = question + # text_pair = words + tokens = self.tokenize(text) + token_boxes = [self.pad_token_box for _ in range(len(tokens))] + + for word, box in zip(text_pair, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + pair_tokens.extend(word_tokens) + pair_token_boxes.extend([box] * len(word_tokens)) + + # Create ids + pair_ids + ids = self.convert_tokens_to_ids(tokens) + pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None + + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Compute the total size of the returned encodings + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) = self.truncate_sequences( + ids, + token_boxes, + pair_ids=pair_ids, + pair_token_boxes=pair_token_boxes, + labels=labels, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes + encoded_inputs["overflowing_labels"] = overflowing_labels + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + token_boxes = [self.cls_token_box] + token_boxes + [self.sep_token_box] + if pair_token_boxes: + pair_token_boxes = pair_token_boxes + [self.sep_token_box] + if labels: + labels = [self.pad_token_label] + labels + [self.pad_token_label] + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + encoded_inputs["bbox"] = token_boxes + pair_token_boxes + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + if labels: + encoded_inputs["labels"] = labels + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def truncate_sequences( + self, + ids: list[int], + token_boxes: list[list[int]], + pair_ids: Optional[list[int]] = None, + pair_token_boxes: Optional[list[list[int]]] = None, + labels: Optional[list[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> tuple[list[int], list[int], list[int]]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + token_boxes (`List[List[int]]`): + Bounding boxes of the first sequence. + pair_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + pair_token_boxes (`List[List[int]]`, *optional*): + Bounding boxes of the second sequence. + labels (`List[int]`, *optional*): + Labels of the first sequence (for token classification tasks). + num_tokens_to_remove (`int`, *optional*, defaults to 0): + Number of tokens to remove using the truncation strategy. + truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + The strategy to follow for truncation. Can be: + + - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a + batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater + than the model maximum admissible input size). + stride (`int`, *optional*, defaults to 0): + If set to a positive number, the overflowing tokens returned will contain some tokens from the main + sequence returned. The value of this argument defines the number of additional tokens. + + Returns: + `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair + of sequences (or a batch of pairs) is provided. + """ + if num_tokens_to_remove <= 0: + return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( + truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None + ): + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + overflowing_token_boxes = token_boxes[-window_len:] + overflowing_labels = labels[-window_len:] + ids = ids[:-num_tokens_to_remove] + token_boxes = token_boxes[:-num_tokens_to_remove] + labels = labels[:-num_tokens_to_remove] + else: + error_msg = ( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the first sequence has a length {len(ids)}. " + ) + if truncation_strategy == TruncationStrategy.ONLY_FIRST: + error_msg = ( + error_msg + "Please select another truncation strategy than " + f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." + ) + logger.error(error_msg) + elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: + logger.warning( + "Be aware, overflowing tokens are not returned for the setting you have chosen," + f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' " + "truncation strategy. So the returned list will always be empty even if some " + "tokens have been removed." + ) + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + ids = ids[:-1] + token_boxes = token_boxes[:-1] + labels = labels[:-1] + else: + pair_ids = pair_ids[:-1] + pair_token_boxes = pair_token_boxes[:-1] + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + overflowing_token_boxes = pair_token_boxes[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_first'." + ) + + return ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) + + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["LayoutLMv2Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..8e324ee0b8fe971b06cd102eaa9930990a1f5b99 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py @@ -0,0 +1,789 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fast tokenization class for LayoutLMv2. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus +and _encode_plus, in which the Rust tokenizer is used. +""" + +import json +from typing import Optional, Union + +from tokenizers import normalizers + +from ...tokenization_utils_base import ( + BatchEncoding, + EncodedInput, + PaddingStrategy, + PreTokenizedInput, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import add_end_docstrings, logging +from .tokenization_layoutlmv2 import ( + LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, + LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, + LayoutLMv2Tokenizer, +) + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" LayoutLMv2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [CLS] token. + sep_token_box (`List[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`): + The bounding box to use for the special [SEP] token. + pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [PAD] token. + pad_token_label (`int`, *optional*, defaults to -100): + The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's + CrossEntropyLoss. + only_label_first_subword (`bool`, *optional*, defaults to `True`): + Whether or not to only label the first subword, in case word labels are provided. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original LayoutLMv2). + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = LayoutLMv2Tokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + cls_token_box=[0, 0, 0, 0], + sep_token_box=[1000, 1000, 1000, 1000], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + cls_token_box=cls_token_box, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + pre_tok_state.get("lowercase", do_lower_case) != do_lower_case + or pre_tok_state.get("strip_accents", strip_accents) != strip_accents + ): + pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) + pre_tok_state["lowercase"] = do_lower_case + pre_tok_state["strip_accents"] = strip_accents + self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) + + self.do_lower_case = do_lower_case + + # additional properties + self.cls_token_box = cls_token_box + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None, + boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None, + word_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[list[list[list[int]]]] = None, + word_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]: + batched_input = [(text, pair)] if pair else [text] + encodings = self._tokenizer.encode_batch( + batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs + ) + + return encodings[0].tokens + + @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[list[list[int]]] = None, + word_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated, + `__call__` should be used instead. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + boxes=boxes, + text_pair=text_pair, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[list[list[list[int]]]] = None, + word_labels: Optional[list[list[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + if not isinstance(batch_text_or_text_pairs, list): + raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})") + + # Set the truncation and padding strategy and restore the initial configuration + self.set_truncation_and_padding( + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + ) + + if is_pair: + batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs] + + encodings = self._tokenizer.encode_batch( + batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + is_pretokenized=True, # we set this to True as LayoutLMv2 always expects pretokenized inputs + ) + + # Convert encoding to dict + # `Tokens` has type: Tuple[ + # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]], + # List[EncodingFast] + # ] + # with nested dimensions corresponding to batch, overflows, sequence length + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=True + if word_labels is not None + else return_offsets_mapping, # we use offsets to create the labels + return_length=return_length, + verbose=verbose, + ) + for encoding in encodings + ] + + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0]: + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [] + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) + + # create the token boxes + token_boxes = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + token_boxes_example = [] + for id, sequence_id, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_encodings[batch_index].sequence_ids, + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if is_pair and sequence_id == 0: + token_boxes_example.append(self.pad_token_box) + else: + token_boxes_example.append(boxes[original_index][word_id]) + else: + if id == self.cls_token_id: + token_boxes_example.append(self.cls_token_box) + elif id == self.sep_token_id: + token_boxes_example.append(self.sep_token_box) + elif id == self.pad_token_id: + token_boxes_example.append(self.pad_token_box) + else: + raise ValueError("Id not recognized") + token_boxes.append(token_boxes_example) + + sanitized_tokens["bbox"] = token_boxes + + # optionally, create the labels + if word_labels is not None: + labels = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + labels_example = [] + for id, offset, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_tokens["offset_mapping"][batch_index], + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if self.only_label_first_subword: + if offset[0] == 0: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels_example.append(word_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + else: + labels_example.append(word_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + labels.append(labels_example) + + sanitized_tokens["labels"] = labels + # finally, remove offsets if the user didn't want them + if not return_offsets_mapping: + del sanitized_tokens["offset_mapping"] + + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[list[list[int]]] = None, + word_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # make it a batched input + # 2 options: + # 1) only text, in case text must be a list of str + # 2) text + text_pair, in which case text = str and text_pair a list of str + batched_input = [(text, text_pair)] if text_pair else [text] + batched_boxes = [boxes] + batched_word_labels = [word_labels] if word_labels is not None else None + batched_output = self._batch_encode_plus( + batched_input, + is_pair=bool(text_pair is not None), + boxes=batched_boxes, + word_labels=batched_word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + # Overflowing tokens are returned as a batch of output so we keep them in this case + if return_tensors is None and not return_overflowing_tokens: + batched_output = BatchEncoding( + { + key: value[0] if len(value) > 0 and isinstance(value[0], list) else value + for key, value in batched_output.items() + }, + batched_output.encodings, + ) + + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + + return batched_output + + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1: + output += token_ids_1 + [self.sep_token_id] + + return output + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["LayoutLMv2TokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/led/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/led/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..786ebd36d7b8cc3c7508d5282889a5dd941c41fe --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/led/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_led import * + from .modeling_led import * + from .modeling_tf_led import * + from .tokenization_led import * + from .tokenization_led_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/led/configuration_led.py b/venv/lib/python3.13/site-packages/transformers/models/led/configuration_led.py new file mode 100644 index 0000000000000000000000000000000000000000..57809df4aa881f9a4622ce0d1b110e93ff654c4f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/led/configuration_led.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LED model configuration""" + +from typing import Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class LEDConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LEDModel`]. It is used to instantiate an LED + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LED + [allenai/led-base-16384](https://huggingface.co/allenai/led-base-16384) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the LED model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LEDModel`] or [`TFLEDModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_encoder_position_embeddings (`int`, *optional*, defaults to 16384): + The maximum sequence length that the encoder might ever be used with. + max_decoder_position_embeddings (`int`, *optional*, defaults to 16384): + The maximum sequence length that the decoder might ever be used with. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + + Example: + + ```python + >>> from transformers import LEDModel, LEDConfig + + >>> # Initializing a LED allenai/led-base-16384 style configuration + >>> configuration = LEDConfig() + + >>> # Initializing a model from the allenai/led-base-16384 style configuration + >>> model = LEDModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "led" + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model", + "attention_probs_dropout_prob": "attention_dropout", + "initializer_range": "init_std", + } + + def __init__( + self, + vocab_size=50265, + max_encoder_position_embeddings=16384, + max_decoder_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + classifier_dropout=0.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + attention_window: Union[list[int], int] = 512, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_encoder_position_embeddings = max_encoder_position_embeddings + self.max_decoder_position_embeddings = max_decoder_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.attention_window = attention_window + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +__all__ = ["LEDConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/led/modeling_led.py b/venv/lib/python3.13/site-packages/transformers/models/led/modeling_led.py new file mode 100644 index 0000000000000000000000000000000000000000..26d1321842e6b24041d82bc9913ae720f66db55f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/led/modeling_led.py @@ -0,0 +1,2527 @@ +# coding=utf-8 +# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LED model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_led import LEDConfig + + +logger = logging.get_logger(__name__) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _prepare_4d_attention_mask_inverted(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + expanded_attention_mask = inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + # make sure that global_attn_mask is positive + expanded_attention_mask = expanded_attention_mask * inverted_mask + + return expanded_attention_mask + + +class LEDLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__(num_embeddings, embedding_dim) + + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions) + + +# Copied from transformers.models.longformer.modeling_longformer.LongformerSelfAttention with Longformer->LEDEncoder +class LEDEncoderSelfAttention(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + + self.query = nn.Linear(config.hidden_size, self.embed_dim) + self.key = nn.Linear(config.hidden_size, self.embed_dim) + self.value = nn.Linear(config.hidden_size, self.embed_dim) + + # separate projection layers for tokens with global attention + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) + self.key_global = nn.Linear(config.hidden_size, self.embed_dim) + self.value_global = nn.Linear(config.hidden_size, self.embed_dim) + + self.dropout = config.attention_probs_dropout_prob + + self.layer_id = layer_id + attention_window = config.attention_window[self.layer_id] + assert attention_window % 2 == 0, ( + f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" + ) + assert attention_window > 0, ( + f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" + ) + + self.one_sided_attn_window_size = attention_window // 2 + + self.config = config + + def forward( + self, + hidden_states, + attention_mask=None, + layer_head_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, + ): + """ + [`LEDEncoderSelfAttention`] expects *len(hidden_states)* to be multiple of *attention_window*. Padding to + *attention_window* happens in [`LEDEncoderModel.forward`] to avoid redoing the padding on each layer. + + The *attention_mask* is changed in [`LEDEncoderModel.forward`] from 0, 1, 2 to: + + - -10000: no attention + - 0: local attention + - +10000: global attention + """ + hidden_states = hidden_states.transpose(0, 1) + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = hidden_states.size() + assert embed_dim == self.embed_dim, ( + f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + ) + + # normalize query + query_vectors /= math.sqrt(self.head_dim) + + query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None] + + # cast to fp32/fp16 then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( + remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min + ) + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size + ) + + # pad local attention probs + attn_scores += diagonal_mask + + assert list(attn_scores.size()) == [ + batch_size, + seq_len, + self.num_heads, + self.one_sided_attn_window_size * 2 + 1, + ], ( + f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," + f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" + ) + + # compute local attention probs from global attention keys and contact over window dim + if is_global_attn: + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + # calculate global attn probs from global key + + global_key_attn_scores = self._concat_with_global_key_attn_probs( + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + # concat to local_attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) + + # free memory + del global_key_attn_scores + + attn_probs = nn.functional.softmax( + attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_heads,), ( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) + attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) + attn_probs = attn_probs.type_as(attn_scores) + + # free memory + del attn_scores + + # apply dropout + attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) + + value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # compute local attention output with global attention value and add + if is_global_attn: + # compute sum of global and local attn + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + # compute local attn only + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() + + # compute value for global attention and overwrite to attention output + # TODO: remove the redundant computation + if is_global_attn: + global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + ) + + # get only non zero global attn output + nonzero_global_attn_output = global_attn_output[ + is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] + ] + + # overwrite values with global attention + attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( + len(is_local_index_global_attn_nonzero[0]), -1 + ) + # The attention weights for tokens with global attention are + # just filler values, they were never used to compute the output. + # Fill with 0 now, the correct values are in 'global_attn_probs'. + attn_probs[is_index_global_attn_nonzero] = 0 + + outputs = (attn_output.transpose(0, 1),) + + if output_attentions: + outputs += (attn_probs,) + + return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs + + @staticmethod + def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): + """pads rows and then flips rows and columns""" + hidden_states_padded = nn.functional.pad( + hidden_states_padded, padding + ) # padding value is not important because it will be overwritten + hidden_states_padded = hidden_states_padded.view( + *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) + ) + return hidden_states_padded + + @staticmethod + def _pad_and_diagonalize(chunked_hidden_states): + """ + shift every row 1 step right, converting columns into diagonals. + + Example: + + ```python + chunked_hidden_states: [ + 0.4983, + 2.6918, + -0.0071, + 1.0492, + -1.8348, + 0.7672, + 0.2986, + 0.0285, + -0.7584, + 0.4206, + -0.0405, + 0.1599, + 2.0514, + -1.1600, + 0.5372, + 0.2629, + ] + window_overlap = num_rows = 4 + ``` + + (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, + -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ + total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() + chunked_hidden_states = nn.functional.pad( + chunked_hidden_states, (0, window_overlap + 1) + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, -1 + ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlap*window_overlap + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim + ) + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + return chunked_hidden_states + + @staticmethod + def _chunk(hidden_states, window_overlap, onnx_export: bool = False): + """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + if not onnx_export: + # non-overlapping chunks of size = 2w + hidden_states = hidden_states.view( + hidden_states.size(0), + torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"), + window_overlap * 2, + hidden_states.size(2), + ) + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(hidden_states.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 + + chunk_stride = list(hidden_states.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) + + # When exporting to ONNX, use this separate logic + # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export + + # TODO replace this with + # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3) + # once `unfold` is supported + # the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow + + chunk_size = [ + hidden_states.size(0), + torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1, + window_overlap * 2, + hidden_states.size(2), + ] + + overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device) + for chunk in range(chunk_size[1]): + overlapping_chunks[:, chunk, :, :] = hidden_states[ + :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, : + ] + return overlapping_chunks + + @staticmethod + def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: + beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) + beginning_mask = beginning_mask_2d[None, :, None, :] + ending_mask = beginning_mask.flip(dims=(1, 3)) + beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] + beginning_mask = beginning_mask.expand(beginning_input.size()) + input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like( + beginning_input, -float("inf") + ).where(beginning_mask.bool(), beginning_input) + ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] + ending_mask = ending_mask.expand(ending_input.size()) + input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like( + ending_input, -float("inf") + ).where(ending_mask.bool(), ending_input) + + def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): + """ + Matrix multiplication of query and key tensors using with a sliding window attention pattern. This + implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained LEDEncoder) with an + overlap of size window_overlap + """ + batch_size, seq_len, num_heads, head_dim = query.size() + assert seq_len % (window_overlap * 2) == 0, ( + f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" + ) + assert query.size() == key.size() + + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + query = self._chunk(query, window_overlap, getattr(self.config, "onnx_export", False)) + key = self._chunk(key, window_overlap, getattr(self.config, "onnx_export", False)) + + # matrix multiplication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap + diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply + + # convert diagonals into columns + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( + diagonal_chunked_attention_scores, padding=(0, 0, 0, 1) + ) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + + diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros( + (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) + ) + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, -1, window_overlap:, : window_overlap + 1 + ] + # - copying the lower triangle + diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ] + + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ + :, 0, : window_overlap - 1, 1 - window_overlap : + ] + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = diagonal_attention_scores.view( + batch_size, num_heads, seq_len, 2 * window_overlap + 1 + ).transpose(2, 1) + + self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores + + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int + ): + """ + Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the + same shape as `attn_probs` + """ + batch_size, seq_len, num_heads, head_dim = value.size() + + assert seq_len % (window_overlap * 2) == 0 + assert attn_probs.size()[:3] == value.size()[:3] + assert attn_probs.size(3) == 2 * window_overlap + 1 + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( + batch_size * num_heads, + torch.div(seq_len, window_overlap, rounding_mode="trunc"), + window_overlap, + 2 * window_overlap + 1, + ) + + # group batch_size and num_heads dimensions into one + value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value_stride = padded_value.stride() + chunked_value_stride = ( + chunked_value_stride[0], + window_overlap * chunked_value_stride[1], + chunked_value_stride[1], + chunked_value_stride[2], + ) + chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) + + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) + + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) + return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """compute global attn indices required throughout forward pass""" + # helper variable + num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = num_global_attn_indices.max() + + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + + # helper variable + is_local_index_global_attn = torch.arange( + max_num_global_attn_indices, device=is_index_global_attn.device + ) < num_global_attn_indices.unsqueeze(dim=-1) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = key_vectors.shape[0] + + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + + key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) + attn_probs_from_global_key[ + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : + ] = torch.finfo(attn_probs_from_global_key.dtype).min + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) + + return attn_probs_from_global_key + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = attn_probs.shape[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) + # get value vectors for global only + value_vectors_only_global = value_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero] + + # use `matmul` because `einsum` crashes sometimes with fp16 + # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) + # compute attn output only global + attn_output_only_global = torch.matmul( + attn_probs_only_global.transpose(1, 2).clone(), value_vectors_only_global.transpose(1, 2).clone() + ).transpose(1, 2) + + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + return attn_output_only_global + attn_output_without_global + + def _compute_global_attn_output_from_hidden( + self, + hidden_states, + max_num_global_attn_indices, + layer_head_mask, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + ): + seq_len, batch_size = hidden_states.shape[:2] + + # prepare global hidden states + global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) + global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ + is_index_global_attn_nonzero[::-1] + ] + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= math.sqrt(self.head_dim) + + # reshape + global_query_vectors_only_global = ( + global_query_vectors_only_global.contiguous() + .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) + global_key_vectors = ( + global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + global_value_vectors = ( + global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + + # compute attn scores + global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2)) + + assert list(global_attn_scores.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + seq_len, + ], ( + "global_attn_scores have the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" + f" {global_attn_scores.size()}." + ) + + global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + global_attn_scores = global_attn_scores.transpose(1, 2) + global_attn_scores[ + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : + ] = torch.finfo(global_attn_scores.dtype).min + global_attn_scores = global_attn_scores.transpose(1, 2) + + global_attn_scores = global_attn_scores.masked_fill( + is_index_masked[:, None, None, :], + torch.finfo(global_attn_scores.dtype).min, + ) + + global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + + # compute global attn probs + global_attn_probs_float = nn.functional.softmax( + global_attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + # apply layer head masking + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_heads,), ( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) + global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view( + batch_size, self.num_heads, max_num_global_attn_indices, seq_len + ) + global_attn_probs_float = global_attn_probs_float.view( + batch_size * self.num_heads, max_num_global_attn_indices, seq_len + ) + + global_attn_probs = nn.functional.dropout( + global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training + ) + + # global attn output + global_attn_output = torch.bmm(global_attn_probs, global_value_vectors) + + assert list(global_attn_output.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + self.head_dim, + ], ( + "global_attn_output tensor has the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" + f" {global_attn_output.size()}." + ) + + global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + global_attn_output = global_attn_output.view( + batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim + ) + return global_attn_output, global_attn_probs + + +class LEDEncoderAttention(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.longformer_self_attn = LEDEncoderSelfAttention(config, layer_id=layer_id) + self.output = nn.Linear(config.d_model, config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + is_index_masked: Optional[torch.Tensor] = None, + is_index_global_attn: Optional[torch.Tensor] = None, + is_global_attn: Optional[bool] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + self_outputs = self.longformer_self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) + + attn_output = self.output(self_outputs[0]) + outputs = (attn_output,) + self_outputs[1:] + + return outputs + + +class LEDDecoderAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_values + + +class LEDEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: LEDConfig, layer_id: int): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LEDEncoderAttention(config, layer_id) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + *(encoder_attention_heads,)*. + """ + residual = hidden_states + attn_outputs = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) + hidden_states = attn_outputs[0] + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + return (hidden_states,) + attn_outputs[1:] + + +class LEDDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: LEDConfig, layer_idx=None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = LEDDecoderAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = LEDDecoderAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape *(batch, seq_len, embed_dim)* + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + *(decoder_attention_heads,)*. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size *(decoder_attention_heads,)*. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`): Whether the base model outputs attentions. + This requires the attentions tensor to be reshaped in this function. + """ + residual = hidden_states + + # Self-Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class LEDClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +@auto_docstring +class LEDPreTrainedModel(PreTrainedModel): + config: LEDConfig + base_model_prefix = "led" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for LEDEncoder's outputs, with potential hidden states, local and global attentions. + """ +) +# Copied from transformers.models.longformer.modeling_longformer.LongformerBaseModelOutput with Longformer->LEDEncoder +class LEDEncoderBaseModelOutput(ModelOutput): + r""" + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: torch.FloatTensor + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class LEDSeq2SeqModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for sequence-to-sequence language models outputs. + """ +) +class LEDSeq2SeqLMOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of sequence-to-sequence sentence classification models. + """ +) +class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of sequence-to-sequence question answering models. + """ +) +class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: Optional[torch.FloatTensor] = None + end_logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class LEDEncoder(LEDPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a + [`LEDEncoderLayer`]. + + Args: + config: LEDConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_encoder_position_embeddings + + if isinstance(config.attention_window, int): + if config.attention_window % 2 != 0: + raise ValueError("`config.attention_window` has to be an even value") + if config.attention_window <= 0: + raise ValueError("`config.attention_window` has to be positive") + config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer + else: + if len(config.attention_window) != config.num_hidden_layers: + raise ValueError( + "`len(config.attention_window)` should equal `config.num_hidden_layers`. " + f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" + ) + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = LEDLearnedPositionalEmbedding( + self.max_source_positions, + embed_dim, + ) + self.layers = nn.ModuleList([LEDEncoderLayer(config, i) for i in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): + # longformer self-attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + return attention_mask + + def _pad_to_window_size( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + inputs_embeds: torch.Tensor, + pad_token_id: int, + ): + """A helper function to pad tokens and mask to work with implementation of Longformer self-attention.""" + # padding + attention_window = ( + self.config.attention_window + if isinstance(self.config.attention_window, int) + else max(self.config.attention_window) + ) + + if attention_window % 2 != 0: + raise ValueError(f"`attention_window` should be an even value. Given {attention_window}") + input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape + batch_size, seq_len = input_shape[:2] + + padding_len = (attention_window - seq_len % attention_window) % attention_window + if padding_len > 0: + logger.warning_once( + f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of " + f"`config.attention_window`: {attention_window}" + ) + if input_ids is not None: + input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id) + if inputs_embeds is not None: + input_ids_padding = inputs_embeds.new_full( + (batch_size, padding_len), + self.config.pad_token_id, + dtype=torch.long, + ) + inputs_embeds_padding = self.embed_tokens(input_ids_padding) + inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) + + attention_mask = nn.functional.pad( + attention_mask, (0, padding_len), value=False + ) # no attention on the padding tokens + + return padding_len, input_ids, attention_mask, inputs_embeds + + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention for the encoder. + Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is + important for task-specific finetuning because it makes the model more flexible at representing the + task. For example, for classification, the token should be given global attention. For QA, all + question tokens should also have global attention. Please refer to the [Longformer + paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # create default attention_mask + if attention_mask is None: + attention_mask = torch.ones(inputs_embeds.size()[:-1], device=inputs_embeds.device, dtype=torch.long) + + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + + # pad input if necessary + padding_len, input_ids, attention_mask, inputs_embeds = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + pad_token_id=self.config.pad_token_id, + ) + + # retrieve input_shape + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + + # convert attention_mask to float + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, seq_len]; 1 -> 0.0; 0 -> "-inf" + attention_mask = _prepare_4d_attention_mask_inverted(attention_mask, inputs_embeds.dtype)[:, 0, 0, :] + + # get masking tensors + is_index_masked = attention_mask < 0 + is_index_global_attn = attention_mask > 0 + is_global_attn = is_index_global_attn.flatten().any().item() + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_global_attentions = () if (output_attentions and is_global_attn) else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = torch.rand([]) + + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) + all_attentions = all_attentions + (layer_outputs[1].transpose(1, 2),) + + if is_global_attn: + # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn + all_global_attentions = all_global_attentions + (layer_outputs[2].transpose(2, 3),) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + # undo padding + if padding_len > 0: + # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) + hidden_states = hidden_states[:, :-padding_len] + if output_hidden_states: + encoder_states = tuple(state[:, :-padding_len] for state in encoder_states) + + if output_attentions: + all_attentions = tuple(state[:, :, :-padding_len, :] for state in all_attentions) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions, all_global_attentions] if v is not None + ) + return LEDEncoderBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + global_attentions=all_global_attentions, + ) + + +class LEDDecoder(LEDPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`LEDDecoderLayer`] + + Args: + config: LEDConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_decoder_position_embeddings + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = LEDLearnedPositionalEmbedding( + self.max_target_positions, + config.d_model, + ) + self.layers = nn.ModuleList([LEDDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + global_attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with + global attention attends to all other tokens, and all other tokens attend to them. This is important + for task-specific finetuning because it makes the model more flexible at representing the task. For + example, for classification, the token should be given global attention. For QA, all question + tokens should also have global attention. Please refer to the [Longformer + paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _create_4d_causal_attention_mask( + input_shape, inputs_embeds.dtype, inputs_embeds.device, past_key_values_length=past_key_values_length + ) + + if attention_mask is not None and combined_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask_inverted( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_inverted( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input_shape, past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + combined_attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class LEDModel(LEDPreTrainedModel): + _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] + + def __init__(self, config: LEDConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = LEDEncoder(config, self.shared) + self.decoder = LEDDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + global_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], LEDSeq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify + to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more information on the + default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention for the encoder. + Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is + important for task-specific finetuning because it makes the model more flexible at representing the task. + For example, for classification, the token should be given global attention. For QA, all question + tokens should also have global attention. Please refer to the [Longformer + paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Using this like Bart, as LED is derived from it. So far + # No checkpoint on the hub exists that uses that in practice. + # https://github.com/huggingface/transformers/blob/ac3cb660cad283163f7c73cad511124e845ca388/src/transformers/models/bart/modeling_bart.py#L1153 + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a LEDEncoderBaseModelOutput when return_dict=False + elif return_dict and not isinstance(encoder_outputs, LEDEncoderBaseModelOutput): + encoder_outputs = LEDEncoderBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + global_attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return LEDSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_global_attentions=encoder_outputs.global_attentions, + ) + + +@auto_docstring( + custom_intro=""" + The LED Model with a language modeling head. Can be used for summarization. + """ +) +class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin): + base_model_prefix = "led" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: LEDConfig): + super().__init__(config) + self.led = LEDModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.led.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.led.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.led.get_encoder() + + def get_decoder(self): + return self.led.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + global_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], LEDSeq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify + to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more information on the + default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention for the encoder. + Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is + important for task-specific finetuning because it makes the model more flexible at representing the task. + For example, for classification, the token should be given global attention. For QA, all question + tokens should also have global attention. Please refer to the [Longformer + paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example Summarization: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, LEDForConditionalGeneration + + >>> model = LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/led-large-16384-arxiv") + + >>> ARTICLE_TO_SUMMARIZE = '''Transformers (Vaswani et al., 2017) have achieved state-of-the-art + ... results in a wide range of natural language tasks including generative language modeling + ... (Dai et al., 2019; Radford et al., 2019) and discriminative ... language understanding (Devlin et al., 2019). + ... This success is partly due to the self-attention component which enables the network to capture contextual + ... information from the entire sequence. While powerful, the memory and computational requirements of + ... self-attention grow quadratically with sequence length, making it infeasible (or very expensive) to + ... process long sequences. To address this limitation, we present Longformer, a modified Transformer + ... architecture with a self-attention operation that scales linearly with the sequence length, making it + ... versatile for processing long documents (Fig 1). This is an advantage for natural language tasks such as + ... long document classification, question answering (QA), and coreference resolution, where existing approaches + ... partition or shorten the long context into smaller sequences that fall within the typical 512 token limit + ... of BERT-style pretrained models. Such partitioning could potentially result in loss of important + ... cross-partition information, and to mitigate this problem, existing methods often rely on complex + ... architectures to address such interactions. On the other hand, our proposed Longformer is able to build + ... contextual representations of the entire context using multiple layers of attention, reducing the need for + ... task-specific architectures.''' + >>> inputs = tokenizer.encode(ARTICLE_TO_SUMMARIZE, return_tensors="pt") + + >>> # Global attention on the first token (cf. Beltagy et al. 2020) + >>> global_attention_mask = torch.zeros_like(inputs) + >>> global_attention_mask[:, 0] = 1 + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs, global_attention_mask=global_attention_mask, num_beams=3, max_length=32) + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)) + ``` + + Example Conditional generation : + + ```python + >>> from transformers import AutoTokenizer, LEDForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384") + >>> TXT = "My friends are but they eat too many carbs." + + >>> model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384") + >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] + + >>> prediction = model.generate(input_ids)[0] + >>> print(tokenizer.decode(prediction, skip_special_tokens=True)) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.led( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return LEDSeq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + encoder_global_attentions=outputs.encoder_global_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + +@auto_docstring( + custom_intro=""" + LED model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """ +) +class LEDForSequenceClassification(LEDPreTrainedModel): + _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] + + def __init__(self, config: LEDConfig, **kwargs): + warnings.warn( + "The `transformers.LEDForSequenceClassification` class is deprecated and will be removed in version 5 of" + " Transformers. No actual method were provided in the original paper on how to perform" + " sequence classification.", + FutureWarning, + ) + super().__init__(config, **kwargs) + self.led = LEDModel(config) + self.classification_head = LEDClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + global_attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], LEDSeq2SeqSequenceClassifierOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify + to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more information on the + default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention for the encoder. + Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is + important for task-specific finetuning because it makes the model more flexible at representing the task. + For example, for classification, the token should be given global attention. For QA, all question + tokens should also have global attention. Please refer to the [Longformer + paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.led( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return LEDSeq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + encoder_global_attentions=outputs.encoder_global_attentions, + ) + + +@auto_docstring +class LEDForQuestionAnswering(LEDPreTrainedModel): + _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.led = LEDModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + global_attention_mask: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], LEDSeq2SeqQuestionAnsweringModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify + to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more information on the + default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention for the encoder. + Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is + important for task-specific finetuning because it makes the model more flexible at representing the task. + For example, for classification, the token should be given global attention. For QA, all question + tokens should also have global attention. Please refer to the [Longformer + paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.led( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return LEDSeq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + encoder_global_attentions=outputs.encoder_global_attentions, + ) + + +__all__ = [ + "LEDForConditionalGeneration", + "LEDForQuestionAnswering", + "LEDForSequenceClassification", + "LEDModel", + "LEDPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/led/modeling_tf_led.py b/venv/lib/python3.13/site-packages/transformers/models/led/modeling_tf_led.py new file mode 100644 index 0000000000000000000000000000000000000000..f499ffac30c9f8324a3d16c1bb4a0ad9638337bc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/led/modeling_tf_led.py @@ -0,0 +1,2663 @@ +# coding=utf-8 +# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 LED model.""" + +from __future__ import annotations + +import random +from dataclasses import dataclass + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions + +# Public API +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_led import LEDConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "allenai/led-base-16384" +_CONFIG_FOR_DOC = "LEDConfig" + + +LARGE_NEGATIVE = -1e8 + + +# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + start_tokens = tf.fill( + (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) + ) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE + mask_cond = tf.range(shape_list(mask)[-1]) + + mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +class TFLEDLearnedPositionalEmbedding(keras.layers.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): + super().__init__(num_embeddings, embedding_dim, **kwargs) + + def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + """Input is expected to be of size [bsz x seqlen].""" + seq_len = input_shape[1] + position_ids = tf.range(seq_len, delta=1, name="range") + position_ids += past_key_values_length + + return super().call(tf.cast(position_ids, dtype=tf.int32)) + + +# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerSelfAttention with TFLongformer->TFLEDEncoder +class TFLEDEncoderSelfAttention(keras.layers.Layer): + def __init__(self, config, layer_id, **kwargs): + super().__init__(**kwargs) + self.config = config + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads}" + ) + + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + self.query = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query", + ) + self.key = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key", + ) + self.value = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value", + ) + + # separate projection layers for tokens with global attention + self.query_global = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query_global", + ) + self.key_global = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key_global", + ) + self.value_global = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value_global", + ) + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.layer_id = layer_id + attention_window = config.attention_window[self.layer_id] + + assert attention_window % 2 == 0, ( + f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" + ) + assert attention_window > 0, ( + f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" + ) + + self.one_sided_attn_window_size = attention_window // 2 + + def build(self, input_shape=None): + if not self.built: + with tf.name_scope("query_global"): + self.query_global.build((self.config.hidden_size,)) + with tf.name_scope("key_global"): + self.key_global.build((self.config.hidden_size,)) + with tf.name_scope("value_global"): + self.value_global.build((self.config.hidden_size,)) + + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + if getattr(self, "query_global", None) is not None: + with tf.name_scope(self.query_global.name): + self.query_global.build([None, None, self.config.hidden_size]) + if getattr(self, "key_global", None) is not None: + with tf.name_scope(self.key_global.name): + self.key_global.build([None, None, self.config.hidden_size]) + if getattr(self, "value_global", None) is not None: + with tf.name_scope(self.value_global.name): + self.value_global.build([None, None, self.config.hidden_size]) + + def call( + self, + inputs, + training=False, + ): + """ + LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to + *attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer. + + The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to: + + - -10000: no attention + - 0: local attention + - +10000: global attention + """ + # retrieve input args + ( + hidden_states, + attention_mask, + layer_head_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ) = inputs + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + batch_size, seq_len, embed_dim = shape_list(hidden_states) + + tf.debugging.assert_equal( + embed_dim, + self.embed_dim, + message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", + ) + + # normalize query + query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) + query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + + # attn_probs = (batch_size, seq_len, num_heads, window*2+1) + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = attention_mask != 0 + # cast to fp32/fp16 then replace 1's with -inf + float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE + + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + tf.ones(shape_list(attention_mask)), + float_mask, + self.one_sided_attn_window_size, + ) + + # pad local attention probs + attn_scores += diagonal_mask + + tf.debugging.assert_equal( + shape_list(attn_scores), + [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], + message=( + f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," + f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}" + ), + ) + + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + + # this function is only relevant for global attention + if is_global_attn: + attn_scores = self._concat_with_global_key_attn_probs( + attn_scores=attn_scores, + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + + attn_probs = stable_softmax(attn_scores, axis=-1) + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + if is_global_attn: + masked_index = tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ) + else: + masked_index = tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ) + attn_probs = tf.where( + masked_index, + tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), + attn_probs, + ) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs + + # apply dropout + attn_probs = self.dropout(attn_probs, training=training) + value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + + # if global attention, compute sum of global and local attn + + if is_global_attn: + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + tf.debugging.assert_equal( + shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" + ) + + attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) + + # compute value for global attention and overwrite to attention output + if is_global_attn: + attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( + attn_output=attn_output, + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + training=training, + ) + else: + # Leave attn_output unchanged + global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len)) + + # make sure that local attention probabilities are set to 0 for indices of global attn + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + if is_global_attn: + masked_global_attn_index = tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ) + else: + masked_global_attn_index = tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ) + attn_probs = tf.where( + masked_global_attn_index, + tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), + attn_probs, + ) + + outputs = (attn_output, attn_probs, global_attn_probs) + + return outputs + + def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): + """ + Matrix multiplication of query and key tensors using with a sliding window attention pattern. This + implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an + overlap of size window_overlap + """ + batch_size, seq_len, num_heads, head_dim = shape_list(query) + + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), + 0, + message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", + ) + tf.debugging.assert_equal( + shape_list(query), + shape_list(key), + message=( + f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:" + f" {shape_list(key)}" + ), + ) + + chunks_count = seq_len // window_overlap - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = tf.reshape( + tf.transpose(query, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) + key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + chunked_query = self._chunk(query, window_overlap) + chunked_key = self._chunk(key, window_overlap) + + # matrix multiplication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap + chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype) + chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply + + # convert diagonals into columns + paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]]) + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + # TODO: This code is most likely not very efficient and should be improved + diagonal_attn_scores_up_triang = tf.concat( + [ + diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], + diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], + ], + axis=1, + ) + + # - copying the lower triangle + diagonal_attn_scores_low_triang = tf.concat( + [ + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), + diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], + ], + axis=1, + ) + diagonal_attn_scores_first_chunk = tf.concat( + [ + tf.roll( + diagonal_chunked_attention_scores, + shift=[1, window_overlap], + axis=[2, 3], + )[:, :, :window_overlap, :window_overlap], + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), + ], + axis=1, + ) + first_chunk_mask = ( + tf.tile( + tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None], + (batch_size * num_heads, 1, window_overlap, window_overlap), + ) + < 1 + ) + diagonal_attn_scores_low_triang = tf.where( + first_chunk_mask, + diagonal_attn_scores_first_chunk, + diagonal_attn_scores_low_triang, + ) + + # merging upper and lower triangle + diagonal_attention_scores = tf.concat( + [diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1 + ) + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = tf.transpose( + tf.reshape( + diagonal_attention_scores, + (batch_size, num_heads, seq_len, 2 * window_overlap + 1), + ), + (0, 2, 1, 3), + ) + + diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + + return diagonal_attention_scores + + @staticmethod + def _mask_invalid_locations(input_tensor, window_overlap): + # create correct upper triangle bool mask + mask_2d_upper = tf.reverse( + tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), + axis=[0], + ) + + # pad to full matrix + padding = tf.convert_to_tensor( + [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] + ) + + # create lower mask + mask_2d = tf.pad(mask_2d_upper, padding) + + # combine with upper mask + mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) + + # broadcast to full matrix + mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) + + # inf tensor used for masking + inf_tensor = -float("inf") * tf.ones_like(input_tensor) + + # mask + input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) + + return input_tensor + + def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): + """ + Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the + same shape as `attn_probs` + """ + + batch_size, seq_len, num_heads, head_dim = shape_list(value) + + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[:3], + shape_list(value)[:3], + message="value and attn_probs must have same dims (except head_dim)", + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[3], + 2 * window_overlap + 1, + message="attn_probs last dim has to be 2 * window_overlap + 1", + ) + + chunks_count = seq_len // window_overlap - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + chunked_attn_probs = tf.reshape( + tf.transpose(attn_probs, (0, 2, 1, 3)), + ( + batch_size * num_heads, + seq_len // window_overlap, + window_overlap, + 2 * window_overlap + 1, + ), + ) + + # group batch_size and num_heads dimensions into one + value = tf.reshape( + tf.transpose(value, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]]) + padded_value = tf.pad(value, paddings, constant_values=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + frame_size = 3 * window_overlap * head_dim + frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count + chunked_value = tf.signal.frame( + tf.reshape(padded_value, (batch_size * num_heads, -1)), + frame_size, + frame_hop_size, + ) + chunked_value = tf.reshape( + chunked_value, + (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), + ) + + tf.debugging.assert_equal( + shape_list(chunked_value), + [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], + message="Chunked value has the wrong shape", + ) + + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) + context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) + context = tf.transpose( + tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), + (0, 2, 1, 3), + ) + + return context + + @staticmethod + def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): + """pads rows and then flips rows and columns""" + hidden_states_padded = tf.pad( + hidden_states_padded, paddings + ) # padding value is not important because it will be overwritten + batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) + hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) + + return hidden_states_padded + + @staticmethod + def _pad_and_diagonalize(chunked_hidden_states): + """ + shift every row 1 step right, converting columns into diagonals. + + Example: + + ```python + chunked_hidden_states: [ + 0.4983, + 2.6918, + -0.0071, + 1.0492, + -1.8348, + 0.7672, + 0.2986, + 0.0285, + -0.7584, + 0.4206, + -0.0405, + 0.1599, + 2.0514, + -1.1600, + 0.5372, + 0.2629, + ] + window_overlap = num_rows = 4 + ``` + + (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, + -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ + total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) + paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) + chunked_hidden_states = tf.pad( + chunked_hidden_states, paddings + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = tf.reshape( + chunked_hidden_states, (total_num_heads, num_chunks, -1) + ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap + chunked_hidden_states = tf.reshape( + chunked_hidden_states, + (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), + ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + + return chunked_hidden_states + + @staticmethod + def _chunk(hidden_states, window_overlap): + """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + batch_size, seq_length, hidden_dim = shape_list(hidden_states) + num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1 + + # define frame size and frame stride (similar to convolution) + frame_hop_size = window_overlap * hidden_dim + frame_size = 2 * frame_hop_size + hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) + + # chunk with overlap + chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) + + tf.debugging.assert_equal( + shape_list(chunked_hidden_states), + [batch_size, num_output_chunks, frame_size], + message=( + "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension" + f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}." + ), + ) + + chunked_hidden_states = tf.reshape( + chunked_hidden_states, + (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), + ) + + return chunked_hidden_states + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """compute global attn indices required throughout forward pass""" + # helper variable + num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) + num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) + + # max number of global attn indices in batch + max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) + + # indices of global attn + is_index_global_attn_nonzero = tf.where(is_index_global_attn) + + # helper variable + is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims( + num_global_attn_indices, axis=-1 + ) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn)) + + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + attn_scores, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = shape_list(key_vectors)[0] + + # select global key vectors + global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) + + # create only global key vectors + key_vectors_only_global = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_key_vectors, + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), + ) + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) + + # (batch_size, max_num_global_attn_indices, seq_len, num_heads) + attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) + mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( + shape_list(attn_probs_from_global_key_trans)[-2:] + ) + mask = tf.ones(mask_shape) * -10000.0 + mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) + + # scatter mask + attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( + attn_probs_from_global_key_trans, + is_local_index_no_global_attn_nonzero, + mask, + ) + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) + + # concat to attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1) + + return attn_scores + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = shape_list(attn_probs)[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices] + + # select global value vectors + global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) + + # create only global value vectors + value_vectors_only_global = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_value_vectors, + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), + ) + + # compute attn output only global + attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) + + # reshape attn probs + attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + + return attn_output_only_global + attn_output_without_global + + def _compute_global_attn_output_from_hidden( + self, + attn_output, + hidden_states, + max_num_global_attn_indices, + layer_head_mask, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + training, + ): + batch_size, seq_len = shape_list(hidden_states)[:2] + + # prepare global hidden states + global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) + global_attn_hidden_states = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_attn_hidden_states, + shape=(batch_size, max_num_global_attn_indices, self.embed_dim), + ) + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= tf.math.sqrt( + tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype) + ) + global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) + global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) + global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) + + # compute attn scores + global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(global_attn_scores), + [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], + message=( + "global_attn_scores have the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" + f" {shape_list(global_attn_scores)}." + ), + ) + + global_attn_scores = tf.reshape( + global_attn_scores, + (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), + ) + global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) + mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( + shape_list(global_attn_scores_trans)[-2:] + ) + global_attn_mask = tf.ones(mask_shape) * -10000.0 + global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) + + # scatter mask + global_attn_scores_trans = tf.tensor_scatter_nd_update( + global_attn_scores_trans, + is_local_index_no_global_attn_nonzero, + global_attn_mask, + ) + global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) + + # mask global attn scores + attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1)) + global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) + global_attn_scores = tf.reshape( + global_attn_scores, + (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), + ) + + # compute global attn probs + global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1) + + # apply layer head masking + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + ) + global_attn_probs_float = tf.reshape( + global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + ) + + # dropout + global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) + + # global attn output + global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) + + tf.debugging.assert_equal( + shape_list(global_attn_output), + [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], + message=( + "global_attn_output tensor has the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" + f" {shape_list(global_attn_output)}." + ), + ) + + global_attn_output = tf.reshape( + global_attn_output, + (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), + ) + + # get only non zero global attn output + nonzero_global_attn_output = tf.gather_nd( + tf.transpose(global_attn_output, (0, 2, 1, 3)), + is_local_index_global_attn_nonzero, + ) + nonzero_global_attn_output = tf.reshape( + nonzero_global_attn_output, + (shape_list(is_local_index_global_attn_nonzero)[0], -1), + ) + + # overwrite values with global attention + attn_output = tf.tensor_scatter_nd_update( + attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output + ) + + global_attn_probs = tf.reshape( + global_attn_probs, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + ) + + return attn_output, global_attn_probs + + def reshape_and_transpose(self, vector, batch_size): + return tf.reshape( + tf.transpose( + tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), + (0, 2, 1, 3), + ), + (batch_size * self.num_heads, -1, self.head_dim), + ) + + +class TFLEDEncoderAttention(keras.layers.Layer): + def __init__(self, config, layer_id, **kwargs): + super().__init__(**kwargs) + self.longformer_self_attn = TFLEDEncoderSelfAttention(config, layer_id=layer_id, name="longformer_self_attn") + self.output_dense = keras.layers.Dense(config.d_model, use_bias=True, name="output") + self.config = config + + def call(self, inputs, training=False): + ( + hidden_states, + attention_mask, + layer_head_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ) = inputs + + self_outputs = self.longformer_self_attn( + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], + training=training, + ) + + attention_output = self.output_dense(self_outputs[0], training=training) + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "longformer_self_attn", None) is not None: + with tf.name_scope(self.longformer_self_attn.name): + self.longformer_self_attn.build(None) + if getattr(self, "output_dense", None) is not None: + with tf.name_scope(self.output_dense.name): + self.output_dense.build([None, None, self.config.d_model]) + + +class TFLEDDecoderAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: tuple[tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training=False, + ) -> tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast( + attention_mask, dtype=attn_weights.dtype + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +class TFLEDEncoderLayer(keras.layers.Layer): + def __init__(self, config: LEDConfig, layer_id: int, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFLEDEncoderAttention(config, layer_id, name="self_attn") + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + layer_head_mask: tf.Tensor, + is_index_masked: tf.Tensor, + is_index_global_attn: tf.Tensor, + is_global_attn: bool, + training=False, + ): + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(config.encoder_attention_heads,)*. + """ + residual = hidden_states + layer_outputs = self.self_attn( + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], + training=training, + ) + + hidden_states = layer_outputs[0] + + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return (hidden_states,) + layer_outputs[1:] + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.encoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFLEDDecoderLayer(keras.layers.Layer): + def __init__(self, config: LEDConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFLEDDecoderAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.encoder_attn = TFLEDDecoderAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + name="encoder_attn", + is_decoder=True, + ) + self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states, + attention_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + encoder_layer_head_mask: tf.Tensor | None = None, + past_key_value: tuple[tf.Tensor] | None = None, + training=False, + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + encoder_hidden_states (`tf.Tensor`): + cross attention input to the layer of shape *(batch, seq_len, embed_dim)* + encoder_attention_mask (`tf.Tensor`): encoder attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(config.encoder_attention_heads,)*. + encoder_layer_head_mask (`tf.Tensor`): mask for encoder attention heads in a given layer of + size *(config.encoder_attention_heads,)*. + past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states + """ + residual = hidden_states + + # Self-Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, + past_key_value=cross_attn_past_key_value, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return ( + hidden_states, + self_attn_weights, + cross_attn_weights, + present_key_value, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "encoder_attn", None) is not None: + with tf.name_scope(self.encoder_attn.name): + self.encoder_attn.build(None) + if getattr(self, "encoder_attn_layer_norm", None) is not None: + with tf.name_scope(self.encoder_attn_layer_norm.name): + self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.decoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFLEDPreTrainedModel(TFPreTrainedModel): + config_class = LEDConfig + base_model_prefix = "led" + + @property + def input_signature(self): + sig = super().input_signature + sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask") + return sig + + +@dataclass +# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder +class TFLEDEncoderBaseModelOutput(ModelOutput): + """ + Base class for Longformer's outputs, with potential hidden states, local and global attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLEDSeq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + decoder_hidden_states: tuple[tf.Tensor, ...] | None = None + decoder_attentions: tuple[tf.Tensor, ...] | None = None + cross_attentions: tuple[tf.Tensor, ...] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: tuple[tf.Tensor, ...] | None = None + encoder_attentions: tuple[tf.Tensor, ...] | None = None + encoder_global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLEDSeq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + decoder_hidden_states: tuple[tf.Tensor, ...] | None = None + decoder_attentions: tuple[tf.Tensor, ...] | None = None + cross_attentions: tuple[tf.Tensor, ...] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: tuple[tf.Tensor, ...] | None = None + encoder_attentions: tuple[tf.Tensor, ...] | None = None + encoder_global_attentions: tuple[tf.Tensor, ...] | None = None + + +LED_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`LEDConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +LED_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tf.Tensor`, *optional*): + hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + of shape `(batch_size, sequence_length, hidden_size)` is a sequence of + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFLEDEncoder(keras.layers.Layer): + config_class = LEDConfig + """ + Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a + [`TFLEDEncoderLayer`]. + + Args: + config: LEDConfig + """ + + def __init__(self, config: LEDConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.dropout = keras.layers.Dropout(config.dropout) + if config.encoder_layerdrop > 0: + logger.warning("Layerdrop is currently disabled in TFLED models.") + self.layerdrop = 0.0 + self.padding_idx = config.pad_token_id + + if isinstance(config.attention_window, int): + assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" + assert config.attention_window > 0, "`config.attention_window` has to be positive" + config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer + else: + assert len(config.attention_window) == config.num_hidden_layers, ( + "`len(config.attention_window)` should equal `config.num_hidden_layers`. " + f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" + ) + + self.attention_window = config.attention_window + self.embed_tokens = embed_tokens + self.embed_positions = TFLEDLearnedPositionalEmbedding( + config.max_encoder_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)] + self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") + self.embed_dim = config.d_model + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + global_attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + """ + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(input_shape, 1) + + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype) + + padding_len, input_ids, attention_mask, inputs_embeds = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + pad_token_id=self.padding_idx, + ) + + input_shape = shape_list(attention_mask) + # is index masked or global attention + is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1) + is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1) + is_global_attn = tf.math.reduce_any(is_index_global_attn) + + embed_pos = self.embed_positions(input_shape) + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + + # check attention mask and invert + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask)[:, 0, 0, :] + attention_mask = attention_mask[:, :, None, None] + + encoder_states = () if output_hidden_states else None + all_attentions = all_global_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + tf.debugging.assert_equal( + shape_list(head_mask)[0], + len(self.layers), + message=( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(head_mask)[0]}." + ), + ) + + # encoder layers + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len) + encoder_states = encoder_states + (hidden_states_to_add,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if training and (dropout_probability < self.layerdrop): # skip the layer + continue + + layer_outputs = encoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) + all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) + + # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn + all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),) + + # undo padding + # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) + hidden_states = self.compute_hidden_states(hidden_states, padding_len) + + # undo padding + if output_attentions: + all_attentions = ( + tuple(state[:, :, :-padding_len, :] for state in all_attentions) if padding_len > 0 else all_attentions + ) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFLEDEncoderBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + global_attentions=all_global_attentions, + ) + + @tf.function + def compute_hidden_states(self, hidden_states, padding_len): + return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + + def _pad_to_window_size( + self, + input_ids, + attention_mask, + inputs_embeds, + pad_token_id, + ): + """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" + # padding + attention_window = ( + self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window) + ) + + assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" + + input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) + batch_size, seq_len = input_shape[:2] + padding_len = (attention_window - seq_len % attention_window) % attention_window + + if padding_len > 0: + logger.warning_once( + f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of " + f"`config.attention_window`: {attention_window}" + ) + + paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) + + if input_ids is not None: + input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) + + if inputs_embeds is not None: + if padding_len > 0: + input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id) + inputs_embeds_padding = self.embed_tokens(input_ids_padding) + inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) + + attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens + + return ( + padding_len, + input_ids, + attention_mask, + inputs_embeds, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layernorm_embedding", None) is not None: + with tf.name_scope(self.layernorm_embedding.name): + self.layernorm_embedding.build([None, None, self.embed_dim]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFLEDDecoder(keras.layers.Layer): + config_class = LEDConfig + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFLEDDecoderLayer`] + + Args: + config: LEDConfig + embed_tokens: output embedding + """ + + def __init__(self, config: LEDConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.embed_tokens = embed_tokens + if config.decoder_layerdrop > 0: + logger.warning("Layerdrop is currently disabled in TFLED models.") + self.layerdrop = 0.0 + self.embed_positions = TFLEDLearnedPositionalEmbedding( + config.max_decoder_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.layers = [TFLEDDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] + self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") + + self.dropout = keras.layers.Dropout(config.dropout) + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. If `past_key_values` are used, the user can optionally input only the last + `decoder_input_ids` (those that don't have their past key value states given to this model) of shape + `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + # embed positions + positions = self.embed_positions(input_shape, past_key_values_length) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] + ) + + if attention_mask is not None and input_shape[-1] > 1: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) + + hidden_states = self.layernorm_embedding(hidden_states + positions) + hidden_states = self.dropout(hidden_states, training=training) + + # decoder layers + all_hidden_states = () + all_self_attns = () + all_cross_attentions = () + present_key_values = () + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + tf.debugging.assert_equal( + shape_list(head_mask)[0], + len(self.layers), + message=( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(head_mask)[0]}." + ), + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + + if training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, + past_key_value=past_key_value, + ) + + if use_cache: + present_key_values += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + all_cross_attentions += (layer_cross_attn,) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + else: + all_hidden_states = None + + all_self_attns = all_self_attns if output_attentions else None + all_cross_attentions = all_cross_attentions if output_attentions else None + + present_key_values = present_key_values if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + else: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layernorm_embedding", None) is not None: + with tf.name_scope(self.layernorm_embedding.name): + self.layernorm_embedding.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFLEDMainLayer(keras.layers.Layer): + config_class = LEDConfig + + def __init__(self, config: LEDConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.shared = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.d_model, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), + name="led.shared", + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "led.shared" + + self.encoder = TFLEDEncoder(config, self.shared, name="encoder") + self.decoder = TFLEDDecoder(config, self.shared, name="decoder") + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + encoder_outputs: tuple | TFLEDEncoderBaseModelOutput | None = None, + global_attention_mask=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + if decoder_input_ids is None and decoder_inputs_embeds is None: + use_cache = False + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput): + encoder_outputs = TFLEDEncoderBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return TFLEDSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_global_attentions=encoder_outputs.global_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare LED Model outputting raw hidden-states without any specific head on top.", + LED_START_DOCSTRING, +) +class TFLEDModel(TFLEDPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.led = TFLEDMainLayer(config, name="led") + + def get_encoder(self): + return self.led.encoder + + def get_decoder(self): + return self.led.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFLEDSeq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + decoder_input_ids: tf.Tensor | None = None, + decoder_attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + decoder_head_mask: tf.Tensor | None = None, + encoder_outputs: tf.Tensor | None = None, + global_attention_mask: tf.Tensor | None = None, + past_key_values: tuple[tuple[tf.Tensor]] | None = None, + inputs_embeds: tf.Tensor | None = None, + decoder_inputs_embeds: tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> tuple[tf.Tensor] | TFLEDSeq2SeqModelOutput: + outputs = self.led( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None + + return TFLEDSeq2SeqModelOutput( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + encoder_global_attentions=enc_g_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "led", None) is not None: + with tf.name_scope(self.led.name): + self.led.build(None) + + +# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer +class BiasLayer(keras.layers.Layer): + """ + Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, + so all weights have to be registered in a layer. + """ + + def __init__(self, shape, initializer, trainable, name, **kwargs): + super().__init__(name=name, **kwargs) + # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of + # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: + # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 + self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) + + def call(self, x): + return x + self.bias + + +@add_start_docstrings( + "The LED Model with a language modeling head. Can be used for summarization.", + LED_START_DOCSTRING, +) +class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + r"led.encoder.embed_tokens.weight", + r"led.decoder.embed_tokens.weight", + ] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.led = TFLEDMainLayer(config, name="led") + self.use_cache = config.use_cache + # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False + ) + + # TODO (Joao): investigate why LED has numerical issues in XLA generate + self.supports_xla_generation = False + + def get_decoder(self): + return self.led.decoder + + def get_encoder(self): + return self.led.encoder + + def get_bias(self): + return {"final_logits_bias": self.bias_layer.bias} + + def set_bias(self, value): + # Replaces the existing layers containing bias for correct (de)serialization. + vocab_size = value["final_logits_bias"].shape[-1] + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False + ) + self.bias_layer.bias.assign(value["final_logits_bias"]) + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + @unpack_inputs + @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: TFLEDEncoderBaseModelOutput | None = None, + global_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool = False, + ) -> tuple[tf.Tensor] | TFLEDSeq2SeqLMOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TFLEDForConditionalGeneration + >>> import tensorflow as tf + + >>> mname = "allenai/led-base-16384" + >>> tokenizer = AutoTokenizer.from_pretrained(mname) + >>> TXT = "My friends are but they eat too many carbs." + >>> model = TFLEDForConditionalGeneration.from_pretrained(mname) + >>> batch = tokenizer([TXT], return_tensors="tf") + >>> logits = model(inputs=batch.input_ids).logits + >>> probs = tf.nn.softmax(logits[0]) + >>> # probs[5] is associated with the mask token + ```""" + + if labels is not None: + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.led( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + lm_logits = tf.matmul(outputs[0], self.led.shared.weights, transpose_b=True) + lm_logits = self.bias_layer(lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return TFLEDSeq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, # index 1 of d outputs + decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs + decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs + encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs + encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out + encoder_attentions=outputs.encoder_attentions, # 2 of e out + encoder_global_attentions=outputs.encoder_global_attentions, + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None + + return TFLEDSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + encoder_global_attentions=enc_g_attns, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def hf_compute_loss(self, labels, logits): + """CrossEntropyLoss that ignores pad tokens""" + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if self.config.tf_legacy_loss: + melted_labels = tf.reshape(labels, (-1,)) + active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) + reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) + labels = tf.boolean_mask(melted_labels, active_loss) + return loss_fn(labels, reduced_logits) + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_loss = loss_fn(tf.nn.relu(labels), logits) + # make sure only non-padding labels affect the loss + loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype) + masked_loss = unmasked_loss * loss_mask + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) + return tf.reshape(reduced_masked_loss, (1,)) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "led", None) is not None: + with tf.name_scope(self.led.name): + self.led.build(None) + if getattr(self, "bias_layer", None) is not None: + with tf.name_scope(self.bias_layer.name): + self.bias_layer.build(None) + + +__all__ = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/led/tokenization_led.py b/venv/lib/python3.13/site-packages/transformers/models/led/tokenization_led.py new file mode 100644 index 0000000000000000000000000000000000000000..d110ac30d969e1958a872e8bdc186df0822a25b5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/led/tokenization_led.py @@ -0,0 +1,454 @@ +# coding=utf-8 +# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for LED.""" + +import json +import os +from functools import lru_cache +from typing import Optional, Union + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_base import BatchEncoding, EncodedInput +from ...utils import PaddingStrategy, logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} + +# See all LED models at https://huggingface.co/models?filter=LED + + +@lru_cache +# Copied from transformers.models.bart.tokenization_bart.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.bart.tokenization_bart.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class LEDTokenizer(PreTrainedTokenizer): + """ + Constructs a LED tokenizer, which is smilar to the ROBERTa tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import LEDTokenizer + + >>> tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (BART tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.__init__ + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.vocab_size + def vocab_size(self): + return len(self.encoder) + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer._tokenize + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.build_inputs_with_special_tokens with BART->LED + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A LED sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.create_token_type_ids_from_sequences with BART->LED + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. LED does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.prepare_for_tokenization + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + encoded_inputs = super()._pad( + encoded_inputs=encoded_inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_attention_mask and "global_attention_mask" in encoded_inputs: + required_input = encoded_inputs[self.model_input_names[0]] + # `global_attention_mask` need to have the same length as other (sequential) inputs. + needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input) + + if needs_to_be_padded: + difference = len(required_input) - len(encoded_inputs["global_attention_mask"]) + + if self.padding_side == "right": + # Use `-1` since `0` in `global_attention_mask` means `local attention` instead of `not to attend` + encoded_inputs["global_attention_mask"] = ( + encoded_inputs["global_attention_mask"] + [-1] * difference + ) + elif self.padding_side == "left": + encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[ + "global_attention_mask" + ] + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return encoded_inputs + + +__all__ = ["LEDTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/led/tokenization_led_fast.py b/venv/lib/python3.13/site-packages/transformers/models/led/tokenization_led_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..baea10f23516fef914d70fef87bd3d17bce19baa --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/led/tokenization_led_fast.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for LED.""" + +import json +from typing import Optional, Union + +from tokenizers import processors + +from ...tokenization_utils_base import AddedToken, BatchEncoding, EncodedInput +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import PaddingStrategy, logging +from .tokenization_led import LEDTokenizer + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +class LEDTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" LED tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer, + using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import LEDTokenizerFast + + >>> tokenizer = LEDTokenizerFast.from_pretrained("allenai/led-base-16384") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (LED tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = LEDTokenizer + model_input_names = ["input_ids", "attention_mask"] + + # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.__init__ + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + trim_offsets=True, + **kwargs, + ): + # we have to specify that this tokens is special otherwise adding it will reset the normalized flag to `False` in `add_special_tokens` + mask_token = ( + AddedToken(mask_token, lstrip=True, normalized=True, special=True) + if isinstance(mask_token, str) + else mask_token + ) + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + **kwargs, + ) + + # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + @property + # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.mask_token with BART->LED + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + LED tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on LED. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._batch_encode_plus + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + if is_split_into_words and not self.add_prefix_space: + raise ValueError( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._encode_plus + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + if is_split_into_words and not self.add_prefix_space: + raise ValueError( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.create_token_type_ids_from_sequences with BART->LED + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. LED does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + # Copied from transformers.models.led.tokenization_led.LEDTokenizer._pad + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + encoded_inputs = super()._pad( + encoded_inputs=encoded_inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_attention_mask and "global_attention_mask" in encoded_inputs: + required_input = encoded_inputs[self.model_input_names[0]] + # `global_attention_mask` need to have the same length as other (sequential) inputs. + needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input) + + if needs_to_be_padded: + difference = len(required_input) - len(encoded_inputs["global_attention_mask"]) + + if self.padding_side == "right": + # Use `-1` since `0` in `global_attention_mask` means `local attention` instead of `not to attend` + encoded_inputs["global_attention_mask"] = ( + encoded_inputs["global_attention_mask"] + [-1] * difference + ) + elif self.padding_side == "left": + encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[ + "global_attention_mask" + ] + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return encoded_inputs + + +__all__ = ["LEDTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/levit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/levit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ae097b66a7f23a97cc24eda5ae80051bcd475c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/levit/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_levit import * + from .feature_extraction_levit import * + from .image_processing_levit import * + from .image_processing_levit_fast import * + from .modeling_levit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/levit/configuration_levit.py b/venv/lib/python3.13/site-packages/transformers/models/levit/configuration_levit.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d63ed8e37177e0779fe3d22b027a69f1f80838 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/levit/configuration_levit.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LeViT model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class LevitConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LevitModel`]. It is used to instantiate a LeViT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LeViT + [facebook/levit-128S](https://huggingface.co/facebook/levit-128S) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size of the input image. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + kernel_size (`int`, *optional*, defaults to 3): + The kernel size for the initial convolution layers of patch embedding. + stride (`int`, *optional*, defaults to 2): + The stride size for the initial convolution layers of patch embedding. + padding (`int`, *optional*, defaults to 1): + The padding size for the initial convolution layers of patch embedding. + patch_size (`int`, *optional*, defaults to 16): + The patch size for embeddings. + hidden_sizes (`list[int]`, *optional*, defaults to `[128, 256, 384]`): + Dimension of each of the encoder blocks. + num_attention_heads (`list[int]`, *optional*, defaults to `[4, 8, 12]`): + Number of attention heads for each attention layer in each block of the Transformer encoder. + depths (`list[int]`, *optional*, defaults to `[4, 4, 4]`): + The number of layers in each encoder block. + key_dim (`list[int]`, *optional*, defaults to `[16, 16, 16]`): + The size of key in each of the encoder blocks. + drop_path_rate (`int`, *optional*, defaults to 0): + The dropout probability for stochastic depths, used in the blocks of the Transformer encoder. + mlp_ratios (`list[int]`, *optional*, defaults to `[2, 2, 2]`): + Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the + encoder blocks. + attention_ratios (`list[int]`, *optional*, defaults to `[2, 2, 2]`): + Ratio of the size of the output dimension compared to input dimension of attention layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import LevitConfig, LevitModel + + >>> # Initializing a LeViT levit-128S style configuration + >>> configuration = LevitConfig() + + >>> # Initializing a model (with random weights) from the levit-128S style configuration + >>> model = LevitModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "levit" + + def __init__( + self, + image_size=224, + num_channels=3, + kernel_size=3, + stride=2, + padding=1, + patch_size=16, + hidden_sizes=[128, 256, 384], + num_attention_heads=[4, 8, 12], + depths=[4, 4, 4], + key_dim=[16, 16, 16], + drop_path_rate=0, + mlp_ratio=[2, 2, 2], + attention_ratio=[2, 2, 2], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.image_size = image_size + self.num_channels = num_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.hidden_sizes = hidden_sizes + self.num_attention_heads = num_attention_heads + self.depths = depths + self.key_dim = key_dim + self.drop_path_rate = drop_path_rate + self.patch_size = patch_size + self.attention_ratio = attention_ratio + self.mlp_ratio = mlp_ratio + self.initializer_range = initializer_range + self.down_ops = [ + ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2], + ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2], + ] + + +# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig +class LevitOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["LevitConfig", "LevitOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/levit/feature_extraction_levit.py b/venv/lib/python3.13/site-packages/transformers/models/levit/feature_extraction_levit.py new file mode 100644 index 0000000000000000000000000000000000000000..d634239b24500fd279d41101188937548df4bfa2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/levit/feature_extraction_levit.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for LeViT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_levit import LevitImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class LevitFeatureExtractor(LevitImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class LevitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use LevitImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["LevitFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/levit/image_processing_levit.py b/venv/lib/python3.13/site-packages/transformers/models/levit/image_processing_levit.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf03b39e4b9d97440c5e284e16b13e026f4bd25 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/levit/image_processing_levit.py @@ -0,0 +1,312 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for LeViT.""" + +from collections.abc import Iterable +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class LevitImageProcessor(BaseImageProcessor): + r""" + Constructs a LeViT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Wwhether to resize the shortest edge of the input to int(256/224 *`size`). Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict[str, int]`, *optional*, defaults to `{"shortest_edge": 224}`): + Size of the output image after resizing. If size is a dict with keys "width" and "height", the image will + be resized to `(size["height"], size["width"])`. If size is a dict with key "shortest_edge", the shortest + edge value `c` is rescaled to `int(c * (256/224))`. The smaller edge of the image will be matched to this + value i.e, if height > width, then image will be rescaled to `(size["shortest_edge"] * height / width, + size["shortest_edge"])`. Can be overridden by the `size` parameter in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether or not to center crop the input to `(crop_size["height"], crop_size["width"])`. Can be overridden + by the `do_center_crop` parameter in the `preprocess` method. + crop_size (`Dict`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired image size after `center_crop`. Can be overridden by the `crop_size` parameter in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + image_mean (`list[int]`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`list[int]`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_MEAN, + image_std: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_STD, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. + + If size is a dict with keys "width" and "height", the image will be resized to `(size["height"], + size["width"])`. + + If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`. + The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled + to `(size["shortest_edge"] * height / width, size["shortest_edge"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image after resizing. If size is a dict with keys "width" and "height", the image + will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value + `c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value + i.e, if height > width, then image will be rescaled to (size * height / width, size). + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size_dict = get_size_dict(size, default_to_square=False) + # size_dict is a dict with either keys "height" and "width" or "shortest_edge" + if "shortest_edge" in size: + shortest_edge = int((256 / 224) * size["shortest_edge"]) + output_size = get_resize_output_image_size( + image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format + ) + size_dict = {"height": output_size[0], "width": output_size[1]} + if "height" not in size_dict or "width" not in size_dict: + raise ValueError( + f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}" + ) + return resize( + image, + size=(size_dict["height"], size_dict["width"]), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, Iterable[float]]] = None, + image_std: Optional[Union[float, Iterable[float]]] = None, + return_tensors: Optional[TensorType] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Preprocess an image or batch of images to be used as input to a LeViT model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after resizing. If size is a dict with keys "width" and "height", the image + will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value + `c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value + i.e, if height > width, then image will be rescaled to (size * height / width, size). + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the output image after center cropping. Crops images to (crop_size["height"], + crop_size["width"]). + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image pixel values by `rescaling_factor` - typical to values between 0 and 1. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Factor to rescale the image pixel values by. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image pixel values by `image_mean` and `image_std`. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean to normalize the image pixel values by. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to normalize the image pixel values by. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [self.resize(image, size, resample, input_data_format=input_data_format) for image in images] + + if do_center_crop: + images = [self.center_crop(image, crop_size, input_data_format=input_data_format) for image in images] + + if do_rescale: + images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + + if do_normalize: + images = [ + self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["LevitImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/levit/image_processing_levit_fast.py b/venv/lib/python3.13/site-packages/transformers/models/levit/image_processing_levit_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..ae30194288fa173067c344f6856eb604aa630251 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/levit/image_processing_levit_fast.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for LeViT.""" + +from typing import Optional + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict +from ...image_transforms import ( + ChannelDimension, + get_resize_output_image_size, +) +from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling +from ...utils import auto_docstring + + +@auto_docstring +class LevitImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = None + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"] = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize an image. + + If size is a dict with keys "width" and "height", the image will be resized to `(size["height"], + size["width"])`. + + If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`. + The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled + to `(size["shortest_edge"] * height / width, size["shortest_edge"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the output image after resizing. If size is a dict with keys "width" and "height", the image + will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value + `c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value + i.e, if height > width, then image will be rescaled to (size * height / width, size). + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BICUBIC`): + Resampling filter to use when resiizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BICUBIC + if size.shortest_edge: + shortest_edge = int((256 / 224) * size["shortest_edge"]) + new_size = get_resize_output_image_size( + image, size=shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST + ) + elif size.height and size.width: + new_size = (size.height, size.width) + else: + raise ValueError( + f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size.keys()} {size.keys()}." + ) + return F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + + +__all__ = ["LevitImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/levit/modeling_levit.py b/venv/lib/python3.13/site-packages/transformers/models/levit/modeling_levit.py new file mode 100644 index 0000000000000000000000000000000000000000..3deca07e24009f3ed13b8afdc50c8edc0d8cf9ca --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/levit/modeling_levit.py @@ -0,0 +1,657 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LeViT model.""" + +import itertools +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, + ModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_levit import LevitConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`LevitForImageClassificationWithTeacher`]. + """ +) +class LevitForImageClassificationWithTeacherOutput(ModelOutput): + r""" + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Prediction scores as the average of the `cls_logits` and `distillation_logits`. + cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the + class token). + distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the + distillation token). + """ + + logits: Optional[torch.FloatTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + distillation_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +class LevitConvEmbeddings(nn.Module): + """ + LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1 + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False + ) + self.batch_norm = nn.BatchNorm2d(out_channels) + + def forward(self, embeddings): + embeddings = self.convolution(embeddings) + embeddings = self.batch_norm(embeddings) + return embeddings + + +class LevitPatchEmbeddings(nn.Module): + """ + LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple + `LevitConvEmbeddings`. + """ + + def __init__(self, config): + super().__init__() + self.embedding_layer_1 = LevitConvEmbeddings( + config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding + ) + self.activation_layer_1 = nn.Hardswish() + + self.embedding_layer_2 = LevitConvEmbeddings( + config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding + ) + self.activation_layer_2 = nn.Hardswish() + + self.embedding_layer_3 = LevitConvEmbeddings( + config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding + ) + self.activation_layer_3 = nn.Hardswish() + + self.embedding_layer_4 = LevitConvEmbeddings( + config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding + ) + self.num_channels = config.num_channels + + def forward(self, pixel_values): + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.embedding_layer_1(pixel_values) + embeddings = self.activation_layer_1(embeddings) + embeddings = self.embedding_layer_2(embeddings) + embeddings = self.activation_layer_2(embeddings) + embeddings = self.embedding_layer_3(embeddings) + embeddings = self.activation_layer_3(embeddings) + embeddings = self.embedding_layer_4(embeddings) + return embeddings.flatten(2).transpose(1, 2) + + +class MLPLayerWithBN(nn.Module): + def __init__(self, input_dim, output_dim, bn_weight_init=1): + super().__init__() + self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False) + self.batch_norm = nn.BatchNorm1d(output_dim) + + def forward(self, hidden_state): + hidden_state = self.linear(hidden_state) + hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state) + return hidden_state + + +class LevitSubsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, hidden_state): + batch_size, _, channels = hidden_state.shape + hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[ + :, :: self.stride, :: self.stride + ].reshape(batch_size, -1, channels) + return hidden_state + + +class LevitAttention(nn.Module): + def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution): + super().__init__() + self.num_attention_heads = num_attention_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2 + self.out_dim_projection = attention_ratio * key_dim * num_attention_heads + + self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values) + self.activation = nn.Hardswish() + self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0) + + points = list(itertools.product(range(resolution), range(resolution))) + len_points = len(points) + attention_offsets, indices = {}, [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + indices.append(attention_offsets[offset]) + + self.attention_bias_cache = {} + self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets))) + self.register_buffer( + "attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False + ) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device): + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, hidden_state): + batch_size, seq_length, _ = hidden_state.shape + queries_keys_values = self.queries_keys_values(hidden_state) + query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split( + [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3 + ) + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device) + attention = attention.softmax(dim=-1) + hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection) + hidden_state = self.projection(self.activation(hidden_state)) + return hidden_state + + +class LevitAttentionSubsample(nn.Module): + def __init__( + self, + input_dim, + output_dim, + key_dim, + num_attention_heads, + attention_ratio, + stride, + resolution_in, + resolution_out, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads + self.out_dim_projection = attention_ratio * key_dim * num_attention_heads + self.resolution_out = resolution_out + # resolution_in is the initial resolution, resolution_out is final resolution after downsampling + self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values) + self.queries_subsample = LevitSubsample(stride, resolution_in) + self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads) + self.activation = nn.Hardswish() + self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim) + + self.attention_bias_cache = {} + + points = list(itertools.product(range(resolution_in), range(resolution_in))) + points_ = list(itertools.product(range(resolution_out), range(resolution_out))) + len_points, len_points_ = len(points), len(points_) + attention_offsets, indices = {}, [] + for p1 in points_: + for p2 in points: + size = 1 + offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + indices.append(attention_offsets[offset]) + + self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets))) + self.register_buffer( + "attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False + ) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device): + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, hidden_state): + batch_size, seq_length, _ = hidden_state.shape + key, value = ( + self.keys_values(hidden_state) + .view(batch_size, seq_length, self.num_attention_heads, -1) + .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3) + ) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + query = self.queries(self.queries_subsample(hidden_state)) + query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute( + 0, 2, 1, 3 + ) + + attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device) + attention = attention.softmax(dim=-1) + hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection) + hidden_state = self.projection(self.activation(hidden_state)) + return hidden_state + + +class LevitMLPLayer(nn.Module): + """ + MLP Layer with `2X` expansion in contrast to ViT with `4X`. + """ + + def __init__(self, input_dim, hidden_dim): + super().__init__() + self.linear_up = MLPLayerWithBN(input_dim, hidden_dim) + self.activation = nn.Hardswish() + self.linear_down = MLPLayerWithBN(hidden_dim, input_dim) + + def forward(self, hidden_state): + hidden_state = self.linear_up(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.linear_down(hidden_state) + return hidden_state + + +class LevitResidualLayer(nn.Module): + """ + Residual Block for LeViT + """ + + def __init__(self, module, drop_rate): + super().__init__() + self.module = module + self.drop_rate = drop_rate + + def forward(self, hidden_state): + if self.training and self.drop_rate > 0: + rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device) + rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach() + hidden_state = hidden_state + self.module(hidden_state) * rnd + return hidden_state + else: + hidden_state = hidden_state + self.module(hidden_state) + return hidden_state + + +class LevitStage(nn.Module): + """ + LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers. + """ + + def __init__( + self, + config, + idx, + hidden_sizes, + key_dim, + depths, + num_attention_heads, + attention_ratio, + mlp_ratio, + down_ops, + resolution_in, + ): + super().__init__() + self.layers = [] + self.config = config + self.resolution_in = resolution_in + # resolution_in is the initial resolution, resolution_out is final resolution after downsampling + for _ in range(depths): + self.layers.append( + LevitResidualLayer( + LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in), + self.config.drop_path_rate, + ) + ) + if mlp_ratio > 0: + hidden_dim = hidden_sizes * mlp_ratio + self.layers.append( + LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate) + ) + + if down_ops[0] == "Subsample": + self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1 + self.layers.append( + LevitAttentionSubsample( + *self.config.hidden_sizes[idx : idx + 2], + key_dim=down_ops[1], + num_attention_heads=down_ops[2], + attention_ratio=down_ops[3], + stride=down_ops[5], + resolution_in=resolution_in, + resolution_out=self.resolution_out, + ) + ) + self.resolution_in = self.resolution_out + if down_ops[4] > 0: + hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4] + self.layers.append( + LevitResidualLayer( + LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate + ) + ) + + self.layers = nn.ModuleList(self.layers) + + def get_resolution(self): + return self.resolution_in + + def forward(self, hidden_state): + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class LevitEncoder(nn.Module): + """ + LeViT Encoder consisting of multiple `LevitStage` stages. + """ + + def __init__(self, config): + super().__init__() + self.config = config + resolution = self.config.image_size // self.config.patch_size + self.stages = [] + self.config.down_ops.append([""]) + + for stage_idx in range(len(config.depths)): + stage = LevitStage( + config, + stage_idx, + config.hidden_sizes[stage_idx], + config.key_dim[stage_idx], + config.depths[stage_idx], + config.num_attention_heads[stage_idx], + config.attention_ratio[stage_idx], + config.mlp_ratio[stage_idx], + config.down_ops[stage_idx], + resolution, + ) + resolution = stage.get_resolution() + self.stages.append(stage) + + self.stages = nn.ModuleList(self.stages) + + def forward(self, hidden_state, output_hidden_states=False, return_dict=True): + all_hidden_states = () if output_hidden_states else None + + for stage in self.stages: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + hidden_state = stage(hidden_state) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + if not return_dict: + return tuple(v for v in [hidden_state, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states) + + +class LevitClassificationLayer(nn.Module): + """ + LeViT Classification Layer + """ + + def __init__(self, input_dim, output_dim): + super().__init__() + self.batch_norm = nn.BatchNorm1d(input_dim) + self.linear = nn.Linear(input_dim, output_dim) + + def forward(self, hidden_state): + hidden_state = self.batch_norm(hidden_state) + logits = self.linear(hidden_state) + return logits + + +@auto_docstring +class LevitPreTrainedModel(PreTrainedModel): + config: LevitConfig + base_model_prefix = "levit" + main_input_name = "pixel_values" + _no_split_modules = ["LevitResidualLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class LevitModel(LevitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.patch_embeddings = LevitPatchEmbeddings(config) + self.encoder = LevitEncoder(config) + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embeddings = self.patch_embeddings(pixel_values) + encoder_outputs = self.encoder( + embeddings, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + + # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes) + pooled_output = last_hidden_state.mean(dim=1) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class LevitForImageClassification(LevitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.num_labels = config.num_labels + self.levit = LevitModel(config) + + # Classifier head + self.classifier = ( + LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels) + if config.num_labels > 0 + else torch.nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + sequence_output = outputs[0] + sequence_output = sequence_output.mean(1) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and + a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning:: + This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet + supported. + """ +) +class LevitForImageClassificationWithTeacher(LevitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.num_labels = config.num_labels + self.levit = LevitModel(config) + + # Classifier head + self.classifier = ( + LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels) + if config.num_labels > 0 + else torch.nn.Identity() + ) + self.classifier_distill = ( + LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels) + if config.num_labels > 0 + else torch.nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LevitForImageClassificationWithTeacherOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + sequence_output = outputs[0] + sequence_output = sequence_output.mean(1) + cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output) + logits = (cls_logits + distill_logits) / 2 + + if not return_dict: + output = (logits, cls_logits, distill_logits) + outputs[2:] + return output + + return LevitForImageClassificationWithTeacherOutput( + logits=logits, + cls_logits=cls_logits, + distillation_logits=distill_logits, + hidden_states=outputs.hidden_states, + ) + + +__all__ = [ + "LevitForImageClassification", + "LevitForImageClassificationWithTeacher", + "LevitModel", + "LevitPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0357ffbaa69bd379fb48e884ed4d9a13f943d9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_lfm2_vl import * + from .image_processing_lfm2_vl_fast import * + from .modeling_lfm2_vl import * + from .processing_lfm2_vl import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/configuration_lfm2_vl.py b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/configuration_lfm2_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..1378fbe6dc8c54fe8e6aa05d8b3ee473c93d4c46 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/configuration_lfm2_vl.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LFM2-VL model.""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class Lfm2VlConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Lfm2VlForConditionalGeneration`]. It is used to instantiate an + Lfm2Vl model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Lfm2-VL-1.6B. + + e.g. [LiquidAI/LFM2-VL-1.6B](https://huggingface.co/LiquidAI/LFM2-VL-1.6B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`AutoConfig | dict`, *optional*, defaults to `Siglip2ImageConfig`): + The config object or dictionary of the vision backbone. + text_config (`AutoConfig | dict`, *optional*, defaults to `Lfm2Config`): + The config object or dictionary of the text backbone. + image_token_id (`int`, *optional*, defaults to 396): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + projector_hidden_size (`int`, *optional*, defaults to 2560): + The hidden size of the multimodal projector. + projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + downsample_factor (`int`, *optional*, defaults to 2): + The downsample_factor factor of the vision backbone. + """ + + model_type = "lfm2-vl" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_id=396, + projector_hidden_act="gelu", + projector_hidden_size=2560, + projector_bias=True, + downsample_factor=2, + **kwargs, + ): + self.image_token_id = image_token_id + self.projector_hidden_act = projector_hidden_act + self.projector_hidden_size = projector_hidden_size + self.projector_bias = projector_bias + self.downsample_factor = downsample_factor + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "siglip2_vision_model") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["siglip2_vision_model"]() + + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "lfm2") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["lfm2"]() + + self.vision_config = vision_config + self.text_config = text_config + + super().__init__(**kwargs) + + +__all__ = ["Lfm2VlConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..4081c86e108ac161bee3e0cf30a7c9cb52a46ec1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py @@ -0,0 +1,541 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from functools import lru_cache +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import ( + Unpack, +) +from ...utils import ( + TensorType, + auto_docstring, + logging, +) + + +logger = logging.get_logger(__name__) + + +def round_by_factor(number: float, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + """Find the closest aspect ratio from target_ratios to match the input aspect ratio. + + Args: + aspect_ratio: The aspect ratio to match (width/height). + target_ratios: List of possible aspect ratios as tuples of (width, height) integers. + width: Original image width in pixels. + height: Original image height in pixels. + image_size: Base size for calculating target area. + + Returns: + tuple[int, int]: The best matching ratio as (width, height) integers. + """ + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + + # update best ratio if we found a closer match + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + # if equally close, prefer the ratio that better matches the original image area + elif ratio_diff == best_ratio_diff: + target_area = image_size * image_size * ratio[0] * ratio[1] + if area > 0.5 * target_area: + best_ratio = ratio + + return best_ratio + + +# copied from Siglip2ImageProcessor +@lru_cache(maxsize=256) +def get_image_size_for_max_num_patches( + image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5 +) -> tuple[int, int]: + """ + Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch. + + Args: + image_height (`int`): + Original image height. + image_width (`int`): + Original image width. + patch_size (`int`): + Patch size for processing. + max_num_patches (`int`): + Maximum number of patches. + eps (`float`): + Small threshold for binary search. + + Returns: + Tuple: (target_height, target_width) + """ + + def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: + scaled_size = size * scale + scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size + scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch + return int(scaled_size) + + # Binary search for optimal scale + scale_min, scale_max = eps / 10, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + num_patches = (target_height / patch_size) * (target_width / patch_size) + + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + return target_height, target_width + + +def convert_image_to_patches(images: "torch.Tensor", patch_size: int) -> "torch.Tensor": + """ + Convert 3D array image of shape (image_height, image_width, num_channels) into 2D array of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + batch_size, num_channels, image_height, image_width = images.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = images.reshape( + batch_size, num_channels, num_patches_height, patch_size, num_patches_width, patch_size + ) + patched_image = patched_image.permute(0, 2, 4, 3, 5, 1) + patched_image = patched_image.reshape(batch_size, num_patches_height * num_patches_width, -1) + return patched_image + + +def pad_along_first_dim( + images: "torch.Tensor", target_length: int, pad_value: int = 0 +) -> tuple["torch.Tensor", "torch.Tensor"]: + """ + Pad the array along the first dimension. + """ + current_length = images.shape[1] + padding_length = target_length - current_length + pixel_mask = torch.ones((target_length,), dtype=torch.int32) + if padding_length > 0: + paddings = (0, 0, 0, padding_length, 0, 0) + images = torch.nn.functional.pad(images, paddings, mode="constant", value=pad_value) + pixel_mask[-padding_length:] = 0 + return images, pixel_mask + + +class Lfm2VlFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + downsample_factor (`int`, *optional*, defaults to `2`): + The downsampling factor for images used when resizing the image. + """ + + downsample_factor: Optional[int] + do_image_splitting: Optional[bool] + min_tiles: Optional[int] + max_tiles: Optional[int] + use_thumbnail: Optional[bool] + min_image_tokens: Optional[int] + max_image_tokens: Optional[int] + encoder_patch_size: Optional[int] + tile_size: Optional[int] + max_pixels_tolerance: Optional[float] + do_pad: Optional[bool] + return_row_col_info: Optional[bool] + + +@auto_docstring +class Lfm2VlImageProcessorFast(BaseImageProcessorFast): + downsample_factor = 2 + do_image_splitting = True + min_tiles = 2 + max_tiles = 10 + use_thumbnail = True + min_image_tokens = 64 + max_image_tokens = 256 + encoder_patch_size = 16 + tile_size = 512 + max_pixels_tolerance = 2.0 + do_resize = True + size = {"height": 512, "width": 512} + resample = PILImageResampling.BILINEAR + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = True + do_pad = True + return_row_col_info = False + image_mean = IMAGENET_STANDARD_STD + image_std = IMAGENET_STANDARD_MEAN + valid_kwargs = Lfm2VlFastImageProcessorKwargs + model_input_names = ["pixel_values", "pixel_attention_mask", "spatial_shapes"] + + def __init__(self, **kwargs: Unpack[Lfm2VlFastImageProcessorKwargs]): + super().__init__(**kwargs) + + max_thumbnail_image_patches = self.max_image_tokens * self.downsample_factor**2 + tile_size_patches = (self.tile_size // self.encoder_patch_size) ** 2 if self.do_image_splitting else 0 + self.max_num_patches = max( + max_thumbnail_image_patches, + tile_size_patches, + ) + + @lru_cache(maxsize=256) + def _target_ratios(self, min_tiles: int, max_tiles: int) -> list[tuple[int, int]]: + ratios = [ + (w, h) + for n in range(min_tiles, max_tiles + 1) + for w in range(1, n + 1) + for h in range(1, n + 1) + if min_tiles <= w * h <= max_tiles + ] + return sorted(set(ratios), key=lambda x: x[0] * x[1]) + + def _get_grid_layout( + self, + height: int, + width: int, + min_tiles: int, + max_tiles: int, + tile_size: int, + ) -> tuple[int, int]: + aspect_ratio = width / height + target_ratios = self._target_ratios(min_tiles, max_tiles) + + # find best matching grid configuration + grid_width, grid_height = find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, tile_size) + + target_width = tile_size * grid_width + target_height = tile_size * grid_height + total_patches = grid_width * grid_height + + return grid_width, grid_height, target_width, target_height, total_patches + + def crop_image_to_patches( + self, + image: "torch.Tensor", + min_tiles: int, + max_tiles: int, + tile_size: int, + use_thumbnail: bool, + thumbnail_size: tuple[int], + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + """ + Processes a high resolution image into patches. + This method splits a high resolution image into a grid of smaller patches while trying to maintain + the original aspect ratio. It finds the optimal grid configuration within the specified tile constraints. + """ + batch_size, num_channels, height, width = image.shape + grid_width, grid_height, target_width, target_height, total_patches = self._get_grid_layout( + height, width, min_tiles=min_tiles, max_tiles=max_tiles, tile_size=tile_size + ) + resized_image = F.resize( + image, (target_height, target_width), interpolation=interpolation, antialias=antialias + ) + + # split the image into patches + processed_images = ( + resized_image.unfold(2, size=tile_size, step=tile_size) + .unfold(3, size=tile_size, step=tile_size) + .contiguous() + .view(batch_size, num_channels, -1, tile_size, tile_size) + .permute(2, 0, 1, 3, 4) + .reshape(batch_size, -1, num_channels, tile_size, tile_size) + ) + + # Re-order processed images to a nested image structure, so it can be reordered back correctly + # Note that the images can't be stacked because the thumbnail image is of bigger size than patches + # Each image in sublist will be of shape (1, C, H, W) + processed_images = list(processed_images) + + if use_thumbnail and grid_width * grid_height != 1: + total_patches += 1 + thumbnail_image = F.resize(image, thumbnail_size, interpolation=interpolation, antialias=antialias) + for i in range(batch_size): + processed_images[i] = list(processed_images[i]) + list(thumbnail_image[i][None, ...]) + + return processed_images, grid_width, grid_height + + # Adapted from Qwen-VL with minor differences + def smart_resize( + self, + height: int, + width: int, + downsample_factor: int, + min_image_tokens: int, + max_image_tokens: int, + encoder_patch_size: int, + ) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'encoder_patch_size' * 'downsample_factor'. + This ensures no padding is needed in the downsampling step. + 2. The total number of pixels is within the range ['smart_resize_min_pixels', 'smart_resize_max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + total_factor = encoder_patch_size * downsample_factor + smart_resize_min_pixels = min_image_tokens * encoder_patch_size**2 * downsample_factor**2 + smart_resize_max_pixels = max_image_tokens * encoder_patch_size**2 * downsample_factor**2 + + h_bar = max(total_factor, round_by_factor(height, total_factor)) + w_bar = max(total_factor, round_by_factor(width, total_factor)) + + if h_bar * w_bar > smart_resize_max_pixels: + beta = math.sqrt((height * width) / smart_resize_max_pixels) + math.floor(height / beta / total_factor) * total_factor + h_bar = max(total_factor, math.floor(height / beta / total_factor) * total_factor) + w_bar = max(total_factor, math.floor(width / beta / total_factor) * total_factor) + elif h_bar * w_bar < smart_resize_min_pixels: + beta = math.sqrt(smart_resize_min_pixels / (height * width)) + h_bar = math.ceil(height * beta / total_factor) * total_factor + w_bar = math.ceil(width * beta / total_factor) * total_factor + + return w_bar, h_bar + + def _is_image_too_large( + self, + height: int, + width: int, + max_image_tokens: int, + encoder_patch_size: int, + downsample_factor: int, + max_pixels_tolerance: float, + ) -> bool: + """Check if the image is too large to be processed as one tile.""" + total_factor = encoder_patch_size * downsample_factor + + h_bar = max(encoder_patch_size, round_by_factor(height, total_factor)) + w_bar = max(encoder_patch_size, round_by_factor(width, total_factor)) + return h_bar * w_bar > max_image_tokens * encoder_patch_size**2 * downsample_factor**2 * max_pixels_tolerance + + def resize_and_split( + self, + images: "torch.Tensor", + downsample_factor: int, + min_tiles: int, + max_tiles: int, + use_thumbnail: bool, + min_image_tokens: int, + max_image_tokens: int, + encoder_patch_size: int, + tile_size: int, + max_pixels_tolerance: float, + interpolation: "F.InterpolationMode", + ) -> "torch.Tensor": + batch_size, _, height, width = images.shape + do_image_splitting = not min_tiles == max_tiles == 1 + is_image_large = self._is_image_too_large( + height=height, + width=width, + max_image_tokens=max_image_tokens, + encoder_patch_size=encoder_patch_size, + downsample_factor=downsample_factor, + max_pixels_tolerance=max_pixels_tolerance, + ) + + new_width, new_height = self.smart_resize( + height=height, + width=width, + downsample_factor=downsample_factor, + min_image_tokens=min_image_tokens, + max_image_tokens=max_image_tokens, + encoder_patch_size=encoder_patch_size, + ) + + # Big image will be cropped into patches and small images are just resized + if is_image_large and do_image_splitting: + images, num_rows, num_cols = self.crop_image_to_patches( + images, + min_tiles=min_tiles, + max_tiles=max_tiles, + tile_size=tile_size, + thumbnail_size=(new_height, new_width), + use_thumbnail=use_thumbnail, + interpolation=interpolation, + ) + else: + num_rows = num_cols = 1 + images = F.resize(images, (new_height, new_width), interpolation=interpolation) + # Make a list and treat it as single crop per image so it can be re-grouped back correctly + images = [[image] for image in images] + + num_rows = [num_rows] * batch_size + num_cols = [num_cols] * batch_size + image_sizes = [[new_height, new_width]] * batch_size + return images, num_rows, num_cols, image_sizes + + def _preprocess( + self, + images: ImageInput, + size: SizeDict, + interpolation: "F.InterpolationMode", + do_resize: bool, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, list[float]], + image_std: Union[float, list[float]], + downsample_factor: int, + do_image_splitting: bool, + min_tiles: int, + max_tiles: int, + use_thumbnail: bool, + min_image_tokens: int, + max_image_tokens: int, + encoder_patch_size: int, + tile_size: int, + max_pixels_tolerance: float, + return_tensors: Union[str, TensorType], + disable_grouping: bool, + do_pad: bool, + return_row_col_info: bool, + **kwargs, + ) -> BatchFeature: + if not do_image_splitting: + min_tiles = 1 + max_tiles = 1 + logger.debug( + "Image splitting is disabled, setting min_tiles and max_tiles to 1. Set do_image_splitting=True to enable splitting." + ) + + if do_image_splitting and min_tiles > max_tiles: + raise ValueError("min_tiles must be less than or equal to max_tiles") + + max_thumbnail_image_patches = max_image_tokens * downsample_factor**2 + tile_size_patches = (tile_size // encoder_patch_size) ** 2 if do_image_splitting else 0 + max_num_patches = max( + max_thumbnail_image_patches, + tile_size_patches, + ) + + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + resized_image_sizes = {} + rows_grouped, cols_grouped = {}, {} + for shape, stacked_images in grouped_images.items(): + num_rows = [1] * stacked_images.shape[0] + num_cols = [1] * stacked_images.shape[0] + height, width = stacked_images.shape[-2:] + image_sizes = [[height, width]] * stacked_images.shape[0] + do_resize = True + + if do_resize: + stacked_images, num_rows, num_cols, image_sizes = self.resize_and_split( + stacked_images, + downsample_factor=downsample_factor, + min_tiles=min_tiles, + max_tiles=max_tiles, + use_thumbnail=use_thumbnail, + min_image_tokens=min_image_tokens, + max_image_tokens=max_image_tokens, + encoder_patch_size=encoder_patch_size, + tile_size=tile_size, + max_pixels_tolerance=max_pixels_tolerance, + interpolation=interpolation, + ) + + rows_grouped[shape] = num_rows + cols_grouped[shape] = num_cols + resized_image_sizes[shape] = image_sizes + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + batch_rows = reorder_images(rows_grouped, grouped_images_index) + batch_cols = reorder_images(cols_grouped, grouped_images_index) + resized_image_sizes = reorder_images(resized_image_sizes, grouped_images_index) + + grouped_images, grouped_images_index = group_images_by_shape( + resized_images, disable_grouping=disable_grouping, is_nested=True + ) + + processed_images_grouped = {} + processed_masks, processed_spatial_shapes = {}, {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + batch_size, *_, height, width = stacked_images.shape + num_patches_height = height // encoder_patch_size + num_patches_width = width // encoder_patch_size + + stacked_images = convert_image_to_patches(stacked_images, encoder_patch_size) + processed_spatial_shapes[shape] = [[num_patches_height, num_patches_width]] * batch_size + + if do_pad: + stacked_images, pixel_mask = pad_along_first_dim(stacked_images, max_num_patches) + processed_masks[shape] = [pixel_mask] * batch_size + + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True) + data = {"pixel_values": torch.cat([torch.stack(images) for images in processed_images])} + + if do_pad: + processed_masks = reorder_images(processed_masks, grouped_images_index, is_nested=True) + processed_spatial_shapes = reorder_images(processed_spatial_shapes, grouped_images_index, is_nested=True) + processed_masks = torch.cat([torch.stack(masks) for masks in processed_masks]) + processed_spatial_shapes = torch.cat( + [torch.tensor(spatial_shape) for spatial_shape in processed_spatial_shapes] + ) + data.update({"pixel_attention_mask": processed_masks, "spatial_shapes": processed_spatial_shapes}) + + if return_row_col_info: + data["image_rows"] = batch_rows + data["image_cols"] = batch_cols + data["image_sizes"] = resized_image_sizes + + encoding = BatchFeature(data=data, tensor_type=return_tensors) + return encoding + + +__all__ = ["Lfm2VlImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/modeling_lfm2_vl.py b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/modeling_lfm2_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..deee35394ee1fc64e7e10e6f96a4a9a892158b64 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/modeling_lfm2_vl.py @@ -0,0 +1,497 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lfm2_vl/modular_lfm2_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lfm2_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ..auto import AutoModel +from .configuration_lfm2_vl import Lfm2VlConfig + + +class Lfm2VlMultiModalProjector(nn.Module): + def __init__(self, config: Lfm2VlConfig): + super().__init__() + in_channels = config.vision_config.hidden_size * (config.downsample_factor**2) + self.factor = config.downsample_factor + self.layer_norm = nn.LayerNorm(in_channels) + self.linear_1 = nn.Linear( + in_channels, + config.projector_hidden_size, + bias=config.projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_hidden_size, + config.text_config.hidden_size, + bias=config.projector_bias, + ) + + def forward(self, image_features: torch.Tensor): + image_features = self.pixel_unshuffle(image_features) + image_features = self.layer_norm(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + def pixel_unshuffle(self, hidden_states: torch.Tensor): + batch_size, width, height, channels = hidden_states.size() + hidden_states = hidden_states.reshape(batch_size, width, height // self.factor, channels * self.factor) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape( + batch_size, height // self.factor, width // self.factor, channels * self.factor**2 + ) + hidden_states = hidden_states.permute(0, 2, 1, 3) + return hidden_states + + +@auto_docstring +class Lfm2VlPreTrainedModel(PreTrainedModel): + config: Lfm2VlConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = False + _supports_flex_attn = True + _supports_attention_backend = True + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Lfm2Vl causal language model (or autoregressive) outputs. + """ +) +class Lfm2VlCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Lfm2Vl outputs, with hidden states and attentions. + """ +) +class Lfm2VlModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@auto_docstring( + custom_intro=""" + The Lfm2Vl model which consists of a vision backbone and a language model, without a language modeling head. + """ +) +class Lfm2VlModel(Lfm2VlPreTrainedModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: Lfm2VlConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = Lfm2VlMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + spatial_shapes: torch.Tensor, + pixel_attention_mask: torch.Tensor, + **kwargs, + ) -> list[torch.Tensor]: + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + spatial_shapes (`torch.Tensor` of shape `(batch_size, 2)`): + The spatial shapes of the input images. + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, height, width)`): + The pixel attention mask of the input images. + Returns: + image_features (`list[torch.Tensor]`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower( + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + ).last_hidden_state + + img_feature_lengths = pixel_attention_mask.sum(dim=1) + image_features = [] + + for img_idx in range(image_outputs.size(0)): + feature = image_outputs[img_idx] + # unpad the image representation + feature = feature[: img_feature_lengths[img_idx], :].unsqueeze(0) + + # reshape to original height and width + feature_org_h, feature_org_w = spatial_shapes[img_idx] + feature = feature.reshape(1, feature_org_h, feature_org_w, -1) + + # project the image representation + img_embedding = self.multi_modal_projector(feature) + + # flatten here to handle variable length in naflex + img_embedding = img_embedding.reshape(-1, img_embedding.size(-1)) + image_features.append(img_embedding) + + return image_features + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + spatial_shapes: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Lfm2VlModelOutputWithPast]: + r""" + spatial_shapes (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): + The spatial shapes of the input images. + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, height, width)`, *optional*): + The pixel attention mask of the input images. + """ + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + ) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_features=image_features, + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Lfm2VlModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The LFM2_VL model which consists of a vision backbone and a language model. + """ +) +class Lfm2VlForConditionalGeneration(Lfm2VlPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Lfm2VlConfig): + super().__init__(config) + self.model = Lfm2VlModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + spatial_shapes: torch.Tensor, + pixel_attention_mask: torch.Tensor, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + **kwargs, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + spatial_shapes: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Lfm2VlCausalLMOutputWithPast]: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`, *optional*): + The input image tensors. + spatial_shapes (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): + The spatial shapes of the input images. + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, height, width)`, *optional*): + The pixel attention mask of the input images. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + >>> from transformers.image_utils import load_image + + >>> model = AutoModelForImageTextToText.from_pretrained( + ... "LiquidAI/LFM2-VL-1.6B", + ... ) + >>> processor = AutoProcessor.from_pretrained( + ... "LiquidAI/LFM2-VL-1.6B", + ... ) + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = load_image(url) + + >>> conversation = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image", "image": image}, + ... {"type": "text", "text": "What is in this image?"}, + ... ], + ... }, + ... ] + + >>> inputs = processor.apply_chat_template( + ... conversation, + ... add_generation_prompt=True, + ... tokenize=True, + ... return_dict=True, + ... return_tensors="pt" + ... ) + + >>> # Generate + >>> outputs = model.generate(**inputs, max_new_tokens=45) + >>> processor.batch_decode(outputs, skip_special_tokens=True)[0] + 'This image depicts a vibrant street scene in what appears to be a Chinatown or similar cultural area. The focal point is a large red stop sign with white lettering, mounted on a pole.' + ```""" + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.text_config.vocab_size, + **kwargs, + ) + + return Lfm2VlCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["Lfm2VlForConditionalGeneration", "Lfm2VlPreTrainedModel", "Lfm2VlModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/modular_lfm2_vl.py b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/modular_lfm2_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..68367464c3cf75a093c0469a3e44c65e4a3bbd27 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/modular_lfm2_vl.py @@ -0,0 +1,352 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Lfm2-VL model.""" + +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, + LlavaPreTrainedModel, +) +from .configuration_lfm2_vl import Lfm2VlConfig + + +logger = logging.get_logger(__name__) + + +class Lfm2VlMultiModalProjector(nn.Module): + def __init__(self, config: Lfm2VlConfig): + super().__init__() + in_channels = config.vision_config.hidden_size * (config.downsample_factor**2) + self.factor = config.downsample_factor + self.layer_norm = nn.LayerNorm(in_channels) + self.linear_1 = nn.Linear( + in_channels, + config.projector_hidden_size, + bias=config.projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_hidden_size, + config.text_config.hidden_size, + bias=config.projector_bias, + ) + + def forward(self, image_features: torch.Tensor): + image_features = self.pixel_unshuffle(image_features) + image_features = self.layer_norm(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + def pixel_unshuffle(self, hidden_states: torch.Tensor): + batch_size, width, height, channels = hidden_states.size() + hidden_states = hidden_states.reshape(batch_size, width, height // self.factor, channels * self.factor) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape( + batch_size, height // self.factor, width // self.factor, channels * self.factor**2 + ) + hidden_states = hidden_states.permute(0, 2, 1, 3) + return hidden_states + + +class Lfm2VlPreTrainedModel(LlavaPreTrainedModel): + _can_compile_fullgraph = False + + +class Lfm2VlCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + pass + + +class Lfm2VlModelOutputWithPast(LlavaModelOutputWithPast): + pass + + +class Lfm2VlModel(LlavaModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: Lfm2VlConfig): + super().__init__(config) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + spatial_shapes: torch.Tensor, + pixel_attention_mask: torch.Tensor, + **kwargs, + ) -> list[torch.Tensor]: + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + spatial_shapes (`torch.Tensor` of shape `(batch_size, 2)`): + The spatial shapes of the input images. + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, height, width)`): + The pixel attention mask of the input images. + Returns: + image_features (`list[torch.Tensor]`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower( + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + ).last_hidden_state + + img_feature_lengths = pixel_attention_mask.sum(dim=1) + image_features = [] + + for img_idx in range(image_outputs.size(0)): + feature = image_outputs[img_idx] + # unpad the image representation + feature = feature[: img_feature_lengths[img_idx], :].unsqueeze(0) + + # reshape to original height and width + feature_org_h, feature_org_w = spatial_shapes[img_idx] + feature = feature.reshape(1, feature_org_h, feature_org_w, -1) + + # project the image representation + img_embedding = self.multi_modal_projector(feature) + + # flatten here to handle variable length in naflex + img_embedding = img_embedding.reshape(-1, img_embedding.size(-1)) + image_features.append(img_embedding) + + return image_features + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + spatial_shapes: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Lfm2VlModelOutputWithPast]: + r""" + spatial_shapes (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): + The spatial shapes of the input images. + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, height, width)`, *optional*): + The pixel attention mask of the input images. + """ + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + ) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_features=image_features, + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Lfm2VlModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +class Lfm2VlForConditionalGeneration(LlavaForConditionalGeneration): + _checkpoint_conversion_mapping = {} + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + spatial_shapes: torch.Tensor, + pixel_attention_mask: torch.Tensor, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + **kwargs, + ) + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + spatial_shapes: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Lfm2VlCausalLMOutputWithPast]: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`, *optional*): + The input image tensors. + spatial_shapes (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): + The spatial shapes of the input images. + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, height, width)`, *optional*): + The pixel attention mask of the input images. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + >>> from transformers.image_utils import load_image + + >>> model = AutoModelForImageTextToText.from_pretrained( + ... "LiquidAI/LFM2-VL-1.6B", + ... ) + >>> processor = AutoProcessor.from_pretrained( + ... "LiquidAI/LFM2-VL-1.6B", + ... ) + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = load_image(url) + + >>> conversation = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image", "image": image}, + ... {"type": "text", "text": "What is in this image?"}, + ... ], + ... }, + ... ] + + >>> inputs = processor.apply_chat_template( + ... conversation, + ... add_generation_prompt=True, + ... tokenize=True, + ... return_dict=True, + ... return_tensors="pt" + ... ) + + >>> # Generate + >>> outputs = model.generate(**inputs, max_new_tokens=45) + >>> processor.batch_decode(outputs, skip_special_tokens=True)[0] + 'This image depicts a vibrant street scene in what appears to be a Chinatown or similar cultural area. The focal point is a large red stop sign with white lettering, mounted on a pole.' + ```""" + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.text_config.vocab_size, + **kwargs, + ) + + return Lfm2VlCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + +__all__ = ["Lfm2VlForConditionalGeneration", "Lfm2VlPreTrainedModel", "Lfm2VlModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/processing_lfm2_vl.py b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/processing_lfm2_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..12f289c266a14ae1ab84e07f748b6a84c72ea5be --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lfm2_vl/processing_lfm2_vl.py @@ -0,0 +1,269 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, make_nested_list_of_images +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from ...tokenization_utils_base import BatchEncoding, TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Lfm2VlImagesKwargs(ImagesKwargs, total=False): + downsample_factor: Optional[int] + do_image_splitting: Optional[bool] + min_tiles: Optional[int] + max_tiles: Optional[int] + use_thumbnail: Optional[bool] + min_image_tokens: Optional[int] + max_image_tokens: Optional[int] + encoder_patch_size: Optional[int] + tile_size: Optional[int] + max_pixels_tolerance: Optional[float] + patch_size: Optional[int] + do_pad: Optional[bool] + return_row_col_info: Optional[bool] + + +class Lfm2VlProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Lfm2VlImagesKwargs + + _defaults = { + "images_kwargs": { + "return_row_col_info": True, + }, + "text_kwargs": { + "use_image_special_tokens": True, + "add_special_tokens": False, + "padding": False, + "is_split_into_words": False, + }, + } + + +class Lfm2VlProcessor(ProcessorMixin): + r""" + Constructs a Lfm2Vl processor which wraps a Lfm2Tokenizer tokenizer and Lfm2VlImageProcessor into a single processor. + + [`Lfm2VlProcessor`] offers all the functionalities of [`Lfm2ImageProcessor`] and [`Lfm2Tokenizer`]. + + Args: + image_processor (`Lfm2VlImageProcessor`): + An instance of [`Lfm2VlImageProcessor`]. The image processor is a required input. + tokenizer (`PreTrainedTokenizerBase`): + An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input. + chat_template (`str`, *optional*): + A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. + use_image_special_tokens (`bool`, *optional*, defaults to `True`): + Whether to use image special tokens or not when processing. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "Lfm2VlImageProcessorFast" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor, + tokenizer, + chat_template: Optional[str] = None, + use_image_special_tokens: Optional[bool] = True, + **kwargs, + ): + self.image_token = tokenizer.image_token + self.image_token_id = tokenizer.image_token_id + self.use_image_special_tokens = use_image_special_tokens + self.image_start_token = tokenizer.image_start_token + self.image_end_token = tokenizer.image_end_token + self.image_thumbnail_token = tokenizer.image_thumbnail + super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs) + + def __call__( + self, + images: Optional[Union[ImageInput, list[ImageInput], list[list[ImageInput]]]] = None, + text: Optional[Union[TextInput, list[TextInput]]] = None, + **kwargs: Unpack[Lfm2VlProcessorKwargs], + ) -> BatchEncoding: + """ + Processes the input prompts and returns a BatchFeature. + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`, *optional*): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. If is of type `list[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1. + text (`TextInput`, *optional*): + The sequence or batch of sequences to be encoded. + Wherever an image token, `` is encountered it is expanded to a proper sequence of image tokens. + return_tensors (`Optional[str, TensorType]`, *optional*): + If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more + information. + """ + if text is None and images is None: + raise ValueError("You must provide one of `text` or `images`.") + + if images is not None and text is None: + raise ValueError( + "You must provide `text` when `images` is provided. Minimal text consists of a single image token." + ) + + output_kwargs = self._merge_kwargs( + Lfm2VlProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + n_images_in_text = [sample.count(self.image_token) for sample in text] + if sum(n_images_in_text) > 0 and images is None: + raise ValueError(f"We detected {sum(n_images_in_text)} tokens in the text but no images were passed") + + inputs = {} + use_image_special_tokens = output_kwargs["text_kwargs"].pop("use_image_special_tokens") + + if images is not None: + images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(images) + vision_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) + + n_images_in_images = [len(sublist) for sublist in batched_images] + if n_images_in_images != n_images_in_text: + raise ValueError( + f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." + ) + + text = self.expand_text_with_placeholders( + text, + batched_images, + image_rows=vision_inputs.pop("image_rows"), + image_cols=vision_inputs.pop("image_cols"), + image_sizes=vision_inputs.pop("image_sizes"), + use_image_special_tokens=use_image_special_tokens, + **output_kwargs["images_kwargs"], + ) + inputs.update(vision_inputs) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + inputs.update(text_inputs) + + return BatchFeature(inputs, tensor_type=return_tensors) + + def expand_text_with_placeholders( + self, + text: list[str], + images: list[list[ImageInput]], + image_rows: list[list[int]], + image_cols: list[list[int]], + image_sizes: list[list[int]], + use_image_special_tokens: bool, + **images_kwargs, + ): + prompt_strings = [] + + image_data = iter(zip(*[image_rows, image_cols, image_sizes])) + for sample_text, sample_images in zip(text, images): + split_sample = sample_text.split(self.image_token) + sample_text_with_image_tokens = "" + for i, image in enumerate(sample_images): + sample_text_with_image_tokens += split_sample[i] + if use_image_special_tokens: + sample_text_with_image_tokens += self.image_start_token + + rows, cols, image_size = next(image_data) + num_thumbnail_tokens, num_tokens_per_tile = self._get_image_num_tokens(image_size, **images_kwargs) + + if rows > 1 or cols > 1: + for row in range(rows): + for col in range(cols): + if use_image_special_tokens: + sample_text_with_image_tokens += f"<|img_row_{row + 1}_col_{col + 1}|>" + sample_text_with_image_tokens += self.image_token * num_tokens_per_tile + + if num_thumbnail_tokens > 0: + if use_image_special_tokens: + sample_text_with_image_tokens += self.image_thumbnail_token + sample_text_with_image_tokens += self.image_token * num_thumbnail_tokens + else: + sample_text_with_image_tokens += self.image_token * num_thumbnail_tokens + + if use_image_special_tokens: + sample_text_with_image_tokens += self.image_end_token + + sample_text_with_image_tokens += split_sample[i + 1] + prompt_strings.append(sample_text_with_image_tokens) + + return prompt_strings + + def _get_image_num_tokens(self, image_size: list[int], **images_kwargs) -> tuple[int, int]: + tile_size = images_kwargs.get("tile_size", self.image_processor.tile_size) + downsample_factor = images_kwargs.get("downsample_factor", self.image_processor.downsample_factor) + encoder_patch_size = images_kwargs.get("encoder_patch_size", self.image_processor.encoder_patch_size) + use_thumbnail = images_kwargs.get("use_thumbnail", self.image_processor.use_thumbnail) + + thumbnail_tokens = 0 + if use_thumbnail: + image_height, image_width = image_size + num_patches_height = image_height // encoder_patch_size + num_patches_width = image_width // encoder_patch_size + dwn_num_patches_height = math.ceil(num_patches_height / downsample_factor) + dwn_num_patches_width = math.ceil(num_patches_width / downsample_factor) + thumbnail_tokens = dwn_num_patches_height * dwn_num_patches_width + + num_patches_tile = tile_size // encoder_patch_size + dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor) + tile_tokens = dwn_num_patches_tile * dwn_num_patches_tile + + return thumbnail_tokens, tile_tokens + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LFM2Tokeniser's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + batched_decode_output = self.tokenizer.batch_decode(*args, **kwargs) + return batched_decode_output + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LFM2Tokeniser's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + decode_output = self.tokenizer.decode(*args, **kwargs) + return decode_output + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + + # LFM2-VL has no dedicated tokenizer class and uses the Base class with default model input names + tokenizer_input_names = [name for name in tokenizer_input_names if name != "token_type_ids"] + return list(tokenizer_input_names + image_processor_input_names) + + +__all__ = ["Lfm2VlProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lightglue/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/lightglue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..190e4e4329419f92da2141c2027f8212d6178e6f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lightglue/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_lightglue import * + from .image_processing_lightglue import * + from .modeling_lightglue import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/lightglue/configuration_lightglue.py b/venv/lib/python3.13/site-packages/transformers/models/lightglue/configuration_lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..90e8d41b45156932278872ddd61432e18330db7c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lightglue/configuration_lightglue.py @@ -0,0 +1,155 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig +from ..superpoint import SuperPointConfig + + +class LightGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to + instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LightGlue + [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + descriptor_dim (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + num_hidden_layers (`int`, *optional*, defaults to 9): + The number of self and cross attention layers. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the multi-head attention. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + depth_confidence (`float`, *optional*, defaults to 0.95): + The confidence threshold used to perform early stopping + width_confidence (`float`, *optional*, defaults to 0.99): + The confidence threshold used to prune points + filter_threshold (`float`, *optional*, defaults to 0.1): + The confidence threshold used to filter matches + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function to be used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether to trust remote code when using other models than SuperPoint as keypoint detector. + + Examples: + ```python + >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching + + >>> # Initializing a LightGlue style configuration + >>> configuration = LightGlueConfig() + + >>> # Initializing a model from the LightGlue style configuration + >>> model = LightGlueForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightglue" + sub_configs = {"keypoint_detector_config": AutoConfig} + + def __init__( + self, + keypoint_detector_config: SuperPointConfig = None, + descriptor_dim: int = 256, + num_hidden_layers: int = 9, + num_attention_heads: int = 4, + num_key_value_heads=None, + depth_confidence: float = 0.95, + width_confidence: float = 0.99, + filter_threshold: float = 0.1, + initializer_range: float = 0.02, + hidden_act: str = "gelu", + attention_dropout=0.0, + attention_bias=True, + trust_remote_code: bool = False, + **kwargs, + ): + # LightGlue can be used with other models than SuperPoint as keypoint detector + # We provide the trust_remote_code argument to allow the use of other models + # that are not registered in the CONFIG_MAPPING dictionary (for example DISK) + self.trust_remote_code = trust_remote_code + + if descriptor_dim % num_attention_heads != 0: + raise ValueError("descriptor_dim % num_heads is different from zero") + + self.descriptor_dim = descriptor_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.initializer_range = initializer_range + + # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention + # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = keypoint_detector_config.get("model_type", "superpoint") + if keypoint_detector_config["model_type"] not in CONFIG_MAPPING: + keypoint_detector_config = AutoConfig.from_pretrained( + keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code + ) + else: + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config, attn_implementation="eager" + ) + + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") + + self.keypoint_detector_config = keypoint_detector_config + + self.hidden_size = descriptor_dim + self.intermediate_size = descriptor_dim * 2 + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + super().__init__(**kwargs) + + +__all__ = ["LightGlueConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lightglue/image_processing_lightglue.py b/venv/lib/python3.13/site-packages/transformers/models/lightglue/image_processing_lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..400475b76c77e6f078f53ad31fe4dcffde5bef8e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lightglue/image_processing_lightglue.py @@ -0,0 +1,526 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Optional, Union + +import numpy as np +import torch + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + get_image_type, + infer_channel_dimension_format, + is_pil_image, + is_scaled_image, + is_valid_image, + is_vision_available, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_matplotlib_available, logging, requires_backends +from ...utils.import_utils import requires +from .modeling_lightglue import LightGlueKeypointMatchingOutput + + +if is_vision_available(): + from PIL import Image, ImageDraw + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +def is_grayscale( + image: np.ndarray, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True + return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) + elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True + return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) + + +def convert_to_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch + and tensorflow grayscale conversion + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (Image): + The image to convert. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + requires_backends(convert_to_grayscale, ["vision"]) + + if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image + if input_data_format == ChannelDimension.FIRST: + gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=0) + elif input_data_format == ChannelDimension.LAST: + gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=-1) + return gray_image + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("L") + return image + + +def validate_and_format_image_pairs(images: ImageInput): + error_message = ( + "Input images must be a one of the following :", + " - A pair of PIL images.", + " - A pair of 3D arrays.", + " - A list of pairs of PIL images.", + " - A list of pairs of 3D arrays.", + ) + + def _is_valid_image(image): + """images is a PIL Image or a 3D array.""" + return is_pil_image(image) or ( + is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3 + ) + + if isinstance(images, list): + if len(images) == 2 and all((_is_valid_image(image)) for image in images): + return images + if all( + isinstance(image_pair, list) + and len(image_pair) == 2 + and all(_is_valid_image(image) for image in image_pair) + for image_pair in images + ): + return [image for image_pair in images for image in image_pair] + raise ValueError(error_message) + + +@requires(backends=("torch",)) +class LightGlueImageProcessor(BaseImageProcessor): + r""" + Constructs a LightGlue image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden + by `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): + Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to + `True`. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_grayscale (`bool`, *optional*, defaults to `True`): + Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_grayscale: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 480, "width": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be inferred from the input + image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + size = get_size_dict(size, default_to_square=False) + + return resize( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_grayscale: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with + pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set + `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + # Validate and convert the input images into a flattened list of images for all subsequent processing steps. + images = validate_and_format_image_pairs(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_grayscale: + image = convert_to_grayscale(image, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + all_images.append(image) + + # Convert back the flattened list of images into a list of pairs of images. + image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)] + + data = {"pixel_values": image_pairs} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_matching( + self, + outputs: LightGlueKeypointMatchingOutput, + target_sizes: Union[TensorType, list[tuple]], + threshold: float = 0.0, + ) -> list[dict[str, torch.Tensor]]: + """ + Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + Args: + outputs ([`KeypointMatchingOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` or `list[tuple[tuple[int, int]]]`, *optional*): + Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`tuple[int, int]`) containing the + target size `(height, width)` of each image in the batch. This must be the original image size (before + any processing). + threshold (`float`, *optional*, defaults to 0.0): + Threshold to filter out the matches with low scores. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image + of the pair, the matching scores and the matching indices. + """ + if outputs.mask.shape[0] != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + if not all(len(target_size) == 2 for target_size in target_sizes): + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + if isinstance(target_sizes, list): + image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_pair_sizes = target_sizes + + keypoints = outputs.keypoints.clone() + keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2) + keypoints = keypoints.to(torch.int32) + + results = [] + for mask_pair, keypoints_pair, matches, scores in zip( + outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0] + ): + mask0 = mask_pair[0] > 0 + mask1 = mask_pair[1] > 0 + keypoints0 = keypoints_pair[0][mask0] + keypoints1 = keypoints_pair[1][mask1] + matches0 = matches[mask0] + scores0 = scores[mask0] + + # Filter out matches with low scores + valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1) + + matched_keypoints0 = keypoints0[valid_matches] + matched_keypoints1 = keypoints1[matches0[valid_matches]] + matching_scores = scores0[valid_matches] + + results.append( + { + "keypoints0": matched_keypoints0, + "keypoints1": matched_keypoints1, + "matching_scores": matching_scores, + } + ) + + return results + + def visualize_keypoint_matching( + self, + images: ImageInput, + keypoint_matching_output: list[dict[str, torch.Tensor]], + ) -> list["Image.Image"]: + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 + images or a list of list of 2 images list with pixel values ranging from 0 to 255. + keypoint_matching_output (List[Dict[str, torch.Tensor]]]): + A post processed keypoint matching output + + Returns: + `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected + keypoints as well as the matching between them. + """ + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + results = [] + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8) + plot_image[:height0, :width0] = image_pair[0] + plot_image[:height1, width0:] = image_pair[1] + + plot_image_pil = Image.fromarray(plot_image) + draw = ImageDraw.Draw(plot_image_pil) + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + color = self._get_color(matching_score) + draw.line( + (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y), + fill=color, + width=3, + ) + draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black") + draw.ellipse( + (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2), + fill="black", + ) + + results.append(plot_image_pil) + return results + + def _get_color(self, score): + """Maps a score to a color.""" + r = int(255 * (1 - score)) + g = int(255 * score) + b = 0 + return (r, g, b) + + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires + matplotlib to be installed. + + .. deprecated:: + `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or + a list of list of 2 images list with pixel values ranging from 0 to 255. + keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]): + Raw outputs of the model. + """ + warnings.warn( + "`plot_keypoint_matching` is deprecated and will be removed in transformers v. " + "Use `visualize_keypoint_matching` instead.", + FutureWarning, + ) + + if is_matplotlib_available(): + import matplotlib.pyplot as plt + else: + raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method") + + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3)) + plot_image[:height0, :width0] = image_pair[0] / 255.0 + plot_image[:height1, width0:] = image_pair[1] / 255.0 + plt.imshow(plot_image) + plt.axis("off") + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + plt.plot( + [keypoint0_x, keypoint1_x + width0], + [keypoint0_y, keypoint1_y], + color=plt.get_cmap("RdYlGn")(matching_score.item()), + alpha=0.9, + linewidth=0.5, + ) + plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) + plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2) + plt.show() + + +__all__ = ["LightGlueImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lightglue/modeling_lightglue.py b/venv/lib/python3.13/site-packages/transformers/models/lightglue/modeling_lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9faa3e4e0439750057ff2e1cab19dff16c9867 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lightglue/modeling_lightglue.py @@ -0,0 +1,920 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import can_return_tuple +from ..auto.modeling_auto import AutoModelForKeypointDetection +from .configuration_lightglue import LightGlueConfig + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching, + the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the + batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask + tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint + matching information. + """ +) +class LightGlueKeypointMatchingOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Pruning mask indicating which keypoints are removed and at which layer. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching + information. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)` returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True` + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)` returned when `output_attentions=True` is passed or when + `config.output_attentions=True` + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + prune: Optional[torch.IntTensor] = None + mask: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +class LightGluePositionalEncoder(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False) + + def forward( + self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + projected_keypoints = self.projector(keypoints) + embeddings = projected_keypoints.repeat_interleave(2, dim=-1) + cosines = torch.cos(embeddings) + sines = torch.sin(embeddings) + embeddings = (cosines, sines) + output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,) + return output + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class LightGlueAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LightGlueMLP(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class LightGlueTransformerLayer(nn.Module): + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.self_attention = LightGlueAttention(config, layer_idx) + self.self_mlp = LightGlueMLP(config) + self.cross_attention = LightGlueAttention(config, layer_idx) + self.cross_mlp = LightGlueMLP(config) + + def forward( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + attention_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + batch_size, num_keypoints, descriptor_dim = descriptors.shape + + # Self attention block + attention_output, self_attentions = self.self_attention( + descriptors, + position_embeddings=keypoints, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + intermediate_states = torch.cat([descriptors, attention_output], dim=-1) + output_states = self.self_mlp(intermediate_states) + self_attention_descriptors = descriptors + output_states + + if output_hidden_states: + self_attention_hidden_states = (intermediate_states, output_states) + + # Reshape hidden_states to group by image_pairs : + # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + # Flip dimension 1 to perform cross attention : + # (image0, image1) -> (image1, image0) + # Reshape back to original shape : + # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim) + encoder_hidden_states = ( + self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim) + .flip(1) + .reshape(batch_size, num_keypoints, descriptor_dim) + ) + # Same for mask + encoder_attention_mask = ( + attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if attention_mask is not None + else None + ) + + # Cross attention block + cross_attention_output, cross_attentions = self.cross_attention( + self_attention_descriptors, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1) + cross_output_states = self.cross_mlp(cross_intermediate_states) + descriptors = self_attention_descriptors + cross_output_states + + if output_hidden_states: + cross_attention_hidden_states = (cross_intermediate_states, cross_output_states) + all_hidden_states = ( + all_hidden_states + + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + self_attention_hidden_states + + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + cross_attention_hidden_states + ) + + if output_attentions: + all_attentions = all_attentions + (self_attentions,) + (cross_attentions,) + + return descriptors, all_hidden_states, all_attentions + + +def sigmoid_log_double_softmax( + similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape + certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2) + scores0 = nn.functional.log_softmax(similarity, 2) + scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0) + scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties + scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1)) + scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1)) + return scores + + +class LightGlueMatchAssignmentLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.descriptor_dim = config.descriptor_dim + self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True) + self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True) + + def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + batch_size, num_keypoints, descriptor_dim = descriptors.shape + # Final projection and similarity computation + m_descriptors = self.final_projection(descriptors) + m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25 + m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim) + m_descriptors0 = m_descriptors[:, 0] + m_descriptors1 = m_descriptors[:, 1] + similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2) + if mask is not None: + mask = mask.reshape(batch_size // 2, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1) + mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2) + mask = mask0 * mask1 + similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min) + + # Compute matchability of descriptors + matchability = self.matchability(descriptors) + matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1) + matchability_0 = matchability[:, 0] + matchability_1 = matchability[:, 1] + + # Compute scores from similarity and matchability + scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1) + return scores + + def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor: + """Get matchability of descriptors as a probability""" + matchability = self.matchability(descriptors) + matchability = nn.functional.sigmoid(matchability).squeeze(-1) + return matchability + + +class LightGlueTokenConfidenceLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.token = nn.Linear(config.descriptor_dim, 1) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + token = self.token(descriptors.detach()) + token = nn.functional.sigmoid(token).squeeze(-1) + return token + + +@auto_docstring +class LightGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config: LightGlueConfig + base_model_prefix = "lightglue" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _supports_flash_attn = True + _supports_sdpa = True + + +def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]: + """obtain matches from a score matrix [Bx M+1 x N+1]""" + batch_size, _, _ = scores.shape + # For each keypoint, get the best match + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + matches0 = max0.indices + matches1 = max1.indices + + # Mutual check for matches + indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None] + indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None] + mutual0 = indices0 == matches1.gather(1, matches0) + mutual1 = indices1 == matches0.gather(1, matches1) + + # Get matching scores and filter based on mutual check and thresholding + max0 = max0.values.exp() + zero = max0.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero) + valid0 = mutual0 & (matching_scores0 > threshold) + valid1 = mutual1 & valid0.gather(1, matches1) + + # Filter matches based on mutual check and thresholding of scores + matches0 = torch.where(valid0, matches0, -1) + matches1 = torch.where(valid1, matches1, -1) + matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1) + matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1) + + return matches, matching_scores + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + shift = size / 2 + scale = size.max(-1).values / 2 + keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None] + return keypoints + + +@auto_docstring( + custom_intro=""" + LightGlue model taking images as inputs and outputting the matching of them. + """ +) +class LightGlueForKeypointMatching(LightGluePreTrainedModel): + """ + LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as + SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient. + It consists of : + 1. Keypoint Encoder + 2. A Graph Neural Network with self and cross attention layers + 3. Matching Assignment layers + + The correspondence ids use -1 to indicate non-matching points. + + Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed. + In ICCV 2023. https://huggingface.co/papers/2306.13643 + """ + + def __init__(self, config: LightGlueConfig): + super().__init__(config) + self.keypoint_detector = AutoModelForKeypointDetection.from_config( + config.keypoint_detector_config, trust_remote_code=config.trust_remote_code + ) + + self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim + self.descriptor_dim = config.descriptor_dim + self.num_layers = config.num_hidden_layers + self.filter_threshold = config.filter_threshold + self.depth_confidence = config.depth_confidence + self.width_confidence = config.width_confidence + + if self.descriptor_dim != self.keypoint_detector_descriptor_dim: + self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True) + else: + self.input_projection = nn.Identity() + + self.positional_encoder = LightGluePositionalEncoder(config) + + self.transformer_layers = nn.ModuleList( + [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.match_assignment_layers = nn.ModuleList( + [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.token_confidence = nn.ModuleList( + [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)] + ) + + self.post_init() + + def _get_confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold for a given layer""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers) + return np.clip(threshold, 0, 1) + + def _keypoint_processing( + self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + descriptors = descriptors.detach().contiguous() + projected_descriptors = self.input_projection(descriptors) + keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states) + return projected_descriptors, keypoint_encoding_output + + def _get_early_stopped_image_pairs( + self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor + ) -> torch.Tensor: + """evaluate whether we should stop inference based on the confidence of the keypoints""" + batch_size, _ = mask.shape + if layer_index < self.num_layers - 1: + # If the current layer is not the last layer, we compute the confidence of the keypoints and check + # if we should stop the forward pass through the transformer layers for each pair of images. + keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1) + keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1) + threshold = self._get_confidence_threshold(layer_index) + ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points + early_stopped_pairs = ratio_confident > self.depth_confidence + else: + # If the current layer is the last layer, we stop the forward pass through the transformer layers for + # all pairs of images. + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + return early_stopped_pairs + + def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None): + if early_stops is not None: + descriptors = descriptors[early_stops] + mask = mask[early_stops] + scores = self.match_assignment_layers[layer_index](descriptors, mask) + matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold) + return matches, matching_scores + + def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self._get_confidence_threshold(layer_index) + return keep + + def _do_layer_keypoint_pruning( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + mask: torch.Tensor, + indices: torch.Tensor, + prune_output: torch.Tensor, + keypoint_confidences: torch.Tensor, + layer_index: int, + ): + """ + For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the + descriptors. + """ + batch_size, _, _ = descriptors.shape + descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors) + pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index) + pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False)) + + # For each image, we extract the pruned indices and the corresponding descriptors and keypoints. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = ( + [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)] + for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices] + ) + for i in range(batch_size): + prune_output[i, pruned_indices[i]] += 1 + + # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = ( + pad_sequence(pruned_tensor, batch_first=True) + for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask] + ) + pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1) + pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1) + + return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output + + def _concat_early_stopped_outputs( + self, + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ): + early_stops_indices = torch.stack(early_stops_indices) + # Rearrange tensors to have the same order as the input batch + ids = torch.arange(early_stops_indices.shape[0]) + order_indices = early_stops_indices[ids] + early_stops_indices = early_stops_indices[order_indices] + matches, final_pruned_keypoints_indices = ( + pad_sequence(tensor, batch_first=True, padding_value=-1) + for tensor in [matches, final_pruned_keypoints_indices] + ) + matching_scores, final_pruned_keypoints_iterations = ( + pad_sequence(tensor, batch_first=True, padding_value=0) + for tensor in [matching_scores, final_pruned_keypoints_iterations] + ) + matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = ( + tensor[early_stops_indices] + for tensor in [ + matches, + matching_scores, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + ] + ) + return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores + + def _do_final_keypoint_pruning( + self, + indices: torch.Tensor, + matches: torch.Tensor, + matching_scores: torch.Tensor, + num_keypoints: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to + # have tensors from + batch_size, _ = indices.shape + indices, matches, matching_scores = ( + tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores] + ) + indices0 = indices[:, 0] + indices1 = indices[:, 1] + matches0 = matches[:, 0] + matches1 = matches[:, 1] + matching_scores0 = matching_scores[:, 0] + matching_scores1 = matching_scores[:, 1] + + # Prepare final matches and matching scores + _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype) + _matching_scores = torch.zeros( + (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype + ) + # Fill the matches and matching scores for each image pair + for i in range(batch_size // 2): + _matches[i, 0, indices0[i]] = torch.where( + matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0)) + ) + _matches[i, 1, indices1[i]] = torch.where( + matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0)) + ) + _matching_scores[i, 0, indices0[i]] = matching_scores0[i] + _matching_scores[i, 1, indices1[i]] = matching_scores1[i] + return _matches, _matching_scores + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + height: int, + width: int, + mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + device = keypoints.device + batch_size, _, initial_num_keypoints, _ = keypoints.shape + num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1) + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2) + mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None + descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim) + image_indices = torch.arange(batch_size * 2, device=device) + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + descriptors, keypoint_encoding_output = self._keypoint_processing( + descriptors, keypoints, output_hidden_states=output_hidden_states + ) + + keypoints = keypoint_encoding_output[0] + + # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the + # keypoints is above a certain threshold. + do_early_stop = self.depth_confidence > 0 + # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of + # the keypoints is below a certain threshold. + do_keypoint_pruning = self.width_confidence > 0 + + early_stops_indices = [] + matches = [] + matching_scores = [] + final_pruned_keypoints_indices = [] + final_pruned_keypoints_iterations = [] + + pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1) + pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices) + + for layer_index in range(self.num_layers): + input_shape = descriptors.size() + if mask is not None: + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device) + layer_output = self.transformer_layers[layer_index]( + descriptors, + keypoints, + attention_mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors, hidden_states, attention = layer_output + if output_hidden_states: + all_hidden_states = all_hidden_states + hidden_states + if output_attentions: + all_attentions = all_attentions + attention + + if do_early_stop: + if layer_index < self.num_layers - 1: + # Get the confidence of the keypoints for the current layer + keypoint_confidences = self.token_confidence[layer_index](descriptors) + + # Determine which pairs of images should be early stopped based on the confidence of the keypoints for + # the current layer. + early_stopped_pairs = self._get_early_stopped_image_pairs( + keypoint_confidences, layer_index, mask, num_points=num_points_per_pair + ) + else: + # Early stopping always occurs at the last layer + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + + if torch.any(early_stopped_pairs): + # If a pair of images is considered early stopped, we compute the matches for the remaining + # keypoints and stop the forward pass through the transformer layers for this pair of images. + early_stops = early_stopped_pairs.repeat_interleave(2) + early_stopped_image_indices = image_indices[early_stops] + early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching( + descriptors, mask, layer_index, early_stops=early_stops + ) + early_stops_indices.extend(list(early_stopped_image_indices)) + matches.extend(list(early_stopped_matches)) + matching_scores.extend(list(early_stopped_matching_scores)) + if do_keypoint_pruning: + final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops])) + final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops])) + + # Remove image pairs that have been early stopped from the forward pass + num_points_per_pair = num_points_per_pair[~early_stopped_pairs] + descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple( + tensor[~early_stops] + for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices] + ) + keypoints = (keypoints_0, keypoint_1) + if do_keypoint_pruning: + pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple( + tensor[~early_stops] + for tensor in [ + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + ] + ) + # If all pairs of images are early stopped, we stop the forward pass through the transformer + # layers for all pairs of images. + if torch.all(early_stopped_pairs): + break + + if do_keypoint_pruning: + # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of + # the keypoints is below a certain threshold. + descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = ( + self._do_layer_keypoint_pruning( + descriptors, + keypoints, + mask, + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + layer_index, + ) + ) + + if do_early_stop and do_keypoint_pruning: + # Concatenate early stopped outputs together and perform final keypoint pruning + final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = ( + self._concat_early_stopped_outputs( + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ) + ) + matches, matching_scores = self._do_final_keypoint_pruning( + final_pruned_keypoints_indices, + matches, + matching_scores, + initial_num_keypoints, + ) + else: + matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1) + final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers + + final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape( + batch_size, 2, initial_num_keypoints + ) + + return ( + matches, + matching_scores, + final_pruned_keypoints_iterations, + all_hidden_states, + all_attentions, + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[tuple, LightGlueKeypointMatchingOutput]: + loss = None + if labels is not None: + raise ValueError("LightGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, _, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return LightGlueKeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + prune=prune, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/lightglue/modular_lightglue.py b/venv/lib/python3.13/site-packages/transformers/models/lightglue/modular_lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..29441344c9cdfbdc9c74afa7bec493c19166844c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/lightglue/modular_lightglue.py @@ -0,0 +1,1078 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import PretrainedConfig +from ...image_utils import ImageInput, is_vision_available, to_numpy_array +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging +from ...utils.generic import can_return_tuple +from ..auto import CONFIG_MAPPING, AutoConfig +from ..auto.modeling_auto import AutoModelForKeypointDetection +from ..clip.modeling_clip import CLIPMLP +from ..cohere.modeling_cohere import apply_rotary_pos_emb +from ..llama.modeling_llama import LlamaAttention, eager_attention_forward +from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs +from ..superpoint import SuperPointConfig + + +if is_vision_available(): + from PIL import Image, ImageDraw + + +logger = logging.get_logger(__name__) + + +class LightGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to + instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LightGlue + [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + descriptor_dim (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + num_hidden_layers (`int`, *optional*, defaults to 9): + The number of self and cross attention layers. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the multi-head attention. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + depth_confidence (`float`, *optional*, defaults to 0.95): + The confidence threshold used to perform early stopping + width_confidence (`float`, *optional*, defaults to 0.99): + The confidence threshold used to prune points + filter_threshold (`float`, *optional*, defaults to 0.1): + The confidence threshold used to filter matches + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function to be used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether to trust remote code when using other models than SuperPoint as keypoint detector. + + Examples: + ```python + >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching + + >>> # Initializing a LightGlue style configuration + >>> configuration = LightGlueConfig() + + >>> # Initializing a model from the LightGlue style configuration + >>> model = LightGlueForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightglue" + sub_configs = {"keypoint_detector_config": AutoConfig} + + def __init__( + self, + keypoint_detector_config: SuperPointConfig = None, + descriptor_dim: int = 256, + num_hidden_layers: int = 9, + num_attention_heads: int = 4, + num_key_value_heads=None, + depth_confidence: float = 0.95, + width_confidence: float = 0.99, + filter_threshold: float = 0.1, + initializer_range: float = 0.02, + hidden_act: str = "gelu", + attention_dropout=0.0, + attention_bias=True, + trust_remote_code: bool = False, + **kwargs, + ): + # LightGlue can be used with other models than SuperPoint as keypoint detector + # We provide the trust_remote_code argument to allow the use of other models + # that are not registered in the CONFIG_MAPPING dictionary (for example DISK) + self.trust_remote_code = trust_remote_code + + if descriptor_dim % num_attention_heads != 0: + raise ValueError("descriptor_dim % num_heads is different from zero") + + self.descriptor_dim = descriptor_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.initializer_range = initializer_range + + # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention + # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = keypoint_detector_config.get("model_type", "superpoint") + if keypoint_detector_config["model_type"] not in CONFIG_MAPPING: + keypoint_detector_config = AutoConfig.from_pretrained( + keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code + ) + else: + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config, attn_implementation="eager" + ) + + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") + + self.keypoint_detector_config = keypoint_detector_config + + self.hidden_size = descriptor_dim + self.intermediate_size = descriptor_dim * 2 + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + super().__init__(**kwargs) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching, + the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the + batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask + tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint + matching information. + """ +) +class LightGlueKeypointMatchingOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Pruning mask indicating which keypoints are removed and at which layer. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching + information. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)` returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True` + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)` returned when `output_attentions=True` is passed or when + `config.output_attentions=True` + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + prune: Optional[torch.IntTensor] = None + mask: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +class LightGlueImageProcessor(SuperGlueImageProcessor): + def post_process_keypoint_matching( + self, + outputs: LightGlueKeypointMatchingOutput, + target_sizes: Union[TensorType, list[tuple]], + threshold: float = 0.0, + ) -> list[dict[str, torch.Tensor]]: + return super().post_process_keypoint_matching(outputs, target_sizes, threshold) + + # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue + def visualize_keypoint_matching( + self, + images: ImageInput, + keypoint_matching_output: list[dict[str, torch.Tensor]], + ) -> list["Image.Image"]: + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 + images or a list of list of 2 images list with pixel values ranging from 0 to 255. + keypoint_matching_output (List[Dict[str, torch.Tensor]]]): + A post processed keypoint matching output + + Returns: + `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected + keypoints as well as the matching between them. + """ + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + results = [] + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8) + plot_image[:height0, :width0] = image_pair[0] + plot_image[:height1, width0:] = image_pair[1] + + plot_image_pil = Image.fromarray(plot_image) + draw = ImageDraw.Draw(plot_image_pil) + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + color = self._get_color(matching_score) + draw.line( + (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y), + fill=color, + width=3, + ) + draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black") + draw.ellipse( + (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2), + fill="black", + ) + + results.append(plot_image_pil) + return results + + # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color + def _get_color(self, score): + """Maps a score to a color.""" + r = int(255 * (1 - score)) + g = int(255 * score) + b = 0 + return (r, g, b) + + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires + matplotlib to be installed. + + .. deprecated:: + `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or + a list of list of 2 images list with pixel values ranging from 0 to 255. + keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]): + Raw outputs of the model. + """ + warnings.warn( + "`plot_keypoint_matching` is deprecated and will be removed in transformers v. " + "Use `visualize_keypoint_matching` instead.", + FutureWarning, + ) + + if is_matplotlib_available(): + import matplotlib.pyplot as plt + else: + raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method") + + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3)) + plot_image[:height0, :width0] = image_pair[0] / 255.0 + plot_image[:height1, width0:] = image_pair[1] / 255.0 + plt.imshow(plot_image) + plt.axis("off") + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + plt.plot( + [keypoint0_x, keypoint1_x + width0], + [keypoint0_y, keypoint1_y], + color=plt.get_cmap("RdYlGn")(matching_score.item()), + alpha=0.9, + linewidth=0.5, + ) + plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) + plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2) + plt.show() + + +class LightGluePositionalEncoder(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False) + + def forward( + self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + projected_keypoints = self.projector(keypoints) + embeddings = projected_keypoints.repeat_interleave(2, dim=-1) + cosines = torch.cos(embeddings) + sines = torch.sin(embeddings) + embeddings = (cosines, sines) + output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,) + return output + + +class LightGlueAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LightGlueMLP(CLIPMLP): + def __init__(self, config: LightGlueConfig): + super().__init__(config) + self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size) + self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class LightGlueTransformerLayer(nn.Module): + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.self_attention = LightGlueAttention(config, layer_idx) + self.self_mlp = LightGlueMLP(config) + self.cross_attention = LightGlueAttention(config, layer_idx) + self.cross_mlp = LightGlueMLP(config) + + def forward( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + attention_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + batch_size, num_keypoints, descriptor_dim = descriptors.shape + + # Self attention block + attention_output, self_attentions = self.self_attention( + descriptors, + position_embeddings=keypoints, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + intermediate_states = torch.cat([descriptors, attention_output], dim=-1) + output_states = self.self_mlp(intermediate_states) + self_attention_descriptors = descriptors + output_states + + if output_hidden_states: + self_attention_hidden_states = (intermediate_states, output_states) + + # Reshape hidden_states to group by image_pairs : + # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + # Flip dimension 1 to perform cross attention : + # (image0, image1) -> (image1, image0) + # Reshape back to original shape : + # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim) + encoder_hidden_states = ( + self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim) + .flip(1) + .reshape(batch_size, num_keypoints, descriptor_dim) + ) + # Same for mask + encoder_attention_mask = ( + attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if attention_mask is not None + else None + ) + + # Cross attention block + cross_attention_output, cross_attentions = self.cross_attention( + self_attention_descriptors, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1) + cross_output_states = self.cross_mlp(cross_intermediate_states) + descriptors = self_attention_descriptors + cross_output_states + + if output_hidden_states: + cross_attention_hidden_states = (cross_intermediate_states, cross_output_states) + all_hidden_states = ( + all_hidden_states + + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + self_attention_hidden_states + + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + cross_attention_hidden_states + ) + + if output_attentions: + all_attentions = all_attentions + (self_attentions,) + (cross_attentions,) + + return descriptors, all_hidden_states, all_attentions + + +def sigmoid_log_double_softmax( + similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape + certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2) + scores0 = nn.functional.log_softmax(similarity, 2) + scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0) + scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties + scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1)) + scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1)) + return scores + + +class LightGlueMatchAssignmentLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.descriptor_dim = config.descriptor_dim + self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True) + self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True) + + def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + batch_size, num_keypoints, descriptor_dim = descriptors.shape + # Final projection and similarity computation + m_descriptors = self.final_projection(descriptors) + m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25 + m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim) + m_descriptors0 = m_descriptors[:, 0] + m_descriptors1 = m_descriptors[:, 1] + similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2) + if mask is not None: + mask = mask.reshape(batch_size // 2, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1) + mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2) + mask = mask0 * mask1 + similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min) + + # Compute matchability of descriptors + matchability = self.matchability(descriptors) + matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1) + matchability_0 = matchability[:, 0] + matchability_1 = matchability[:, 1] + + # Compute scores from similarity and matchability + scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1) + return scores + + def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor: + """Get matchability of descriptors as a probability""" + matchability = self.matchability(descriptors) + matchability = nn.functional.sigmoid(matchability).squeeze(-1) + return matchability + + +class LightGlueTokenConfidenceLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.token = nn.Linear(config.descriptor_dim, 1) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + token = self.token(descriptors.detach()) + token = nn.functional.sigmoid(token).squeeze(-1) + return token + + +@auto_docstring +class LightGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config: LightGlueConfig + base_model_prefix = "lightglue" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _supports_flash_attn = True + _supports_sdpa = True + + +def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]: + """obtain matches from a score matrix [Bx M+1 x N+1]""" + batch_size, _, _ = scores.shape + # For each keypoint, get the best match + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + matches0 = max0.indices + matches1 = max1.indices + + # Mutual check for matches + indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None] + indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None] + mutual0 = indices0 == matches1.gather(1, matches0) + mutual1 = indices1 == matches0.gather(1, matches1) + + # Get matching scores and filter based on mutual check and thresholding + max0 = max0.values.exp() + zero = max0.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero) + valid0 = mutual0 & (matching_scores0 > threshold) + valid1 = mutual1 & valid0.gather(1, matches1) + + # Filter matches based on mutual check and thresholding of scores + matches0 = torch.where(valid0, matches0, -1) + matches1 = torch.where(valid1, matches1, -1) + matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1) + matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1) + + return matches, matching_scores + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + shift = size / 2 + scale = size.max(-1).values / 2 + keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None] + return keypoints + + +@auto_docstring( + custom_intro=""" + LightGlue model taking images as inputs and outputting the matching of them. + """ +) +class LightGlueForKeypointMatching(LightGluePreTrainedModel): + """ + LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as + SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient. + It consists of : + 1. Keypoint Encoder + 2. A Graph Neural Network with self and cross attention layers + 3. Matching Assignment layers + + The correspondence ids use -1 to indicate non-matching points. + + Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed. + In ICCV 2023. https://huggingface.co/papers/2306.13643 + """ + + def __init__(self, config: LightGlueConfig): + super().__init__(config) + self.keypoint_detector = AutoModelForKeypointDetection.from_config( + config.keypoint_detector_config, trust_remote_code=config.trust_remote_code + ) + + self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim + self.descriptor_dim = config.descriptor_dim + self.num_layers = config.num_hidden_layers + self.filter_threshold = config.filter_threshold + self.depth_confidence = config.depth_confidence + self.width_confidence = config.width_confidence + + if self.descriptor_dim != self.keypoint_detector_descriptor_dim: + self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True) + else: + self.input_projection = nn.Identity() + + self.positional_encoder = LightGluePositionalEncoder(config) + + self.transformer_layers = nn.ModuleList( + [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.match_assignment_layers = nn.ModuleList( + [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.token_confidence = nn.ModuleList( + [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)] + ) + + self.post_init() + + def _get_confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold for a given layer""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers) + return np.clip(threshold, 0, 1) + + def _keypoint_processing( + self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + descriptors = descriptors.detach().contiguous() + projected_descriptors = self.input_projection(descriptors) + keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states) + return projected_descriptors, keypoint_encoding_output + + def _get_early_stopped_image_pairs( + self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor + ) -> torch.Tensor: + """evaluate whether we should stop inference based on the confidence of the keypoints""" + batch_size, _ = mask.shape + if layer_index < self.num_layers - 1: + # If the current layer is not the last layer, we compute the confidence of the keypoints and check + # if we should stop the forward pass through the transformer layers for each pair of images. + keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1) + keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1) + threshold = self._get_confidence_threshold(layer_index) + ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points + early_stopped_pairs = ratio_confident > self.depth_confidence + else: + # If the current layer is the last layer, we stop the forward pass through the transformer layers for + # all pairs of images. + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + return early_stopped_pairs + + def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None): + if early_stops is not None: + descriptors = descriptors[early_stops] + mask = mask[early_stops] + scores = self.match_assignment_layers[layer_index](descriptors, mask) + matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold) + return matches, matching_scores + + def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self._get_confidence_threshold(layer_index) + return keep + + def _do_layer_keypoint_pruning( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + mask: torch.Tensor, + indices: torch.Tensor, + prune_output: torch.Tensor, + keypoint_confidences: torch.Tensor, + layer_index: int, + ): + """ + For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the + descriptors. + """ + batch_size, _, _ = descriptors.shape + descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors) + pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index) + pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False)) + + # For each image, we extract the pruned indices and the corresponding descriptors and keypoints. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = ( + [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)] + for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices] + ) + for i in range(batch_size): + prune_output[i, pruned_indices[i]] += 1 + + # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = ( + pad_sequence(pruned_tensor, batch_first=True) + for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask] + ) + pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1) + pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1) + + return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output + + def _concat_early_stopped_outputs( + self, + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ): + early_stops_indices = torch.stack(early_stops_indices) + # Rearrange tensors to have the same order as the input batch + ids = torch.arange(early_stops_indices.shape[0]) + order_indices = early_stops_indices[ids] + early_stops_indices = early_stops_indices[order_indices] + matches, final_pruned_keypoints_indices = ( + pad_sequence(tensor, batch_first=True, padding_value=-1) + for tensor in [matches, final_pruned_keypoints_indices] + ) + matching_scores, final_pruned_keypoints_iterations = ( + pad_sequence(tensor, batch_first=True, padding_value=0) + for tensor in [matching_scores, final_pruned_keypoints_iterations] + ) + matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = ( + tensor[early_stops_indices] + for tensor in [ + matches, + matching_scores, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + ] + ) + return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores + + def _do_final_keypoint_pruning( + self, + indices: torch.Tensor, + matches: torch.Tensor, + matching_scores: torch.Tensor, + num_keypoints: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to + # have tensors from + batch_size, _ = indices.shape + indices, matches, matching_scores = ( + tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores] + ) + indices0 = indices[:, 0] + indices1 = indices[:, 1] + matches0 = matches[:, 0] + matches1 = matches[:, 1] + matching_scores0 = matching_scores[:, 0] + matching_scores1 = matching_scores[:, 1] + + # Prepare final matches and matching scores + _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype) + _matching_scores = torch.zeros( + (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype + ) + # Fill the matches and matching scores for each image pair + for i in range(batch_size // 2): + _matches[i, 0, indices0[i]] = torch.where( + matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0)) + ) + _matches[i, 1, indices1[i]] = torch.where( + matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0)) + ) + _matching_scores[i, 0, indices0[i]] = matching_scores0[i] + _matching_scores[i, 1, indices1[i]] = matching_scores1[i] + return _matches, _matching_scores + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + height: int, + width: int, + mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + device = keypoints.device + batch_size, _, initial_num_keypoints, _ = keypoints.shape + num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1) + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2) + mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None + descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim) + image_indices = torch.arange(batch_size * 2, device=device) + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + descriptors, keypoint_encoding_output = self._keypoint_processing( + descriptors, keypoints, output_hidden_states=output_hidden_states + ) + + keypoints = keypoint_encoding_output[0] + + # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the + # keypoints is above a certain threshold. + do_early_stop = self.depth_confidence > 0 + # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of + # the keypoints is below a certain threshold. + do_keypoint_pruning = self.width_confidence > 0 + + early_stops_indices = [] + matches = [] + matching_scores = [] + final_pruned_keypoints_indices = [] + final_pruned_keypoints_iterations = [] + + pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1) + pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices) + + for layer_index in range(self.num_layers): + input_shape = descriptors.size() + if mask is not None: + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device) + layer_output = self.transformer_layers[layer_index]( + descriptors, + keypoints, + attention_mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors, hidden_states, attention = layer_output + if output_hidden_states: + all_hidden_states = all_hidden_states + hidden_states + if output_attentions: + all_attentions = all_attentions + attention + + if do_early_stop: + if layer_index < self.num_layers - 1: + # Get the confidence of the keypoints for the current layer + keypoint_confidences = self.token_confidence[layer_index](descriptors) + + # Determine which pairs of images should be early stopped based on the confidence of the keypoints for + # the current layer. + early_stopped_pairs = self._get_early_stopped_image_pairs( + keypoint_confidences, layer_index, mask, num_points=num_points_per_pair + ) + else: + # Early stopping always occurs at the last layer + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + + if torch.any(early_stopped_pairs): + # If a pair of images is considered early stopped, we compute the matches for the remaining + # keypoints and stop the forward pass through the transformer layers for this pair of images. + early_stops = early_stopped_pairs.repeat_interleave(2) + early_stopped_image_indices = image_indices[early_stops] + early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching( + descriptors, mask, layer_index, early_stops=early_stops + ) + early_stops_indices.extend(list(early_stopped_image_indices)) + matches.extend(list(early_stopped_matches)) + matching_scores.extend(list(early_stopped_matching_scores)) + if do_keypoint_pruning: + final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops])) + final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops])) + + # Remove image pairs that have been early stopped from the forward pass + num_points_per_pair = num_points_per_pair[~early_stopped_pairs] + descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple( + tensor[~early_stops] + for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices] + ) + keypoints = (keypoints_0, keypoint_1) + if do_keypoint_pruning: + pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple( + tensor[~early_stops] + for tensor in [ + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + ] + ) + # If all pairs of images are early stopped, we stop the forward pass through the transformer + # layers for all pairs of images. + if torch.all(early_stopped_pairs): + break + + if do_keypoint_pruning: + # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of + # the keypoints is below a certain threshold. + descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = ( + self._do_layer_keypoint_pruning( + descriptors, + keypoints, + mask, + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + layer_index, + ) + ) + + if do_early_stop and do_keypoint_pruning: + # Concatenate early stopped outputs together and perform final keypoint pruning + final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = ( + self._concat_early_stopped_outputs( + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ) + ) + matches, matching_scores = self._do_final_keypoint_pruning( + final_pruned_keypoints_indices, + matches, + matching_scores, + initial_num_keypoints, + ) + else: + matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1) + final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers + + final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape( + batch_size, 2, initial_num_keypoints + ) + + return ( + matches, + matching_scores, + final_pruned_keypoints_iterations, + all_hidden_states, + all_attentions, + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[tuple, LightGlueKeypointMatchingOutput]: + loss = None + if labels is not None: + raise ValueError("LightGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, _, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return LightGlueKeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + prune=prune, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/llama4/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/llama4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59fe1686cec1bbad80233353820eb1517e7bb27b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/llama4/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_llama4 import * + from .image_processing_llama4_fast import * + from .modeling_llama4 import * + from .processing_llama4 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/llama4/configuration_llama4.py b/venv/lib/python3.13/site-packages/transformers/models/llama4/configuration_llama4.py new file mode 100644 index 0000000000000000000000000000000000000000..932f4975dba79d5ed97e34c2017f5ed154d34fd6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/llama4/configuration_llama4.py @@ -0,0 +1,479 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Llama4VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a + Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + num_hidden_layers (`int`, *optional*, defaults to 34): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + vision_output_dim (`int`, *optional*, defaults to 7680): + Dimensionality of the vision model output. Includes output of transformer + encoder with intermediate layers and global transformer encoder. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image *tile*. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO + projector_input_dim (`int`, *optional*, defaults to 4096): TODO + projector_output_dim (`int`, *optional*, defaults to 4096): TODO + multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO + projector_dropout (`int`, *optional*, defaults to 0.0): TODO + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + rope_theta (`int`, *optional*, defaults to 10000): TODO + """ + + base_model_tp_plan = { + "model.layers.*.self_attn.q_proj": "colwise", + "model.layers.*.self_attn.k_proj": "colwise", + "model.layers.*.self_attn.v_proj": "colwise", + "model.layers.*.self_attn.o_proj": "rowwise", + "vision_adapter.mlp.fc1": "colwise", + "vision_adapter.mlp.fc2": "rowwise", + "patch_embedding.linear": "colwise_rep", + } + model_type = "llama4_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size: int = 768, + hidden_act: str = "gelu", + num_hidden_layers: int = 34, + num_attention_heads: int = 16, + num_channels: int = 3, + intermediate_size: int = 5632, + vision_output_dim: int = 7680, + image_size: int = 448, + patch_size: int = 14, + norm_eps: float = 1e-5, + vision_feature_select_strategy="default", + initializer_range: float = 0.02, + pixel_shuffle_ratio=0.5, + projector_input_dim=4096, + projector_output_dim=4096, + multi_modal_projector_bias=False, + projector_dropout=0.0, + attention_dropout=0.0, + rope_theta=10000, + **kwargs, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.num_channels = num_channels + self.intermediate_size = intermediate_size + self.image_size = image_size + self.vision_output_dim = vision_output_dim + self.patch_size = patch_size + self.norm_eps = norm_eps + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.pixel_shuffle_ratio = pixel_shuffle_ratio + self.projector_input_dim = projector_input_dim + self.projector_output_dim = projector_output_dim + self.multi_modal_projector_bias = multi_modal_projector_bias + self.projector_dropout = projector_dropout + self.attention_dropout = attention_dropout + self.vision_feature_select_strategy = vision_feature_select_strategy + self.rope_theta = rope_theta + + self._vision_feature_layer = kwargs.get("vision_feature_layer", -1) + + @property + def vision_feature_layer(self): + warnings.warn( + "The `vision_feature_layer` attribute is deprecated and will be removed in v4.58.0.", + FutureWarning, + ) + return self._vision_feature_layer + + @vision_feature_layer.setter + def vision_feature_layer(self, value): + self._vision_feature_layer = value + + super().__init__(**kwargs) + + +class Llama4TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a + Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 202048): + Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented + by the `inputs_ids` passed when calling [`Llama4TextModel`]. + hidden_size (`int`, *optional*, defaults to 5120): + Dimensionality of the embeddings and hidden states. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 40): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If not + specified, will default to `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 128): TODO + hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + pad_token_id (`int`, *optional*, defaults to 128004): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the beginning of sentence token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to `500000.0`): + The base period of the RoPE embeddings. + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + num_experts_per_tok (`int`, *optional*, defaults to 1): TODO + num_local_experts (`int`, *optional*, defaults to 16): TODO + moe_layers (`int`, *optional*): TODO + interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO + use_qk_norm (`int`, *optional*, defaults to `True`): TODO + output_router_logits (`int`, *optional*, defaults to `False`): TODO + router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO + router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + + no_rope_layers (`list[int]`, *optional*): + List with at least the same length as the number of layers in the model. + A `1` at an index position indicates that the corresponding layer will use RoPE, + while a `0` indicates that it's a NoPE layer. + no_rope_layer_interval (`int`, *optional*, defaults to 4): + If `no_rope_layers` is `None`, it will be created using a NoPE layer every + `no_rope_layer_interval` layers. + attention_chunk_size (`int`, *optional*, defaults to 8192): + + layer_types (`list`, *optional*): + Attention pattern for each layer. + attn_temperature_tuning (`bool`, *optional*, defaults to `True`): + Whether to dynamically scale the attention temperature for each query token based on sequence length. + Recommended for long sequences (e.g., >32k tokens) to maintain stable output results. + floor_scale (`int`, *optional*, defaults to 8192): TODO + attn_scale (`int`, *optional*, defaults to 0.1): TODO + + Example: + """ + + model_type = "llama4_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.feed_forward.shared_expert.gate_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.up_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.down_proj": "local_rowwise", + "layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear + "layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear + "layers.*.feed_forward.experts": "local", + "layers.*.feed_forward.gate_proj": "local_colwise", + "layers.*.feed_forward.up_proj": "local_colwise", + "layers.*.feed_forward.down_proj": "local_rowwise", + "layers.*.feed_forward": "gather", + } + base_model_ep_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.feed_forward.experts.gate_up_proj": "grouped_gemm", # row because not linear + "layers.*.feed_forward.experts.down_proj": "grouped_gemm", # col because not linear + "layers.*.feed_forward.experts": "gather", # all reduce + "layers.*.feed_forward.gate_proj": "local_colwise", + "layers.*.feed_forward.up_proj": "local_colwise", + "layers.*.feed_forward.down_proj": "local_rowwise", + "layers.*.feed_forward.router": "ep_router", + } + + def __init__( + self, + vocab_size=202048, + hidden_size=5120, + intermediate_size=8192, + intermediate_size_mlp=16384, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=500000, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=16, + moe_layers=None, + interleave_moe_layer_step=1, + use_qk_norm=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + rope_scaling=None, + no_rope_layers=None, + no_rope_layer_interval=4, + attention_chunk_size=8192, + layer_types=None, + attn_temperature_tuning=True, + floor_scale=8192, + attn_scale=0.1, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.attn_temperature_tuning = attn_temperature_tuning + self.attn_scale = attn_scale + self.floor_scale = floor_scale + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.intermediate_size_mlp = intermediate_size_mlp + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.rope_scaling = rope_scaling + self.attention_bias = False + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.use_qk_norm = use_qk_norm + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + + # Backwards compatibility + if no_rope_layers == []: + no_rope_layers = None + + default_no_rope_layers = [ + int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers) + ] + + self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers + + self.interleave_moe_layer_step = interleave_moe_layer_step + self.moe_layers = ( + moe_layers + if moe_layers is not None + else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step)) + ) + self.attention_chunk_size = attention_chunk_size + + self.layer_types = layer_types + if layer_types is None: + self.layer_types = [ + "chunked_attention" if no_rope else "full_attention" for no_rope in self.no_rope_layers + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + +class Llama4Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an + Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vision_config (`Llama4VisionConfig`, *optional*): + The Llama4 Vision config. + text_config (`Llama4TextConfig`, *optional*): + The Llama4 Text config. + boi_token_index (`int`, *optional*, defaults to 200080): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 200081): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 200092): + The image token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + + ```python + >>> from transformers import Llama4Model, Llama4Config + + >>> # Initializing a Llama4 7B style configuration + >>> configuration = Llama4Config() + + >>> # Initializing a model from the Llama4 7B style configuration + >>> model = Llama4Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama4" + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } + sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig} + base_model_tp_plan = { + "multi_modal_projector.linear_1": "colwise_rep", + } + + def __init__( + self, + vision_config=None, + text_config=None, + boi_token_index=200080, + eoi_token_index=200081, + image_token_index=200092, + tie_word_embeddings=False, + **kwargs, + ): + if vision_config is None: + self.vision_config = Llama4VisionConfig() + logger.info("vision_config is None, using default llama4 vision config") + elif isinstance(vision_config, dict): + self.vision_config = Llama4VisionConfig(**vision_config) + elif isinstance(vision_config, Llama4VisionConfig): + self.vision_config = vision_config + + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + if text_config is None: + self.text_config = Llama4TextConfig() + logger.info("text_config is None, using default llama4 text config") + elif isinstance(text_config, dict): + self.text_config = Llama4TextConfig(**text_config) + elif isinstance(text_config, Llama4TextConfig): + self.text_config = text_config + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/llama4/image_processing_llama4_fast.py b/venv/lib/python3.13/site-packages/transformers/models/llama4/image_processing_llama4_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..6506d5749d946194c5d6757d6680ce50144bdd78 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/llama4/image_processing_llama4_fast.py @@ -0,0 +1,455 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Got-OCR-2.""" + +import math +from collections import defaultdict +from functools import lru_cache +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ImageInput, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) + + +def get_factors(dividend: int) -> set[int]: + """ + Calculate all factors of a given number, i.e. a divisor that leaves + no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}. + + Args: + dividend (int): The number to find factors for. + + Returns: + set: A set containing all factors of the number. + """ + factors_set = set() + + for i in range(1, int(dividend**0.5) + 1): + if dividend % i == 0: + factors_set.add(i) + factors_set.add(dividend // i) + return factors_set + + +def get_max_res_without_distortion( + image_size: tuple[int, int], + target_size: tuple[int, int], +) -> tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. + + Args: + image_size (tuple[int, int]): The original resolution of the image (height, width). + target_resolution (tuple[int, int]): The desired resolution to fit the image into (height, width). + Returns: + tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized. + Example: + >>> _get_max_res_without_distortion([200, 300], target_size = [450, 200]) + (134, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300]) + (450, 338) + """ + + original_height, original_width = image_size + target_height, target_width = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) + + return new_height, new_width + + +def split_to_tiles(images: torch.Tensor, num_tiles_height: int, num_tiles_width: int) -> torch.Tensor: + # Split image into number of required tiles (width x height) + batch_size, num_channels, height, width = images.size() + images = images.view( + batch_size, + num_channels, + num_tiles_height, + height // num_tiles_height, + num_tiles_width, + width // num_tiles_width, + ) + # Permute dimensions to reorder the axes + image = images.permute(0, 2, 4, 1, 3, 5).contiguous() + # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) + image = image.view( + batch_size, + num_tiles_width * num_tiles_height, + num_channels, + height // num_tiles_height, + width // num_tiles_width, + ) + return image + + +@lru_cache(maxsize=1) +def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> torch.Tensor: + """ + Computes all of the allowed resolutions for a fixed number of chunks + and patch_size. Useful for when dividing an image into chunks. + + Args: + max_num_chunks (int): Maximum number of chunks for processing. + patch_size (int): Size of the side of the patch. + + Returns: + torch.Tensor: List of possible resolutions as tuples (height, width). + + Example: + >>> max_num_chunks = 5 + >>> patch_size = 224 + >>> find_supported_resolutions(max_num_chunks, patch_size) + tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), + (672, 224), (224, 448), (448, 224)]) + + Given max_num_chunks=4, patch_size=224, it will create a dictionary: + { + 0.25: [(1, 4)], + 1.0: [(2, 2), (1, 1)], + 4.0: [(4, 1)], + 0.33: [(1, 3)], + 3.0: [(3, 1)], + 0.5: [(1, 2)], + 2.0: [(2, 1)] + } + + and return the resolutions multiplied by the patch_size: + [(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)] + """ + height, width = patch_size.height, patch_size.width + if height != width: + raise ValueError("`size` must be square.") + + patch_size = height + + asp_dict = defaultdict(list) + for chunk_size in range(max_num_chunks, 0, -1): + _factors = sorted(get_factors(chunk_size)) + _asp_ratios = [(factor, chunk_size // factor) for factor in _factors] + for height, width in _asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the patch_size + possible_resolutions = [] + for value in asp_dict.values(): + for height, depth in value: + possible_resolutions.append((height * patch_size, depth * patch_size)) + + return possible_resolutions + + +def pad_to_best_fit( + images: "torch.Tensor", + target_size: tuple[int, int], + background_color: Union[int, tuple[int, int, int]] = 0, +) -> "torch.Tensor": + """ + Pads an image to fit the target size. + + Args: + images (`np.ndarray`): + The images to pad. + background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in multi-channel mode, it will default to `0` in subsequent channels. + Returns: + `torch.Tensor`: The padded images. + """ + + num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0] + if isinstance(background_color, int): + background_color = [background_color] + [0] * (num_channels - 1) + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + height, width = images.shape[-2:] + target_height, target_width = target_size + paste_x_right = target_width - width + paste_y_right = target_height - height + padded_images = F.pad(images, padding=[0, 0, paste_x_right, paste_y_right], fill=background_color) + + return padded_images + + +def get_best_fit( + image_size: tuple[int, int], + possible_resolutions: torch.Tensor, + resize_to_max_canvas: bool = False, +) -> tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to, without distortion, + resize an image to. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. to match the canvas you can upscale height by 2x, and width by 1.5x, + therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5. + + If upscaling is possible (any of the scaling factors is greater than 1), + then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True. + + If upscaling is not possible, then picks the largest scaling factor <= 1, i.e. + reduce downscaling as much as possible. + + If there are multiple resolutions with the same max scale, we pick the one with the lowest area, + to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter + has more padding. + + Args: + image_size (tuple[int, int]): A tuple containing the height and width of the image. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible resolution (height, width). + resize_to_max_canvas (bool): If True, will return the largest upscaling resolution. + + Returns: + list[int]: The best resolution [height, width] for the given image. + + Example: + >>> image_size = (200, 300) + >>> possible_resolutions = torch.tensor([[224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224]]) + >>> get_best_fit(image_size, possible_resolutions) + [224, 448] + + We have: + scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + Only one of the scales > 1: + upscaling_possible = tensor([1.1200, 1.1200]) + smallest_rescale = tensor(1.1200) + So we pick the resolution with the smallest smallest area: + areas = tensor([150528, 100352]) # [672, 224], [224, 448] + optimal_canvas = tensor([224, 448]) + """ + + original_height, original_width = image_size + + # get all possible resolutions heights/widths + target_heights, target_widths = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # get scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get the min scale between width and height (limiting side -> no distortion) + scales = torch.where(scale_h > scale_w, scale_w, scale_h) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return optimal_canvas + + +class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + max_patches (`int`, *optional*, defaults to 16): + The maximum number of patches to be extracted from the image. + Can be overridden by the `max_patches` parameter in the `preprocess` method. + resize_to_max_canvas (`bool`, *optional*, defaults to False): + Whether to resize the image to the maximum canvas size. + If True, picks the canvas the allows the largest resizing without distortion. + If False, downsample as little as possible, including no resizing at all, + but never upsample, unless the image is smaller than the patch size. + """ + + max_patches: Optional[int] + resize_to_max_canvas: Optional[bool] + + +@auto_docstring +class Llama4ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + size = {"height": 336, "width": 336} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + max_patches = 16 + resize_to_max_canvas = False + valid_kwargs = Llama4ImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]): + super().__init__(**kwargs) + + # Disable compilation here as conversion to bfloat16 causes differences in the output of the compiled and non-compiled versions + @torch.compiler.disable + def rescale_and_normalize( + self, + images: "torch.Tensor", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, list[float]], + image_std: Union[float, list[float]], + ) -> "torch.Tensor": + """ + Rescale and normalize images. + Override to rescale and normalize the images in torch.bfloat16 as in the original implementation + """ + if do_rescale and do_normalize: + images = images.to(dtype=torch.bfloat16) * rescale_factor + images = self.normalize(images, image_mean, image_std) + elif do_rescale: + images = images * rescale_factor + elif do_normalize: + images = self.normalize(images, image_mean, image_std) + + return images + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[Llama4ImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + size: SizeDict, + max_patches: int, + resize_to_max_canvas: bool, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size) + possible_resolutions = torch.tensor(possible_resolutions, device=images[0].device) + # process images by batch, grouped by shape + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + grouped_processed_images = {} + grouped_aspect_ratios = {} + for shape, stacked_images in grouped_images.items(): + image_size = stacked_images.shape[-2:] + target_size = get_best_fit(image_size, possible_resolutions, resize_to_max_canvas=resize_to_max_canvas) + # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size + max_upscaling_size = None if resize_to_max_canvas else size.height + if max_upscaling_size is not None: + new_target_height = min(max(image_size[0], max_upscaling_size), target_size[0]) + new_target_width = min(max(image_size[1], max_upscaling_size), target_size[1]) + target_size_without_distortion = (new_target_height, new_target_width) + + # resize to target_size while preserving aspect ratio + new_size_without_distortion = get_max_res_without_distortion(image_size, target_size_without_distortion) + new_size_without_distortion = SizeDict( + height=max(new_size_without_distortion[0], 1), width=max(new_size_without_distortion[1], 1) + ) + processed_images = self.resize( + stacked_images, + new_size_without_distortion, + interpolation=interpolation, + ) + + # pad to target_size to be able to split into tiles + processed_images = pad_to_best_fit(processed_images, target_size) + processed_images = self.rescale_and_normalize( + processed_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + + ratio_h, ratio_w = ( + target_size[0] // size.height, + target_size[1] // size.width, + ) + # split into tiles + processed_images = split_to_tiles(processed_images, ratio_h, ratio_w) + grouped_processed_images[shape] = processed_images + grouped_aspect_ratios[shape] = torch.tensor( + [[ratio_h, ratio_w]] * stacked_images.shape[0], device=images[0].device + ) + + # add a global tile to the processed tile if there are more than one tile + if ratio_h * ratio_w > 1: + global_tiles = self.resize( + stacked_images, + size, + interpolation=interpolation, + ) + global_tiles = self.rescale_and_normalize( + global_tiles, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1) + processed_images = reorder_images(grouped_processed_images, grouped_images_index) + aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index) + + processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images + aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list + return BatchFeature( + data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors + ) + + +__all__ = ["Llama4ImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/llama4/modeling_llama4.py b/venv/lib/python3.13/site-packages/transformers/models/llama4/modeling_llama4.py new file mode 100644 index 0000000000000000000000000000000000000000..a974ed81ba2fbf4c9ef934c6b1fbf4ab54bcc4d4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/llama4/modeling_llama4.py @@ -0,0 +1,1391 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_chunked_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_llama4 import Llama4Config, Llama4TextConfig + + +logger = logging.get_logger(__name__) + + +class Llama4TextExperts(nn.Module): + def __init__(self, config: Llama4TextConfig): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This should really not be run on a single machine, as we are reaching compute bound: + - the inputs are expected to be "sorted" per expert already. + - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + hidden_states = hidden_states.view(self.gate_up_proj.shape[0], -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + next_states = next_states.view(-1, self.hidden_size) + return next_states + + +# Phi3MLP +class Llama4TextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + + if intermediate_size is None: + intermediate_size = config.intermediate_size + + self.config = config + self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x) + return self.down_proj(down_proj) + + +class Llama4TextL2Norm(torch.nn.Module): + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + +class Llama4TextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + Llama4RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Llama4Router(nn.Linear): + def __init__(self, config): + super().__init__(config.hidden_size, config.num_local_experts, bias=False) + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + def forward(self, hidden_states): + router_logits = super().forward(hidden_states) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + router_scores = torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value) + router_scores = torch.nn.functional.sigmoid(router_scores.float()).to(router_scores.dtype) + return router_scores, router_logits + + +@use_kernel_forward_from_hub("Llama4TextMoe") +class Llama4TextMoe(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = Llama4TextExperts(config) + self.router = Llama4Router(config) + self.shared_expert = Llama4TextMLP(config) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_scores, router_logits = self.router(hidden_states) + routed_in = hidden_states.repeat(router_scores.shape[1], 1) + routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1) + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0)) + return out, router_logits + + +class Llama4TextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Llama4TextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + self.rope_type = "llama3" if config.rope_scaling is not None else "default" + + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation + freqs_cis = freqs_cis * self.attention_scaling + + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32 +def vision_eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5 + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Llama4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Llama4TextConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.use_rope = config.no_rope_layers[layer_idx] + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + if self.config.use_qk_norm and self.use_rope: + self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.use_rope: # the 16E model skips rope for long context on certain layers + query_states, key_states = apply_rotary_emb( + query_states, key_states, position_embeddings.to(query_states.device) + ) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + # Use temperature tuning from https://huggingface.co/papers/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + torch.log1p(torch.floor((cache_position.float() + 1.0) / self.floor_scale)) * self.attn_scale + 1.0 + ) + attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 + query_states = (query_states * attn_scales).to(query_states.dtype) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Llama4TextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Llama4TextAttention(config, layer_idx) + self.is_moe_layer = layer_idx in config.moe_layers + if self.is_moe_layer: # the 128E model interleaves dense / sparse + self.feed_forward = Llama4TextMoe(config) + else: + self.feed_forward = Llama4TextMLP(config, intermediate_size=config.intermediate_size_mlp) + + self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attention_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + attention_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + if self.is_moe_layer: + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states.view(residual.shape) + return hidden_states + + +@auto_docstring +class Llama4PreTrainedModel(PreTrainedModel): + config: Llama4Config + supports_gradient_checkpointing = True + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = False + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Llama4TextRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Llama4TextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + elif isinstance(module, Llama4VisionModel): + module.class_embedding.data.normal_(std=module.scale) + module.positional_embedding_vlm.data.normal_(std=module.scale) + + +@auto_docstring +class Llama4TextModel(Llama4PreTrainedModel): + _no_split_modules = ["Llama4TextDecoderLayer"] + base_model_prefix = "model" + config: Llama4TextConfig + _can_record_outputs = { + "attentions": Llama4TextAttention, + "hidden_states": Llama4TextDecoderLayer, + "router_logits": Llama4TextMoe, + } + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Llama4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Llama4TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "chunked_attention": create_chunked_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + freq_cis = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=freq_cis, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): + _no_split_modules = ["Llama4TextDecoderLayer"] + base_model_prefix = "language_model" + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + config: Llama4TextConfig + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.model = Llama4TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Llama4ForCausalLM + + >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava causal language model (or autoregressive) outputs. + """ +) +class Llama4CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class Llama4VisionMLP2(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.fc1 = nn.Linear(self.intermediate_size, config.projector_input_dim, bias=False) + self.fc2 = nn.Linear(config.projector_output_dim, config.projector_output_dim, bias=False) + self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] + self.dropout = config.projector_dropout + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + return self.activation_fn(self.fc2(hidden_states)) + + +class Llama4MultiModalProjector(nn.Module): + def __init__(self, config): + super().__init__() + self.linear_1 = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=False, + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + return hidden_states + + +def pixel_shuffle(input_tensor, shuffle_ratio): + # input_tensor: [batch_size, num_patches, channels] + batch_size, num_patches, channels = input_tensor.shape + patch_size = int(math.sqrt(num_patches)) + + input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) + batch_size, height, width, channels = input_tensor.size() + + reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + reshaped_tensor = reshaped_tensor.view( + batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) + ) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) + return output_tensor + + +class Llama4VisionPixelShuffleMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.pixel_shuffle_ratio = config.pixel_shuffle_ratio + self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2)) + self.output_dim = config.projector_output_dim + self.mlp = Llama4VisionMLP2(config) + + def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: + encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) + return self.mlp(encoded_patches) + + +# TODO there is a different RoPE for vision encoder, defined as below +def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): + ndim = query.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)] + return freqs_ci.view(*shape) + + +def vision_apply_rotary_emb( + query: torch.Tensor, + key: torch.Tensor, + freqs_ci: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) + freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_) # freqs_ci[:,:,None,:] + freqs_ci = freqs_ci.to(query_.device) + query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) + key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) + return query_out.type_as(query), key_out.type_as(key) # but this drops to 8e-3 + + +class Llama4VisionAttention(nn.Module): + def __init__(self, config: Llama4VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_groups = 1 + self.attention_dropout = config.attention_dropout + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attention_interface: Callable = vision_eager_attention_forward + # flex disable because breaks on TP 8, embed is 88 not power of 2 + if self.config._attn_implementation not in ["eager", "flex_attention"]: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim) + is_causal=False, # HAS TO BE ENFORCED + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Llama4VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Llama4VisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Llama4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Llama4VisionAttention(config) + self.mlp = Llama4VisionMLP(config) + + self.input_layernorm = nn.LayerNorm(config.hidden_size) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_state: torch.Tensor, + freqs_ci: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + # Self Attention + residual = hidden_state + + hidden_state = self.input_layernorm(hidden_state) + + hidden_state, attn_weights = self.self_attn( + hidden_state, + freqs_ci=freqs_ci, + attention_mask=attention_mask, + ) + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = residual + hidden_state + + outputs = (hidden_state,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Llama4VisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Llama4VisionEncoderLayer`]. + + Args: + config: Llama4VisionConfig + """ + + def __init__(self, config: Llama4VisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + freqs_ci=freqs_ci, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class Llama4UnfoldConvolution(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = config.patch_size + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) + self.linear = nn.Linear( + config.num_channels * kernel_size[0] * kernel_size[1], + config.hidden_size, + bias=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.unfold(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.linear(hidden_states) + return hidden_states + + +class Llama4VisionRotaryEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + idx = config.image_size // config.patch_size + img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # ID_CLS_TOKEN + frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x + frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y + freq_dim = config.hidden_size // config.num_attention_heads // 2 + rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)) + freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 + + def forward(self, hidden_states): + return self.freqs_ci.to(hidden_states.device) + + +class Llama4VisionModel(Llama4PreTrainedModel): + base_model_prefix = "vision_model" + _no_split_modules = ["Llama4VisionEncoderLayer"] + config: Llama4VisionConfig + + def __init__(self, config: Llama4VisionConfig): + super().__init__(config) + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = Llama4UnfoldConvolution(config) + + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) + self.positional_embedding_vlm = nn.Parameter(self.scale * torch.randn(self.num_patches, self.hidden_size)) + self.rotary_embedding = Llama4VisionRotaryEmbedding(config) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.model = Llama4VisionEncoder(config) + self.vision_adapter = Llama4VisionPixelShuffleMLP(config) + self.post_init() + + def get_input_embeddings(self): + """ + This function is used to fetch the first embedding layer to activate grads on inputs. + """ + return self.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]: + r""" + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MllamaVisionModel + + >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" + >>> model = MllamaVisionModel.from_pretrained(checkpoint) + >>> processor = AutoProcessor.from_pretrained(checkpoint) + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> output = model(**inputs) + + >>> print(output.last_hidden_state.shape) + torch.Size([1, 1, 4, 1025, 7680]) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # num_concurrent_media and num_chunks are both currently 1 + batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape + num_concurrent_media = 1 + num_chunks = 1 + hidden_state = self.patch_embedding(pixel_values) + _, num_patches, hidden_dim = hidden_state.shape + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim + ) + class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1]) + hidden_state = torch.cat([hidden_state, class_embedding], dim=1) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim + ) + positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device) + hidden_state = hidden_state + positional_embedding + + hidden_state = self.layernorm_pre(hidden_state) + + hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) + freqs_ci = self.rotary_embedding(pixel_values) + + output = self.model( + hidden_state, + attention_mask=None, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + freqs_ci=freqs_ci, + ) + + hidden_state = output.last_hidden_state + + hidden_state = self.layernorm_post(hidden_state) + + hidden_state = hidden_state[:, :-1, :] + + # now, we use Llama4VisionPixelShuffle + mlp to project embeddings + hidden_state = self.vision_adapter(hidden_state) + + hidden_states = output.hidden_states if output_hidden_states else None + + if output_attentions: + attentions = output[2] + else: + attentions = None + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + +class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): + _no_split_modules = ["Llama4TextDecoderLayer", "Llama4VisionEncoderLayer"] + _tp_plan = {} + base_model_prefix = "" + config: Llama4Config + + def __init__(self, config: Llama4Config): + super().__init__(config) + self.vision_model = Llama4VisionModel(config.vision_config) + + self.multi_modal_projector = Llama4MultiModalProjector(config) + self.language_model = Llama4ForCausalLM(config.text_config) + self.vocab_size = config.text_config.vocab_size + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_select_strategy: str, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply al projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}") + kwargs = {k: v for k, v in kwargs.items() if v is not None} + image_outputs = self.vision_model(pixel_values, output_hidden_states=False, **kwargs) + hidden_state = image_outputs.last_hidden_state + return hidden_state + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + return special_image_mask + + @auto_docstring + @deprecate_kwarg("vision_feature_layer", version="4.58") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Llama4CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.multi_modal_projector(vision_flat).to( + inputs_embeds.device, inputs_embeds.dtype + ) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Llama4CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = [ + "Llama4PreTrainedModel", + "Llama4TextModel", + "Llama4VisionModel", + "Llama4ForCausalLM", + "Llama4ForConditionalGeneration", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/llama4/processing_llama4.py b/venv/lib/python3.13/site-packages/transformers/models/llama4/processing_llama4.py new file mode 100644 index 0000000000000000000000000000000000000000..ce590bc6f40bb9bcbb7283ecb78256d0e56361bf --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/llama4/processing_llama4.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union + +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput, make_flat_list_of_images + + +class Llama4ImagesKwargs(ImagesKwargs, total=False): + max_patches: Optional[int] + resize_to_max_canvas: Optional[bool] + + +class Llama4ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Llama4ImagesKwargs + _defaults = { + "text_kwargs": { + "padding_side": "left", + }, + } + + +chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\n\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\n\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\n\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\n\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' }}\n{%- endif %}\n" + + +class Llama4Processor(ProcessorMixin): + r""" + Constructs a Llama4 processor which wraps a [`AutoImageProcessor`] and + [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and + tokenizer functionalities. See the [`~Llama4Processor.__call__`] and [`~Llama4Processor.decode`] for more information. + Args: + image_processor ([`AutoImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*, defaults to 28): + The size of image patches for tokenization. + img_size (`int`, *optional*, defaults to 364): + The size of the image to be tokenized. This should correspond to the size given to the image processor. + image_token (`str`, *optional*, defaults to `"<|image|>"`): + The token to be used to represent an image in the text. + downsample_factor (`int`, *optional*, defaults to 1): + The factor by which to scale the patch size. + start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`): + The token to be used to represent the start of an image in the text. + end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`): + The token to be used to represent the end of an image in the text. + img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`): + The token to be used to represent an image patch in the text. + img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`): + The token to be used to represent a line break in the text. + tile_token (`str`, *optional*, defaults to `"TILE"`): + The token to be used to represent an image patch in the text. + tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`): + The token to be used to represent the cover image in the text. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 14, + pixel_shuffle_ratio: float = 0.5, + fake_image_token="<|image|>", + image_token="<|image|>", + start_of_image_token="<|image_start|>", + end_of_image_token="<|image_end|>", + patch_token="<|patch|>", + tile_x_separator_token="<|tile_x_separator|>", + tile_y_separator_token="<|tile_y_separator|>", + chat_template=chat_template, + **kwargs, + ): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + self.downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) + self.patch_size = patch_size + + self.fake_image_token = fake_image_token + self.image_token = image_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.start_of_img_token = start_of_image_token + self.end_of_img_token = end_of_image_token + self.img_patch_token = patch_token + self.tile_token = tile_x_separator_token + self.tile_global_token = tile_y_separator_token + + def _prompt_split_image(self, aspect_ratio, num_patches_per_chunk): + """ + Create a structured string representation of image tokens + + Args: + num_patches: Number of patches in the image + + Returns: + String with appropriate image tokens + """ + img_string = "<|image_start|>" + ratio_h, ratio_w = aspect_ratio + if ratio_h * ratio_w > 1: + for yy in range(ratio_h): + for xx in range(ratio_w): + img_string += "<|patch|>" * num_patches_per_chunk + if xx < ratio_w - 1: + img_string += "<|tile_x_separator|>" + + img_string += "<|tile_y_separator|>" + + img_string += "<|image|>" + img_string += "<|patch|>" * num_patches_per_chunk + img_string += "<|image_end|>" + + return img_string + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Llama4ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text. + To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to + Llama4ImageProcessor's [`~Llama4ImageProcessor.__call__`] if `images` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if text is None: + raise ValueError("You have to specify text.") + + output_kwargs = self._merge_kwargs( + Llama4ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if not isinstance(text, (list, tuple)): + text = [text] + + # Process images + image_inputs = {} + if images is not None: + images = self.image_processor.fetch_images(images) + images = make_flat_list_of_images(images) + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_height, image_width = image_inputs["pixel_values"][0].shape[-2:] + num_patches_per_chunk = int( + (image_height // self.patch_size) * (image_width // self.patch_size) // self.downsample_ratio + ) + aspect_ratios = image_inputs.pop("aspect_ratios") + + total_placeholders = sum(prompt.count(self.fake_image_token) for prompt in text) + if total_placeholders != len(images): + raise ValueError( + f"Found {total_placeholders} placeholders across the batch, " + f"but have {len(images)} flattened images." + ) + + image_index = 0 + processed_text = [] + for prompt in text: + placeholder_count = prompt.count(self.fake_image_token) + if placeholder_count == 0: + # do nothing if there is no image + processed_text.append(prompt) + continue + prompt_splits = prompt.split(self.fake_image_token) + new_prompt = [] + for local_image_index, split_part in enumerate(prompt_splits): + new_prompt.append(split_part) + if local_image_index < placeholder_count: + tokens_for_this_image = self._prompt_split_image( + aspect_ratios[image_index], num_patches_per_chunk + ) + image_index += 1 + new_prompt.append(tokens_for_this_image) + processed_text.append("".join(new_prompt)) + + if image_index != len(images): + raise ValueError("Number of image placeholders in the prompt does not match the number of images.") + + text = processed_text + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + +__all__ = ["Llama4Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/longformer/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/longformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87f53105424b76e5c18bd740ecfdd37a5b29d0d4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/longformer/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_longformer import * + from .modeling_longformer import * + from .modeling_tf_longformer import * + from .tokenization_longformer import * + from .tokenization_longformer_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/longformer/configuration_longformer.py b/venv/lib/python3.13/site-packages/transformers/models/longformer/configuration_longformer.py new file mode 100644 index 0000000000000000000000000000000000000000..207cc18394796352a0d42bf28082baf04d11ff8e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/longformer/configuration_longformer.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Longformer configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import TensorType, logging + + +if TYPE_CHECKING: + from ...onnx.config import PatchingSpec + from ...tokenization_utils_base import PreTrainedTokenizerBase + + +logger = logging.get_logger(__name__) + + +class LongformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LongformerModel`] or a [`TFLongformerModel`]. It + is used to instantiate a Longformer model according to the specified arguments, defining the model architecture. + + This is the configuration class to store the configuration of a [`LongformerModel`]. It is used to instantiate an + Longformer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LongFormer + [allenai/longformer-base-4096](https://huggingface.co/allenai/longformer-base-4096) architecture with a sequence + length 4,096. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Longformer model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`LongformerModel`] or [`TFLongformerModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`LongformerModel`] or + [`TFLongformerModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + attention_window (`int` or `list[int]`, *optional*, defaults to 512): + Size of an attention window around each token. If an `int`, use the same size for all layers. To specify a + different window size for each layer, use a `list[int]` where `len(attention_window) == num_hidden_layers`. + + Example: + + ```python + >>> from transformers import LongformerConfig, LongformerModel + + >>> # Initializing a Longformer configuration + >>> configuration = LongformerConfig() + + >>> # Initializing a model from the configuration + >>> model = LongformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "longformer" + + def __init__( + self, + attention_window: Union[list[int], int] = 512, + sep_token_id: int = 2, + pad_token_id: int = 1, + bos_token_id: int = 0, + eos_token_id: int = 2, + vocab_size: int = 30522, + hidden_size: int = 768, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.1, + attention_probs_dropout_prob: float = 0.1, + max_position_embeddings: int = 512, + type_vocab_size: int = 2, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + onnx_export: bool = False, + **kwargs, + ): + """Constructs LongformerConfig.""" + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.attention_window = attention_window + self.sep_token_id = sep_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.onnx_export = onnx_export + + +class LongformerOnnxConfig(OnnxConfig): + def __init__( + self, config: "PretrainedConfig", task: str = "default", patching_specs: "Optional[list[PatchingSpec]]" = None + ): + super().__init__(config, task, patching_specs) + config.onnx_export = True + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("global_attention_mask", dynamic_axis), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + outputs = super().outputs + if self.task == "default": + outputs["pooler_output"] = {0: "batch"} + return outputs + + @property + def atol_for_validation(self) -> float: + """ + What absolute tolerance value to use during model conversion validation. + + Returns: + Float absolute tolerance value. + """ + return 1e-4 + + @property + def default_onnx_opset(self) -> int: + # needs to be >= 14 to support tril operator + return max(super().default_onnx_opset, 14) + + def generate_dummy_inputs( + self, + tokenizer: "PreTrainedTokenizerBase", + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + inputs = super().generate_dummy_inputs( + preprocessor=tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + import torch + + # for some reason, replacing this code by inputs["global_attention_mask"] = torch.randint(2, inputs["input_ids"].shape, dtype=torch.int64) + # makes the export fail randomly + inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"]) + # make every second token global + inputs["global_attention_mask"][:, ::2] = 1 + + return inputs + + +__all__ = ["LongformerConfig", "LongformerOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/longformer/modeling_longformer.py b/venv/lib/python3.13/site-packages/transformers/models/longformer/modeling_longformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc7089249674b901662d644f94ae8ee8521d625 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/longformer/modeling_longformer.py @@ -0,0 +1,2222 @@ +# coding=utf-8 +# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Longformer model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging +from .configuration_longformer import LongformerConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Longformer's outputs, with potential hidden states, local and global attentions. + """ +) +class LongformerBaseModelOutput(ModelOutput): + r""" + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: torch.FloatTensor + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Longformer's outputs that also contains a pooling of the last hidden states. + """ +) +class LongformerBaseModelOutputWithPooling(ModelOutput): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: torch.FloatTensor + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for masked language models outputs. + """ +) +class LongformerMaskedLMOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of question answering Longformer models. + """ +) +class LongformerQuestionAnsweringModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: Optional[torch.FloatTensor] = None + end_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of sentence classification models. + """ +) +class LongformerSequenceClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of multiple choice Longformer models. + """ +) +class LongformerMultipleChoiceModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of token classification models. + """ +) +class LongformerTokenClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, + where `x` is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +def _get_question_end_index(input_ids, sep_token_id): + """ + Computes the index of the first occurrence of `sep_token_id`. + """ + + sep_token_indices = (input_ids == sep_token_id).nonzero() + batch_size = input_ids.shape[0] + + assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" + assert sep_token_indices.shape[0] == 3 * batch_size, ( + f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You" + " might also consider to set `global_attention_mask` manually in the forward function to avoid this error." + ) + return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] + + +def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True): + """ + Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is + True` else after `sep_token_id`. + """ + question_end_index = _get_question_end_index(input_ids, sep_token_id) + question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 + # bool attention mask with True in locations of global attention + attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) + if before_sep_token is True: + attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.bool) + else: + # last token is separation token and should not be counted and in the middle are two separation tokens + attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.bool) * ( + attention_mask.expand_as(input_ids) < input_ids.shape[-1] + ).to(torch.bool) + + return attention_mask + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indices.long() + padding_idx + + +class LongformerEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor inputs_embeds: + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +class LongformerSelfAttention(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + + self.query = nn.Linear(config.hidden_size, self.embed_dim) + self.key = nn.Linear(config.hidden_size, self.embed_dim) + self.value = nn.Linear(config.hidden_size, self.embed_dim) + + # separate projection layers for tokens with global attention + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) + self.key_global = nn.Linear(config.hidden_size, self.embed_dim) + self.value_global = nn.Linear(config.hidden_size, self.embed_dim) + + self.dropout = config.attention_probs_dropout_prob + + self.layer_id = layer_id + attention_window = config.attention_window[self.layer_id] + assert attention_window % 2 == 0, ( + f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" + ) + assert attention_window > 0, ( + f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" + ) + + self.one_sided_attn_window_size = attention_window // 2 + + self.config = config + + def forward( + self, + hidden_states, + attention_mask=None, + layer_head_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, + ): + """ + [`LongformerSelfAttention`] expects *len(hidden_states)* to be multiple of *attention_window*. Padding to + *attention_window* happens in [`LongformerModel.forward`] to avoid redoing the padding on each layer. + + The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to: + + - -10000: no attention + - 0: local attention + - +10000: global attention + """ + hidden_states = hidden_states.transpose(0, 1) + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = hidden_states.size() + assert embed_dim == self.embed_dim, ( + f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + ) + + # normalize query + query_vectors /= math.sqrt(self.head_dim) + + query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None] + + # cast to fp32/fp16 then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( + remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min + ) + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size + ) + + # pad local attention probs + attn_scores += diagonal_mask + + assert list(attn_scores.size()) == [ + batch_size, + seq_len, + self.num_heads, + self.one_sided_attn_window_size * 2 + 1, + ], ( + f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," + f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" + ) + + # compute local attention probs from global attention keys and contact over window dim + if is_global_attn: + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + # calculate global attn probs from global key + + global_key_attn_scores = self._concat_with_global_key_attn_probs( + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + # concat to local_attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) + + # free memory + del global_key_attn_scores + + attn_probs = nn.functional.softmax( + attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_heads,), ( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) + attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) + attn_probs = attn_probs.type_as(attn_scores) + + # free memory + del attn_scores + + # apply dropout + attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) + + value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # compute local attention output with global attention value and add + if is_global_attn: + # compute sum of global and local attn + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + # compute local attn only + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() + + # compute value for global attention and overwrite to attention output + # TODO: remove the redundant computation + if is_global_attn: + global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + ) + + # get only non zero global attn output + nonzero_global_attn_output = global_attn_output[ + is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] + ] + + # overwrite values with global attention + attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( + len(is_local_index_global_attn_nonzero[0]), -1 + ) + # The attention weights for tokens with global attention are + # just filler values, they were never used to compute the output. + # Fill with 0 now, the correct values are in 'global_attn_probs'. + attn_probs[is_index_global_attn_nonzero] = 0 + + outputs = (attn_output.transpose(0, 1),) + + if output_attentions: + outputs += (attn_probs,) + + return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs + + @staticmethod + def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): + """pads rows and then flips rows and columns""" + hidden_states_padded = nn.functional.pad( + hidden_states_padded, padding + ) # padding value is not important because it will be overwritten + hidden_states_padded = hidden_states_padded.view( + *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) + ) + return hidden_states_padded + + @staticmethod + def _pad_and_diagonalize(chunked_hidden_states): + """ + shift every row 1 step right, converting columns into diagonals. + + Example: + + ```python + chunked_hidden_states: [ + 0.4983, + 2.6918, + -0.0071, + 1.0492, + -1.8348, + 0.7672, + 0.2986, + 0.0285, + -0.7584, + 0.4206, + -0.0405, + 0.1599, + 2.0514, + -1.1600, + 0.5372, + 0.2629, + ] + window_overlap = num_rows = 4 + ``` + + (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, + -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ + total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() + chunked_hidden_states = nn.functional.pad( + chunked_hidden_states, (0, window_overlap + 1) + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, -1 + ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlap*window_overlap + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim + ) + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + return chunked_hidden_states + + @staticmethod + def _chunk(hidden_states, window_overlap, onnx_export: bool = False): + """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + if not onnx_export: + # non-overlapping chunks of size = 2w + hidden_states = hidden_states.view( + hidden_states.size(0), + torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"), + window_overlap * 2, + hidden_states.size(2), + ) + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(hidden_states.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 + + chunk_stride = list(hidden_states.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) + + # When exporting to ONNX, use this separate logic + # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export + + # TODO replace this with + # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3) + # once `unfold` is supported + # the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow + + chunk_size = [ + hidden_states.size(0), + torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1, + window_overlap * 2, + hidden_states.size(2), + ] + + overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device) + for chunk in range(chunk_size[1]): + overlapping_chunks[:, chunk, :, :] = hidden_states[ + :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, : + ] + return overlapping_chunks + + @staticmethod + def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: + beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) + beginning_mask = beginning_mask_2d[None, :, None, :] + ending_mask = beginning_mask.flip(dims=(1, 3)) + beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] + beginning_mask = beginning_mask.expand(beginning_input.size()) + input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like( + beginning_input, -float("inf") + ).where(beginning_mask.bool(), beginning_input) + ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] + ending_mask = ending_mask.expand(ending_input.size()) + input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like( + ending_input, -float("inf") + ).where(ending_mask.bool(), ending_input) + + def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): + """ + Matrix multiplication of query and key tensors using with a sliding window attention pattern. This + implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an + overlap of size window_overlap + """ + batch_size, seq_len, num_heads, head_dim = query.size() + assert seq_len % (window_overlap * 2) == 0, ( + f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" + ) + assert query.size() == key.size() + + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + query = self._chunk(query, window_overlap, getattr(self.config, "onnx_export", False)) + key = self._chunk(key, window_overlap, getattr(self.config, "onnx_export", False)) + + # matrix multiplication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap + diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply + + # convert diagonals into columns + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( + diagonal_chunked_attention_scores, padding=(0, 0, 0, 1) + ) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + + diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros( + (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) + ) + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, -1, window_overlap:, : window_overlap + 1 + ] + # - copying the lower triangle + diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ] + + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ + :, 0, : window_overlap - 1, 1 - window_overlap : + ] + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = diagonal_attention_scores.view( + batch_size, num_heads, seq_len, 2 * window_overlap + 1 + ).transpose(2, 1) + + self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores + + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int + ): + """ + Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the + same shape as `attn_probs` + """ + batch_size, seq_len, num_heads, head_dim = value.size() + + assert seq_len % (window_overlap * 2) == 0 + assert attn_probs.size()[:3] == value.size()[:3] + assert attn_probs.size(3) == 2 * window_overlap + 1 + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( + batch_size * num_heads, + torch.div(seq_len, window_overlap, rounding_mode="trunc"), + window_overlap, + 2 * window_overlap + 1, + ) + + # group batch_size and num_heads dimensions into one + value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value_stride = padded_value.stride() + chunked_value_stride = ( + chunked_value_stride[0], + window_overlap * chunked_value_stride[1], + chunked_value_stride[1], + chunked_value_stride[2], + ) + chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) + + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) + + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) + return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """compute global attn indices required throughout forward pass""" + # helper variable + num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = num_global_attn_indices.max() + + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + + # helper variable + is_local_index_global_attn = torch.arange( + max_num_global_attn_indices, device=is_index_global_attn.device + ) < num_global_attn_indices.unsqueeze(dim=-1) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = key_vectors.shape[0] + + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + + key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) + attn_probs_from_global_key[ + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : + ] = torch.finfo(attn_probs_from_global_key.dtype).min + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) + + return attn_probs_from_global_key + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = attn_probs.shape[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) + # get value vectors for global only + value_vectors_only_global = value_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero] + + # use `matmul` because `einsum` crashes sometimes with fp16 + # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) + # compute attn output only global + attn_output_only_global = torch.matmul( + attn_probs_only_global.transpose(1, 2).clone(), value_vectors_only_global.transpose(1, 2).clone() + ).transpose(1, 2) + + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + return attn_output_only_global + attn_output_without_global + + def _compute_global_attn_output_from_hidden( + self, + hidden_states, + max_num_global_attn_indices, + layer_head_mask, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + ): + seq_len, batch_size = hidden_states.shape[:2] + + # prepare global hidden states + global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) + global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ + is_index_global_attn_nonzero[::-1] + ] + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= math.sqrt(self.head_dim) + + # reshape + global_query_vectors_only_global = ( + global_query_vectors_only_global.contiguous() + .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) + global_key_vectors = ( + global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + global_value_vectors = ( + global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + + # compute attn scores + global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2)) + + assert list(global_attn_scores.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + seq_len, + ], ( + "global_attn_scores have the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" + f" {global_attn_scores.size()}." + ) + + global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + global_attn_scores = global_attn_scores.transpose(1, 2) + global_attn_scores[ + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : + ] = torch.finfo(global_attn_scores.dtype).min + global_attn_scores = global_attn_scores.transpose(1, 2) + + global_attn_scores = global_attn_scores.masked_fill( + is_index_masked[:, None, None, :], + torch.finfo(global_attn_scores.dtype).min, + ) + + global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + + # compute global attn probs + global_attn_probs_float = nn.functional.softmax( + global_attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + # apply layer head masking + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_heads,), ( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) + global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view( + batch_size, self.num_heads, max_num_global_attn_indices, seq_len + ) + global_attn_probs_float = global_attn_probs_float.view( + batch_size * self.num_heads, max_num_global_attn_indices, seq_len + ) + + global_attn_probs = nn.functional.dropout( + global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training + ) + + # global attn output + global_attn_output = torch.bmm(global_attn_probs, global_value_vectors) + + assert list(global_attn_output.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + self.head_dim, + ], ( + "global_attn_output tensor has the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" + f" {global_attn_output.size()}." + ) + + global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + global_attn_output = global_attn_output.view( + batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim + ) + return global_attn_output, global_attn_probs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class LongformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LongformerAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.self = LongformerSelfAttention(config, layer_id) + self.output = LongformerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + layer_head_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) + attn_output = self.output(self_outputs[0], hidden_states) + outputs = (attn_output,) + self_outputs[1:] + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class LongformerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class LongformerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LongformerLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.attention = LongformerAttention(config, layer_id) + self.intermediate = LongformerIntermediate(config) + self.output = LongformerOutput(config) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + def forward( + self, + hidden_states, + attention_mask=None, + layer_head_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, + ): + self_attn_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) + attn_output = self_attn_outputs[0] + outputs = self_attn_outputs[1:] + + layer_output = apply_chunking_to_forward( + self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output + ) + outputs = (layer_output,) + outputs + return outputs + + def ff_chunk(self, attn_output): + intermediate_output = self.intermediate(attn_output) + layer_output = self.output(intermediate_output, attn_output) + return layer_output + + +class LongformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + padding_len=0, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + is_index_masked = attention_mask < 0 + is_index_global_attn = attention_mask > 0 + + # Record `is_global_attn == True` to enable ONNX export + is_global_attn = is_index_global_attn.flatten().any().item() + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None # All local attentions. + all_global_attentions = () if (output_attentions and is_global_attn) else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layer)), ( + f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}." + ) + for idx, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) + all_attentions = all_attentions + (layer_outputs[1].transpose(1, 2),) + + if is_global_attn: + # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn + all_global_attentions = all_global_attentions + (layer_outputs[2].transpose(2, 3),) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # undo padding if necessary + # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) + hidden_states = hidden_states[:, : hidden_states.shape[1] - padding_len] + if output_hidden_states: + all_hidden_states = tuple(state[:, : state.shape[1] - padding_len] for state in all_hidden_states) + + if output_attentions: + all_attentions = tuple(state[:, :, : state.shape[2] - padding_len, :] for state in all_attentions) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None + ) + return LongformerBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + global_attentions=all_global_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class LongformerPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Longformer +class LongformerLMHead(nn.Module): + """Longformer Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@auto_docstring +class LongformerPreTrainedModel(PreTrainedModel): + config: LongformerConfig + base_model_prefix = "longformer" + supports_gradient_checkpointing = True + _no_split_modules = ["LongformerSelfAttention"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class LongformerModel(LongformerPreTrainedModel): + """ + This class copied code from [`RobertaModel`] and overwrote standard self-attention with longformer self-attention + to provide the ability to process long sequences following the self-attention approach described in [Longformer: + the Long-Document Transformer](https://huggingface.co/papers/2004.05150) by Iz Beltagy, Matthew E. Peters, and Arman Cohan. + Longformer self-attention combines a local (sliding window) and global attention to extend to long documents + without the O(n^2) increase in memory and compute. + + The self-attention module `LongformerSelfAttention` implemented here supports the combination of local and global + attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated + attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future + release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA + kernel to be memory and compute efficient. + + """ + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + if isinstance(config.attention_window, int): + assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" + assert config.attention_window > 0, "`config.attention_window` has to be positive" + config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer + else: + assert len(config.attention_window) == config.num_hidden_layers, ( + "`len(config.attention_window)` should equal `config.num_hidden_layers`. " + f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" + ) + + self.embeddings = LongformerEmbeddings(config) + self.encoder = LongformerEncoder(config) + self.pooler = LongformerPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def _pad_to_window_size( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + pad_token_id: int, + ): + """A helper function to pad tokens and mask to work with implementation of Longformer self-attention.""" + # padding + attention_window = ( + self.config.attention_window + if isinstance(self.config.attention_window, int) + else max(self.config.attention_window) + ) + + assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" + input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape + batch_size, seq_len = input_shape[:2] + + padding_len = (attention_window - seq_len % attention_window) % attention_window + + # this path should be recorded in the ONNX export, it is fine with padding_len == 0 as well + if padding_len > 0: + logger.warning_once( + f"Input ids are automatically padded to be a multiple of `config.attention_window`: {attention_window}" + ) + if input_ids is not None: + input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id) + if position_ids is not None: + # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings + position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id) + if inputs_embeds is not None: + input_ids_padding = inputs_embeds.new_full( + (batch_size, padding_len), + self.config.pad_token_id, + dtype=torch.long, + ) + inputs_embeds_padding = self.embeddings(input_ids_padding) + inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) + + attention_mask = nn.functional.pad( + attention_mask, (0, padding_len), value=0 + ) # no attention on the padding tokens + token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 + + return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + + def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + return attention_mask + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + global_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LongformerBaseModelOutputWithPooling]: + r""" + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with global + attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also + have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more + details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + + Examples: + + ```python + >>> import torch + >>> from transformers import LongformerModel, AutoTokenizer + + >>> model = LongformerModel.from_pretrained("allenai/longformer-base-4096") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096") + + >>> SAMPLE_TEXT = " ".join(["Hello world! "] * 1000) # long input document + >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 + + >>> attention_mask = torch.ones( + ... input_ids.shape, dtype=torch.long, device=input_ids.device + ... ) # initialize to local attention + >>> global_attention_mask = torch.zeros( + ... input_ids.shape, dtype=torch.long, device=input_ids.device + ... ) # initialize to global attention to be deactivated for all tokens + >>> global_attention_mask[ + ... :, + ... [ + ... 1, + ... 4, + ... 21, + ... ], + ... ] = 1 # Set global attention to random tokens for the sake of this example + >>> # Usually, set global attention based on the task. For example, + >>> # classification: the token + >>> # QA: question tokens + >>> # LM: potentially on the beginning of sentences and paragraphs + >>> outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) + >>> sequence_output = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + + padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.config.pad_token_id, + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[ + :, 0, 0, : + ] + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + padding_len=padding_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return LongformerBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + global_attentions=encoder_outputs.global_attentions, + ) + + +@auto_docstring +class LongformerForMaskedLM(LongformerPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder"] + + def __init__(self, config): + super().__init__(config) + + self.longformer = LongformerModel(config, add_pooling_layer=False) + self.lm_head = LongformerLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + global_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LongformerMaskedLMOutput]: + r""" + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with global + attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also + have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more + details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Example Mask filling: + + ```python + >>> from transformers import AutoTokenizer, LongformerForMaskedLM + + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096") + >>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096") + ``` + + Let's try a very long input. + + ```python + >>> TXT = ( + ... "My friends are but they eat too many carbs." + ... + " That's why I decide not to eat with them." * 300 + ... ) + >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['healthy', 'skinny', 'thin', 'good', 'vegetarian'] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(prediction_scores.device) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return LongformerMaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class LongformerForSequenceClassification(LongformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.longformer = LongformerModel(config, add_pooling_layer=False) + self.classifier = LongformerClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + global_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LongformerSequenceClassifierOutput]: + r""" + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with global + attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also + have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more + details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if global_attention_mask is None: + logger.warning_once("Initializing global attention on CLS token...") + global_attention_mask = torch.zeros_like(input_ids) + # global attention on cls token + global_attention_mask[:, 0] = 1 + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return LongformerSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + +class LongformerClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, hidden_states, **kwargs): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + output = self.out_proj(hidden_states) + return output + + +@auto_docstring +class LongformerForQuestionAnswering(LongformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.longformer = LongformerModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + global_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LongformerQuestionAnsweringModelOutput]: + r""" + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with global + attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also + have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more + details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + + Examples: + + ```python + >>> from transformers import AutoTokenizer, LongformerForQuestionAnswering + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa") + >>> model = LongformerForQuestionAnswering.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + >>> encoding = tokenizer(question, text, return_tensors="pt") + >>> input_ids = encoding["input_ids"] + + >>> # default is local attention everywhere + >>> # the forward method will automatically set global attention on question tokens + >>> attention_mask = encoding["attention_mask"] + + >>> outputs = model(input_ids, attention_mask=attention_mask) + >>> start_logits = outputs.start_logits + >>> end_logits = outputs.end_logits + >>> all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) + + >>> answer_tokens = all_tokens[torch.argmax(start_logits) : torch.argmax(end_logits) + 1] + >>> answer = tokenizer.decode( + ... tokenizer.convert_tokens_to_ids(answer_tokens) + ... ) # remove space prepending space token + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if global_attention_mask is None: + if input_ids is None: + logger.warning( + "It is not possible to automatically generate the `global_attention_mask` because input_ids is" + " None. Please make sure that it is correctly set." + ) + else: + # set global attention on question tokens automatically + global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id) + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return LongformerQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + +@auto_docstring +class LongformerForTokenClassification(LongformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.longformer = LongformerModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + global_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LongformerTokenClassifierOutput]: + r""" + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with global + attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also + have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more + details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return LongformerTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + +@auto_docstring +class LongformerForMultipleChoice(LongformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.longformer = LongformerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + global_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LongformerMultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + global_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with global + attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also + have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more + details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + """ + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # set global attention on question tokens + if global_attention_mask is None and input_ids is not None: + logger.warning_once("Initializing global attention on multiple choice...") + # put global attention on all tokens after `config.sep_token_id` + global_attention_mask = torch.stack( + [ + _compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False) + for i in range(num_choices) + ], + dim=1, + ) + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_global_attention_mask = ( + global_attention_mask.view(-1, global_attention_mask.size(-1)) + if global_attention_mask is not None + else None + ) + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.longformer( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + global_attention_mask=flat_global_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(reshaped_logits.device) + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return LongformerMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + +__all__ = [ + "LongformerForMaskedLM", + "LongformerForMultipleChoice", + "LongformerForQuestionAnswering", + "LongformerForSequenceClassification", + "LongformerForTokenClassification", + "LongformerModel", + "LongformerPreTrainedModel", + "LongformerSelfAttention", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/longformer/modeling_tf_longformer.py b/venv/lib/python3.13/site-packages/transformers/models/longformer/modeling_tf_longformer.py new file mode 100644 index 0000000000000000000000000000000000000000..891f5d76c95ce1f572ea33baf6db51cf21278a80 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/longformer/modeling_tf_longformer.py @@ -0,0 +1,2783 @@ +# coding=utf-8 +# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow Longformer model.""" + +from __future__ import annotations + +import warnings +from dataclasses import dataclass + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_longformer import LongformerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096" +_CONFIG_FOR_DOC = "LongformerConfig" + +LARGE_NEGATIVE = -1e8 + + +@dataclass +class TFLongformerBaseModelOutput(ModelOutput): + """ + Base class for Longformer's outputs, with potential hidden states, local and global attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLongformerBaseModelOutputWithPooling(ModelOutput): + """ + Base class for Longformer's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + last_hidden_state: tf.Tensor | None = None + pooler_output: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLongformerMaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLongformerQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering Longformer models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: tf.Tensor | None = None + start_logits: tf.Tensor | None = None + end_logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLongformerSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLongformerMultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`tf.Tensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFLongformerTokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : + Classification loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + + attention_window + 1)`, where `x` is the number of tokens with global attention mask. + + Local attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token in the sequence to every token with + global attention (first `x` values) and to every token in the attention window (remaining `attention_window + + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the + remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a + token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding + (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. + If the attention window contains a token with global attention, the attention weight at the corresponding + index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global + attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be + accessed from `global_attentions`. + global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` + is the number of tokens with global attention mask. + + Global attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. Those are the attention weights from every token with global attention to every token + in the sequence. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + global_attentions: tuple[tf.Tensor, ...] | None = None + + +def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True): + """ + Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is + True` else after `sep_token_id`. + """ + assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions" + question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None] + # bool attention mask with True in locations of global attention + attention_mask = tf.expand_dims(tf.range(input_ids_shape[1], dtype=tf.int64), axis=0) + attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1)) + if before_sep_token is True: + question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1])) + attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype) + else: + # last token is separation token and should not be counted and in the middle are two separation tokens + question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1])) + attention_mask = tf.cast( + attention_mask > question_end_index, + dtype=question_end_index.dtype, + ) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype) + + return attention_mask + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Longformer +class TFLongformerLMHead(keras.layers.Layer): + """Longformer Head for masked language modeling.""" + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.act = get_tf_activation("gelu") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + # project back to size of vocabulary with bias + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +class TFLongformerEmbeddings(keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing and some extra casting. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.padding_idx = 1 + self.config = config + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: tf.Tensor + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask + + return incremental_indices + self.padding_idx + + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.cast(tf.fill(dims=input_shape, value=0), tf.int64) + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) + else: + position_ids = tf.expand_dims( + tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1, dtype=tf.int64), + axis=0, + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Longformer +class TFLongformerIntermediate(keras.layers.Layer): + def __init__(self, config: LongformerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Longformer +class TFLongformerOutput(keras.layers.Layer): + def __init__(self, config: LongformerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Longformer +class TFLongformerPooler(keras.layers.Layer): + def __init__(self, config: LongformerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Longformer +class TFLongformerSelfOutput(keras.layers.Layer): + def __init__(self, config: LongformerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFLongformerSelfAttention(keras.layers.Layer): + def __init__(self, config, layer_id, **kwargs): + super().__init__(**kwargs) + self.config = config + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads}" + ) + + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + self.query = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query", + ) + self.key = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key", + ) + self.value = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value", + ) + + # separate projection layers for tokens with global attention + self.query_global = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query_global", + ) + self.key_global = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key_global", + ) + self.value_global = keras.layers.Dense( + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value_global", + ) + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.layer_id = layer_id + attention_window = config.attention_window[self.layer_id] + + assert attention_window % 2 == 0, ( + f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" + ) + assert attention_window > 0, ( + f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" + ) + + self.one_sided_attn_window_size = attention_window // 2 + + def build(self, input_shape=None): + if not self.built: + with tf.name_scope("query_global"): + self.query_global.build((self.config.hidden_size,)) + with tf.name_scope("key_global"): + self.key_global.build((self.config.hidden_size,)) + with tf.name_scope("value_global"): + self.value_global.build((self.config.hidden_size,)) + + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + if getattr(self, "query_global", None) is not None: + with tf.name_scope(self.query_global.name): + self.query_global.build([None, None, self.config.hidden_size]) + if getattr(self, "key_global", None) is not None: + with tf.name_scope(self.key_global.name): + self.key_global.build([None, None, self.config.hidden_size]) + if getattr(self, "value_global", None) is not None: + with tf.name_scope(self.value_global.name): + self.value_global.build([None, None, self.config.hidden_size]) + + def call( + self, + inputs, + training=False, + ): + """ + LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to + *attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer. + + The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to: + + - -10000: no attention + - 0: local attention + - +10000: global attention + """ + # retrieve input args + ( + hidden_states, + attention_mask, + layer_head_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ) = inputs + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + batch_size, seq_len, embed_dim = shape_list(hidden_states) + + tf.debugging.assert_equal( + embed_dim, + self.embed_dim, + message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", + ) + + # normalize query + query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) + query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + + # attn_probs = (batch_size, seq_len, num_heads, window*2+1) + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = attention_mask != 0 + # cast to fp32/fp16 then replace 1's with -inf + float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE + + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + tf.ones(shape_list(attention_mask)), + float_mask, + self.one_sided_attn_window_size, + ) + + # pad local attention probs + attn_scores += diagonal_mask + + tf.debugging.assert_equal( + shape_list(attn_scores), + [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], + message=( + f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," + f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}" + ), + ) + + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + + # this function is only relevant for global attention + if is_global_attn: + attn_scores = self._concat_with_global_key_attn_probs( + attn_scores=attn_scores, + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + + attn_probs = stable_softmax(attn_scores, axis=-1) + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + if is_global_attn: + masked_index = tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ) + else: + masked_index = tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ) + attn_probs = tf.where( + masked_index, + tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), + attn_probs, + ) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs + + # apply dropout + attn_probs = self.dropout(attn_probs, training=training) + value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + + # if global attention, compute sum of global and local attn + + if is_global_attn: + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + tf.debugging.assert_equal( + shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" + ) + + attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) + + # compute value for global attention and overwrite to attention output + if is_global_attn: + attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( + attn_output=attn_output, + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + training=training, + ) + else: + # Leave attn_output unchanged + global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len)) + + # make sure that local attention probabilities are set to 0 for indices of global attn + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + if is_global_attn: + masked_global_attn_index = tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ) + else: + masked_global_attn_index = tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ) + attn_probs = tf.where( + masked_global_attn_index, + tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), + attn_probs, + ) + + outputs = (attn_output, attn_probs, global_attn_probs) + + return outputs + + def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): + """ + Matrix multiplication of query and key tensors using with a sliding window attention pattern. This + implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an + overlap of size window_overlap + """ + batch_size, seq_len, num_heads, head_dim = shape_list(query) + + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), + 0, + message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", + ) + tf.debugging.assert_equal( + shape_list(query), + shape_list(key), + message=( + f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:" + f" {shape_list(key)}" + ), + ) + + chunks_count = seq_len // window_overlap - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = tf.reshape( + tf.transpose(query, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) + key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + chunked_query = self._chunk(query, window_overlap) + chunked_key = self._chunk(key, window_overlap) + + # matrix multiplication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap + chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype) + chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply + + # convert diagonals into columns + paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]]) + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + # TODO: This code is most likely not very efficient and should be improved + diagonal_attn_scores_up_triang = tf.concat( + [ + diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], + diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], + ], + axis=1, + ) + + # - copying the lower triangle + diagonal_attn_scores_low_triang = tf.concat( + [ + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), + diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], + ], + axis=1, + ) + diagonal_attn_scores_first_chunk = tf.concat( + [ + tf.roll( + diagonal_chunked_attention_scores, + shift=[1, window_overlap], + axis=[2, 3], + )[:, :, :window_overlap, :window_overlap], + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), + ], + axis=1, + ) + first_chunk_mask = ( + tf.tile( + tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None], + (batch_size * num_heads, 1, window_overlap, window_overlap), + ) + < 1 + ) + diagonal_attn_scores_low_triang = tf.where( + first_chunk_mask, + diagonal_attn_scores_first_chunk, + diagonal_attn_scores_low_triang, + ) + + # merging upper and lower triangle + diagonal_attention_scores = tf.concat( + [diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1 + ) + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = tf.transpose( + tf.reshape( + diagonal_attention_scores, + (batch_size, num_heads, seq_len, 2 * window_overlap + 1), + ), + (0, 2, 1, 3), + ) + + diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + + return diagonal_attention_scores + + @staticmethod + def _mask_invalid_locations(input_tensor, window_overlap): + # create correct upper triangle bool mask + mask_2d_upper = tf.reverse( + tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), + axis=[0], + ) + + # pad to full matrix + padding = tf.convert_to_tensor( + [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] + ) + + # create lower mask + mask_2d = tf.pad(mask_2d_upper, padding) + + # combine with upper mask + mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) + + # broadcast to full matrix + mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) + + # inf tensor used for masking + inf_tensor = -float("inf") * tf.ones_like(input_tensor) + + # mask + input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) + + return input_tensor + + def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): + """ + Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the + same shape as `attn_probs` + """ + + batch_size, seq_len, num_heads, head_dim = shape_list(value) + + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[:3], + shape_list(value)[:3], + message="value and attn_probs must have same dims (except head_dim)", + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[3], + 2 * window_overlap + 1, + message="attn_probs last dim has to be 2 * window_overlap + 1", + ) + + chunks_count = seq_len // window_overlap - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + chunked_attn_probs = tf.reshape( + tf.transpose(attn_probs, (0, 2, 1, 3)), + ( + batch_size * num_heads, + seq_len // window_overlap, + window_overlap, + 2 * window_overlap + 1, + ), + ) + + # group batch_size and num_heads dimensions into one + value = tf.reshape( + tf.transpose(value, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]]) + padded_value = tf.pad(value, paddings, constant_values=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + frame_size = 3 * window_overlap * head_dim + frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count + chunked_value = tf.signal.frame( + tf.reshape(padded_value, (batch_size * num_heads, -1)), + frame_size, + frame_hop_size, + ) + chunked_value = tf.reshape( + chunked_value, + (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), + ) + + tf.debugging.assert_equal( + shape_list(chunked_value), + [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], + message="Chunked value has the wrong shape", + ) + + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) + context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) + context = tf.transpose( + tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), + (0, 2, 1, 3), + ) + + return context + + @staticmethod + def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): + """pads rows and then flips rows and columns""" + hidden_states_padded = tf.pad( + hidden_states_padded, paddings + ) # padding value is not important because it will be overwritten + batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) + hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) + + return hidden_states_padded + + @staticmethod + def _pad_and_diagonalize(chunked_hidden_states): + """ + shift every row 1 step right, converting columns into diagonals. + + Example: + + ```python + chunked_hidden_states: [ + 0.4983, + 2.6918, + -0.0071, + 1.0492, + -1.8348, + 0.7672, + 0.2986, + 0.0285, + -0.7584, + 0.4206, + -0.0405, + 0.1599, + 2.0514, + -1.1600, + 0.5372, + 0.2629, + ] + window_overlap = num_rows = 4 + ``` + + (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, + -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ + total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) + paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) + chunked_hidden_states = tf.pad( + chunked_hidden_states, paddings + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = tf.reshape( + chunked_hidden_states, (total_num_heads, num_chunks, -1) + ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap + chunked_hidden_states = tf.reshape( + chunked_hidden_states, + (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), + ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + + return chunked_hidden_states + + @staticmethod + def _chunk(hidden_states, window_overlap): + """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + batch_size, seq_length, hidden_dim = shape_list(hidden_states) + num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1 + + # define frame size and frame stride (similar to convolution) + frame_hop_size = window_overlap * hidden_dim + frame_size = 2 * frame_hop_size + hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) + + # chunk with overlap + chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) + + tf.debugging.assert_equal( + shape_list(chunked_hidden_states), + [batch_size, num_output_chunks, frame_size], + message=( + "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension" + f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}." + ), + ) + + chunked_hidden_states = tf.reshape( + chunked_hidden_states, + (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), + ) + + return chunked_hidden_states + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """compute global attn indices required throughout forward pass""" + # helper variable + num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) + num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) + + # max number of global attn indices in batch + max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) + + # indices of global attn + is_index_global_attn_nonzero = tf.where(is_index_global_attn) + + # helper variable + is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims( + num_global_attn_indices, axis=-1 + ) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn)) + + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + attn_scores, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = shape_list(key_vectors)[0] + + # select global key vectors + global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) + + # create only global key vectors + key_vectors_only_global = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_key_vectors, + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), + ) + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) + + # (batch_size, max_num_global_attn_indices, seq_len, num_heads) + attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) + mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( + shape_list(attn_probs_from_global_key_trans)[-2:] + ) + mask = tf.ones(mask_shape) * -10000.0 + mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) + + # scatter mask + attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( + attn_probs_from_global_key_trans, + is_local_index_no_global_attn_nonzero, + mask, + ) + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) + + # concat to attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1) + + return attn_scores + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = shape_list(attn_probs)[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices] + + # select global value vectors + global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) + + # create only global value vectors + value_vectors_only_global = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_value_vectors, + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), + ) + + # compute attn output only global + attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) + + # reshape attn probs + attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + + return attn_output_only_global + attn_output_without_global + + def _compute_global_attn_output_from_hidden( + self, + attn_output, + hidden_states, + max_num_global_attn_indices, + layer_head_mask, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + training, + ): + batch_size, seq_len = shape_list(hidden_states)[:2] + + # prepare global hidden states + global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) + global_attn_hidden_states = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_attn_hidden_states, + shape=(batch_size, max_num_global_attn_indices, self.embed_dim), + ) + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= tf.math.sqrt( + tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype) + ) + global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) + global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) + global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) + + # compute attn scores + global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(global_attn_scores), + [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], + message=( + "global_attn_scores have the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" + f" {shape_list(global_attn_scores)}." + ), + ) + + global_attn_scores = tf.reshape( + global_attn_scores, + (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), + ) + global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) + mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( + shape_list(global_attn_scores_trans)[-2:] + ) + global_attn_mask = tf.ones(mask_shape) * -10000.0 + global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) + + # scatter mask + global_attn_scores_trans = tf.tensor_scatter_nd_update( + global_attn_scores_trans, + is_local_index_no_global_attn_nonzero, + global_attn_mask, + ) + global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) + + # mask global attn scores + attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1)) + global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) + global_attn_scores = tf.reshape( + global_attn_scores, + (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), + ) + + # compute global attn probs + global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1) + + # apply layer head masking + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + ) + global_attn_probs_float = tf.reshape( + global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + ) + + # dropout + global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) + + # global attn output + global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) + + tf.debugging.assert_equal( + shape_list(global_attn_output), + [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], + message=( + "global_attn_output tensor has the wrong size. Size should be" + f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" + f" {shape_list(global_attn_output)}." + ), + ) + + global_attn_output = tf.reshape( + global_attn_output, + (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), + ) + + # get only non zero global attn output + nonzero_global_attn_output = tf.gather_nd( + tf.transpose(global_attn_output, (0, 2, 1, 3)), + is_local_index_global_attn_nonzero, + ) + nonzero_global_attn_output = tf.reshape( + nonzero_global_attn_output, + (shape_list(is_local_index_global_attn_nonzero)[0], -1), + ) + + # overwrite values with global attention + attn_output = tf.tensor_scatter_nd_update( + attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output + ) + + global_attn_probs = tf.reshape( + global_attn_probs, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + ) + + return attn_output, global_attn_probs + + def reshape_and_transpose(self, vector, batch_size): + return tf.reshape( + tf.transpose( + tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), + (0, 2, 1, 3), + ), + (batch_size * self.num_heads, -1, self.head_dim), + ) + + +class TFLongformerAttention(keras.layers.Layer): + def __init__(self, config, layer_id=0, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self") + self.dense_output = TFLongformerSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, inputs, training=False): + ( + hidden_states, + attention_mask, + layer_head_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ) = inputs + + self_outputs = self.self_attention( + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], + training=training, + ) + attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +class TFLongformerLayer(keras.layers.Layer): + def __init__(self, config, layer_id=0, **kwargs): + super().__init__(**kwargs) + + self.attention = TFLongformerAttention(config, layer_id, name="attention") + self.intermediate = TFLongformerIntermediate(config, name="intermediate") + self.longformer_output = TFLongformerOutput(config, name="output") + + def call(self, inputs, training=False): + ( + hidden_states, + attention_mask, + layer_head_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ) = inputs + + attention_outputs = self.attention( + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], + training=training, + ) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.longformer_output(intermediate_output, attention_output, training=training) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "longformer_output", None) is not None: + with tf.name_scope(self.longformer_output.name): + self.longformer_output.build(None) + + +class TFLongformerEncoder(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.layer = [TFLongformerLayer(config, i, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + padding_len=0, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + all_hidden_states = () if output_hidden_states else None + all_attentions = all_global_attentions = () if output_attentions else None + + for idx, layer_module in enumerate(self.layer): + if output_hidden_states: + hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + all_hidden_states = all_hidden_states + (hidden_states_to_add,) + + layer_outputs = layer_module( + [ + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + is_index_masked, + is_index_global_attn, + is_global_attn, + ], + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) + all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) + + # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn + all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),) + + # Add last layer + if output_hidden_states: + hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + all_hidden_states = all_hidden_states + (hidden_states_to_add,) + + # undo padding + # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) + hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + if output_attentions: + all_attentions = ( + tuple(state[:, :, :-padding_len, :] for state in all_attentions) if padding_len > 0 else all_attentions + ) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None + ) + + return TFLongformerBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + global_attentions=all_global_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFLongformerMainLayer(keras.layers.Layer): + config_class = LongformerConfig + + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(**kwargs) + + if isinstance(config.attention_window, int): + assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" + assert config.attention_window > 0, "`config.attention_window` has to be positive" + config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer + else: + assert len(config.attention_window) == config.num_hidden_layers, ( + "`len(config.attention_window)` should equal `config.num_hidden_layers`. " + f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" + ) + + self.config = config + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.pad_token_id = config.pad_token_id + self.attention_window = config.attention_window + self.embeddings = TFLongformerEmbeddings(config, name="embeddings") + self.encoder = TFLongformerEncoder(config, name="encoder") + self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + if input_ids is not None and not isinstance(input_ids, tf.Tensor): + input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64) + elif input_ids is not None: + input_ids = tf.cast(input_ids, tf.int64) + + if attention_mask is not None and not isinstance(attention_mask, tf.Tensor): + attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64) + elif attention_mask is not None: + attention_mask = tf.cast(attention_mask, tf.int64) + + if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor): + global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64) + elif global_attention_mask is not None: + global_attention_mask = tf.cast(global_attention_mask, tf.int64) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.cast(tf.fill(input_shape, 1), tf.int64) + + if token_type_ids is None: + token_type_ids = tf.cast(tf.fill(input_shape, 0), tf.int64) + + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + + ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.pad_token_id, + ) + + # is index masked or global attention + is_index_masked = tf.math.less(attention_mask, 1) + is_index_global_attn = tf.math.greater(attention_mask, 1) + is_global_attn = tf.math.reduce_any(is_index_global_attn) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, to_seq_length, 1, 1] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + extended_attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], attention_mask_shape[1], 1, 1)) + + # Since attention_mask is 1.0 for positions we want to attend locally and 0.0 for + # masked and global attn positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 + embedding_output = self.embeddings( + input_ids, + position_ids, + token_type_ids, + inputs_embeds, + training=training, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + padding_len=padding_len, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFLongformerBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + global_attentions=encoder_outputs.global_attentions, + ) + + def _pad_to_window_size( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + pad_token_id, + ): + """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" + # padding + attention_window = ( + self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window) + ) + + assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" + + input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) + batch_size, seq_len = input_shape[:2] + padding_len = (attention_window - seq_len % attention_window) % attention_window + + paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) + + if input_ids is not None: + input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) + + if position_ids is not None: + # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings + position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) + + if inputs_embeds is not None: + if padding_len > 0: + input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64) + inputs_embeds_padding = self.embeddings(input_ids_padding) + inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) + + attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens + token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 + + return ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) + + @staticmethod + def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor): + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + + return attention_mask + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFLongformerPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LongformerConfig + base_model_prefix = "longformer" + + @property + def input_signature(self): + sig = super().input_signature + sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask") + return sig + + +LONGFORMER_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`LongformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +LONGFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`np.ndarray` or `tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + global_attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to decide the attention given on each token, local attention or global attention. Tokens with global + attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also + have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more + details. Mask values selected in `[0, 1]`: + + - 0 for local attention (a sliding window attention), + - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Longformer Model outputting raw hidden-states without any specific head on top.", + LONGFORMER_START_DOCSTRING, +) +class TFLongformerModel(TFLongformerPreTrainedModel): + """ + + This class copies code from [`TFRobertaModel`] and overwrites standard self-attention with longformer + self-attention to provide the ability to process long sequences following the self-attention approach described in + [Longformer: the Long-Document Transformer](https://huggingface.co/papers/2004.05150) by Iz Beltagy, Matthew E. Peters, and + Arman Cohan. Longformer self-attention combines a local (sliding window) and global attention to extend to long + documents without the O(n^2) increase in memory and compute. + + The self-attention module `TFLongformerSelfAttention` implemented here supports the combination of local and global + attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated + attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future + release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA + kernel to be memory and compute efficient. + + """ + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.longformer = TFLongformerMainLayer(config, name="longformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + global_attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFLongformerBaseModelOutputWithPooling | tuple[tf.Tensor]: + outputs = self.longformer( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "longformer", None) is not None: + with tf.name_scope(self.longformer.name): + self.longformer.build(None) + + +@add_start_docstrings( + """Longformer Model with a `language modeling` head on top.""", + LONGFORMER_START_DOCSTRING, +) +class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") + self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="allenai/longformer-base-4096", + output_type=TFLongformerMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.44, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + global_attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFLongformerMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + outputs = self.longformer( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + + return ((loss,) + output) if loss is not None else output + + return TFLongformerMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "longformer", None) is not None: + with tf.name_scope(self.longformer.name): + self.longformer.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +@add_start_docstrings( + """ + Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / + TriviaQA (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LONGFORMER_START_DOCSTRING, +) +class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") + self.qa_outputs = keras.layers.Dense( + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="qa_outputs", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="allenai/longformer-large-4096-finetuned-triviaqa", + output_type=TFLongformerQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="' puppet'", + expected_loss=0.96, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + global_attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFLongformerQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + """ + + if input_ids is not None and not isinstance(input_ids, tf.Tensor): + input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64) + elif input_ids is not None: + input_ids = tf.cast(input_ids, tf.int64) + + if attention_mask is not None and not isinstance(attention_mask, tf.Tensor): + attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64) + elif attention_mask is not None: + attention_mask = tf.cast(attention_mask, tf.int64) + + if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor): + global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64) + elif global_attention_mask is not None: + global_attention_mask = tf.cast(global_attention_mask, tf.int64) + + # set global attention on question tokens + if global_attention_mask is None and input_ids is not None: + if shape_list(tf.where(input_ids == self.config.sep_token_id))[0] != 3 * shape_list(input_ids)[0]: + logger.warning( + f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for" + " questions answering. You might also consider to set `global_attention_mask` manually in the" + " forward function to avoid this. This is most likely an error. The global attention is disabled" + " for this forward pass." + ) + global_attention_mask = tf.cast(tf.fill(shape_list(input_ids), value=0), tf.int64) + else: + logger.warning_once("Initializing global attention on question tokens...") + # put global attention on all tokens until `config.sep_token_id` is reached + sep_token_indices = tf.where(input_ids == self.config.sep_token_id) + sep_token_indices = tf.cast(sep_token_indices, dtype=tf.int64) + global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices) + + outputs = self.longformer( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + + return ((loss,) + output) if loss is not None else output + + return TFLongformerQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "longformer", None) is not None: + with tf.name_scope(self.longformer.name): + self.longformer.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +class TFLongformerClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.out_proj = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + self.config = config + + def call(self, hidden_states, training=False): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + output = self.out_proj(hidden_states) + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + LONGFORMER_START_DOCSTRING, +) +class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") + self.classifier = TFLongformerClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFLongformerSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + global_attention_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFLongformerSequenceClassifierOutput | tuple[tf.Tensor]: + if input_ids is not None and not isinstance(input_ids, tf.Tensor): + input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64) + elif input_ids is not None: + input_ids = tf.cast(input_ids, tf.int64) + + if attention_mask is not None and not isinstance(attention_mask, tf.Tensor): + attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64) + elif attention_mask is not None: + attention_mask = tf.cast(attention_mask, tf.int64) + + if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor): + global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64) + elif global_attention_mask is not None: + global_attention_mask = tf.cast(global_attention_mask, tf.int64) + + if global_attention_mask is None and input_ids is not None: + logger.warning_once("Initializing global attention on CLS token...") + # global attention on cls token + global_attention_mask = tf.zeros_like(input_ids) + updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int64) + indices = tf.pad( + tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0], dtype=tf.int64), axis=1), + paddings=[[0, 0], [0, 1]], + constant_values=0, + ) + global_attention_mask = tf.tensor_scatter_nd_update( + global_attention_mask, + indices, + updates, + ) + + outputs = self.longformer( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFLongformerSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "longformer", None) is not None: + with tf.name_scope(self.longformer.name): + self.longformer.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + Longformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for RocStories/SWAG tasks. + """, + LONGFORMER_START_DOCSTRING, +) +class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.longformer = TFLongformerMainLayer(config, name="longformer") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @property + def input_signature(self): + return { + "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), + "global_attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="global_attention_mask"), + } + + @unpack_inputs + @add_start_docstrings_to_model_forward( + LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFLongformerMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + global_attention_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFLongformerMultipleChoiceModelOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + flat_global_attention_mask = ( + tf.reshape(global_attention_mask, (-1, shape_list(global_attention_mask)[-1])) + if global_attention_mask is not None + else None + ) + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + + outputs = self.longformer( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + global_attention_mask=flat_global_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFLongformerMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "longformer", None) is not None: + with tf.name_scope(self.longformer.name): + self.longformer.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Longformer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + LONGFORMER_START_DOCSTRING, +) +class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.longformer = TFLongformerMainLayer(config=config, add_pooling_layer=False, name="longformer") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFLongformerTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + global_attention_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.array | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFLongformerTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + + outputs = self.longformer( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFLongformerTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + global_attentions=outputs.global_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "longformer", None) is not None: + with tf.name_scope(self.longformer.name): + self.longformer.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFLongformerForMaskedLM", + "TFLongformerForMultipleChoice", + "TFLongformerForQuestionAnswering", + "TFLongformerForSequenceClassification", + "TFLongformerForTokenClassification", + "TFLongformerModel", + "TFLongformerPreTrainedModel", + "TFLongformerSelfAttention", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/longformer/tokenization_longformer.py b/venv/lib/python3.13/site-packages/transformers/models/longformer/tokenization_longformer.py new file mode 100644 index 0000000000000000000000000000000000000000..104bdd7a9b99f5a4809557dcee97cda4873cea12 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/longformer/tokenization_longformer.py @@ -0,0 +1,402 @@ +# coding=utf-8 +# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from functools import lru_cache +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} + + +@lru_cache +# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.roberta.tokenization_roberta.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer with FacebookAI/roberta-base->allenai/longformer-base-4096, RoBERTa->Longformer all-casing, RobertaTokenizer->LongformerTokenizer +class LongformerTokenizer(PreTrainedTokenizer): + """ + Constructs a Longformer tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import LongformerTokenizer + + >>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (Longformer tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + + # these special tokens are not part of the vocab.json, let's add them in the correct order + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + vocab = dict(self.encoder).copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A Longformer sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + +__all__ = ["LongformerTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/longformer/tokenization_longformer_fast.py b/venv/lib/python3.13/site-packages/transformers/models/longformer/tokenization_longformer_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..bde6bb55fec609ae3550c5cf01b43f0dd3bc866c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/longformer/tokenization_longformer_fast.py @@ -0,0 +1,265 @@ +# coding=utf-8 +# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for Longformer.""" + +import json +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_longformer import LongformerTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +# Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast with FacebookAI/roberta-base->allenai/longformer-base-4096, RoBERTa->Longformer all-casing, Roberta->Longformer +class LongformerTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" Longformer tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 + tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import LongformerTokenizerFast + + >>> tokenizer = LongformerTokenizerFast.from_pretrained("allenai/longformer-base-4096") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (Longformer tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = LongformerTokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + trim_offsets=True, + **kwargs, + ): + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + **kwargs, + ) + + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + @property + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + Longformer tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on Longformer. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + +__all__ = ["LongformerTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/luke/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/luke/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db588fb2da1976194806932bf5d39438f9769e19 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/luke/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_luke import * + from .modeling_luke import * + from .tokenization_luke import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/luke/configuration_luke.py b/venv/lib/python3.13/site-packages/transformers/models/luke/configuration_luke.py new file mode 100644 index 0000000000000000000000000000000000000000..fbac7c36d7641808f7d6f0acecdd45c43c0a2ddc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/luke/configuration_luke.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright Studio Ousia and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LUKE configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class LukeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LukeModel`]. It is used to instantiate a LUKE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LUKE + [studio-ousia/luke-base](https://huggingface.co/studio-ousia/luke-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50267): + Vocabulary size of the LUKE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LukeModel`]. + entity_vocab_size (`int`, *optional*, defaults to 500000): + Entity vocabulary size of the LUKE model. Defines the number of different entities that can be represented + by the `entity_ids` passed when calling [`LukeModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + entity_emb_size (`int`, *optional*, defaults to 256): + The number of dimensions of the entity embedding. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`LukeModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + use_entity_aware_attention (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep + Contextualized Entity Representations with Entity-aware Self-attention (Yamada et + al.)](https://huggingface.co/papers/2010.01057). + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + + Examples: + + ```python + >>> from transformers import LukeConfig, LukeModel + + >>> # Initializing a LUKE configuration + >>> configuration = LukeConfig() + + >>> # Initializing a model from the configuration + >>> model = LukeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "luke" + + def __init__( + self, + vocab_size=50267, + entity_vocab_size=500000, + hidden_size=768, + entity_emb_size=256, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_entity_aware_attention=True, + classifier_dropout=None, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + """Constructs LukeConfig.""" + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.entity_vocab_size = entity_vocab_size + self.hidden_size = hidden_size + self.entity_emb_size = entity_emb_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_entity_aware_attention = use_entity_aware_attention + self.classifier_dropout = classifier_dropout + + +__all__ = ["LukeConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/luke/modeling_luke.py b/venv/lib/python3.13/site-packages/transformers/models/luke/modeling_luke.py new file mode 100644 index 0000000000000000000000000000000000000000..e78197beeb5792f263c2369d6f4789ed29c98c54 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/luke/modeling_luke.py @@ -0,0 +1,2179 @@ +# coding=utf-8 +# Copyright Studio Ousia and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LUKE model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward +from ...utils import ModelOutput, auto_docstring, logging +from .configuration_luke import LukeConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the LUKE model. + """ +) +class BaseLukeModelOutputWithPooling(BaseModelOutputWithPooling): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. + entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`): + Sequence of entity hidden-states at the output of the last layer of the model. + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + entity_last_hidden_state: Optional[torch.FloatTensor] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model's outputs, with potential hidden states and attentions. + """ +) +class BaseLukeModelOutput(BaseModelOutput): + r""" + entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`): + Sequence of entity hidden-states at the output of the last layer of the model. + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + entity_last_hidden_state: Optional[torch.FloatTensor] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model's outputs, with potential hidden states and attentions. + """ +) +class LukeMaskedLMOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + The sum of masked language modeling (MLM) loss and entity prediction loss. + mlm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + mep_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked entity prediction (MEP) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + entity_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax). + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + mlm_loss: Optional[torch.FloatTensor] = None + mep_loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + entity_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Outputs of entity classification models. + """ +) +class EntityClassificationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification scores (before SoftMax). + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Outputs of entity pair classification models. + """ +) +class EntityPairClassificationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification scores (before SoftMax). + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Outputs of entity span classification models. + """ +) +class EntitySpanClassificationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, entity_length, config.num_labels)`): + Classification scores (before SoftMax). + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Outputs of sentence classification models. + """ +) +class LukeSequenceClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of token classification models. + """ +) +class LukeTokenClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Outputs of question answering models. + """ +) +class LukeQuestionAnsweringModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: Optional[torch.FloatTensor] = None + end_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Outputs of multiple choice models. + """ +) +class LukeMultipleChoiceModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class LukeEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +class LukeEntityEmbeddings(nn.Module): + def __init__(self, config: LukeConfig): + super().__init__() + self.config = config + + self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0) + if config.entity_emb_size != config.hidden_size: + self.entity_embedding_dense = nn.Linear(config.entity_emb_size, config.hidden_size, bias=False) + + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + entity_ids: torch.LongTensor, + position_ids: torch.LongTensor, + token_type_ids: Optional[torch.LongTensor] = None, + ): + if token_type_ids is None: + token_type_ids = torch.zeros_like(entity_ids) + + entity_embeddings = self.entity_embeddings(entity_ids) + if self.config.entity_emb_size != self.config.hidden_size: + entity_embeddings = self.entity_embedding_dense(entity_embeddings) + + position_embeddings = self.position_embeddings(position_ids.clamp(min=0)) + position_embedding_mask = (position_ids != -1).type_as(position_embeddings).unsqueeze(-1) + position_embeddings = position_embeddings * position_embedding_mask + position_embeddings = torch.sum(position_embeddings, dim=-2) + position_embeddings = position_embeddings / position_embedding_mask.sum(dim=-2).clamp(min=1e-7) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = entity_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class LukeSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.use_entity_aware_attention = config.use_entity_aware_attention + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + if self.use_entity_aware_attention: + self.w2e_query = nn.Linear(config.hidden_size, self.all_head_size) + self.e2w_query = nn.Linear(config.hidden_size, self.all_head_size) + self.e2e_query = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + word_size = word_hidden_states.size(1) + + if entity_hidden_states is None: + concat_hidden_states = word_hidden_states + else: + concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) + + key_layer = self.transpose_for_scores(self.key(concat_hidden_states)) + value_layer = self.transpose_for_scores(self.value(concat_hidden_states)) + + if self.use_entity_aware_attention and entity_hidden_states is not None: + # compute query vectors using word-word (w2w), word-entity (w2e), entity-word (e2w), entity-entity (e2e) + # query layers + w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states)) + w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states)) + e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states)) + e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states)) + + # compute w2w, w2e, e2w, and e2e key vectors used with the query vectors computed above + w2w_key_layer = key_layer[:, :, :word_size, :] + e2w_key_layer = key_layer[:, :, :word_size, :] + w2e_key_layer = key_layer[:, :, word_size:, :] + e2e_key_layer = key_layer[:, :, word_size:, :] + + # compute attention scores based on the dot product between the query and key vectors + w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2)) + w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2)) + e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2)) + e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2)) + + # combine attention scores to create the final attention score matrix + word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3) + entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3) + attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2) + + else: + query_layer = self.transpose_for_scores(self.query(concat_hidden_states)) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in LukeModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + output_word_hidden_states = context_layer[:, :word_size, :] + if entity_hidden_states is None: + output_entity_hidden_states = None + else: + output_entity_hidden_states = context_layer[:, word_size:, :] + + if output_attentions: + outputs = (output_word_hidden_states, output_entity_hidden_states, attention_probs) + else: + outputs = (output_word_hidden_states, output_entity_hidden_states) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class LukeSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LukeAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = LukeSelfAttention(config) + self.output = LukeSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + raise NotImplementedError("LUKE does not support the pruning of attention heads") + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + word_size = word_hidden_states.size(1) + self_outputs = self.self( + word_hidden_states, + entity_hidden_states, + attention_mask, + head_mask, + output_attentions, + ) + if entity_hidden_states is None: + concat_self_outputs = self_outputs[0] + concat_hidden_states = word_hidden_states + else: + concat_self_outputs = torch.cat(self_outputs[:2], dim=1) + concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) + + attention_output = self.output(concat_self_outputs, concat_hidden_states) + + word_attention_output = attention_output[:, :word_size, :] + if entity_hidden_states is None: + entity_attention_output = None + else: + entity_attention_output = attention_output[:, word_size:, :] + + # add attentions if we output them + outputs = (word_attention_output, entity_attention_output) + self_outputs[2:] + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class LukeIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class LukeOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LukeLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = LukeAttention(config) + self.intermediate = LukeIntermediate(config) + self.output = LukeOutput(config) + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + word_size = word_hidden_states.size(1) + + self_attention_outputs = self.attention( + word_hidden_states, + entity_hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + if entity_hidden_states is None: + concat_attention_output = self_attention_outputs[0] + else: + concat_attention_output = torch.cat(self_attention_outputs[:2], dim=1) + + outputs = self_attention_outputs[2:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, concat_attention_output + ) + word_layer_output = layer_output[:, :word_size, :] + if entity_hidden_states is None: + entity_layer_output = None + else: + entity_layer_output = layer_output[:, word_size:, :] + + outputs = (word_layer_output, entity_layer_output) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class LukeEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_word_hidden_states = () if output_hidden_states else None + all_entity_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_word_hidden_states = all_word_hidden_states + (word_hidden_states,) + all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + layer_outputs = layer_module( + word_hidden_states, + entity_hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + ) + + word_hidden_states = layer_outputs[0] + + if entity_hidden_states is not None: + entity_hidden_states = layer_outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_word_hidden_states = all_word_hidden_states + (word_hidden_states,) + all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + word_hidden_states, + all_word_hidden_states, + all_self_attentions, + entity_hidden_states, + all_entity_hidden_states, + ] + if v is not None + ) + return BaseLukeModelOutput( + last_hidden_state=word_hidden_states, + hidden_states=all_word_hidden_states, + attentions=all_self_attentions, + entity_last_hidden_state=entity_hidden_states, + entity_hidden_states=all_entity_hidden_states, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class LukePooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class EntityPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.entity_emb_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.entity_emb_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class EntityPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.transform = EntityPredictionHeadTransform(config) + self.decoder = nn.Linear(config.entity_emb_size, config.entity_vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.entity_vocab_size)) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + + return hidden_states + + +@auto_docstring +class LukePreTrainedModel(PreTrainedModel): + config: LukeConfig + base_model_prefix = "luke" + supports_gradient_checkpointing = True + _no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + if module.embedding_dim == 1: # embedding for bias parameters + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring( + custom_intro=""" + The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any + """ +) +class LukeModel(LukePreTrainedModel): + def __init__(self, config: LukeConfig, add_pooling_layer: bool = True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = LukeEmbeddings(config) + self.entity_embeddings = LukeEntityEmbeddings(config) + self.encoder = LukeEncoder(config) + + self.pooler = LukePooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def get_entity_embeddings(self): + return self.entity_embeddings.entity_embeddings + + def set_entity_embeddings(self, value): + self.entity_embeddings.entity_embeddings = value + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError("LUKE does not support the pruning of attention heads") + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseLukeModelOutputWithPooling]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, LukeModel + + >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-base") + >>> model = LukeModel.from_pretrained("studio-ousia/luke-base") + # Compute the contextualized entity representation corresponding to the entity mention "Beyoncé" + + >>> text = "Beyoncé lives in Los Angeles." + >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé" + + >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") + >>> outputs = model(**encoding) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state + # Input Wikipedia entities to obtain enriched contextualized representations of word tokens + + >>> text = "Beyoncé lives in Los Angeles." + >>> entities = [ + ... "Beyoncé", + ... "Los Angeles", + ... ] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles" + >>> entity_spans = [ + ... (0, 7), + ... (17, 28), + ... ] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles" + + >>> encoding = tokenizer( + ... text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt" + ... ) + >>> outputs = model(**encoding) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if entity_ids is not None: + entity_seq_length = entity_ids.size(1) + if entity_attention_mask is None: + entity_attention_mask = torch.ones((batch_size, entity_seq_length), device=device) + if entity_token_type_ids is None: + entity_token_type_ids = torch.zeros((batch_size, entity_seq_length), dtype=torch.long, device=device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # First, compute word embeddings + word_embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + + # Second, compute extended attention mask + extended_attention_mask = self.get_extended_attention_mask(attention_mask, entity_attention_mask) + + # Third, compute entity embeddings and concatenate with word embeddings + if entity_ids is None: + entity_embedding_output = None + else: + entity_embedding_output = self.entity_embeddings(entity_ids, entity_position_ids, entity_token_type_ids) + + # Fourth, send embeddings through the model + encoder_outputs = self.encoder( + word_embedding_output, + entity_embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Fifth, get the output. LukeModel outputs the same as BertModel, namely sequence_output of shape (batch_size, seq_len, hidden_size) + sequence_output = encoder_outputs[0] + + # Sixth, we compute the pooled_output, word_sequence_output and entity_sequence_output based on the sequence_output + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseLukeModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + entity_last_hidden_state=encoder_outputs.entity_last_hidden_state, + entity_hidden_states=encoder_outputs.entity_hidden_states, + ) + + def get_extended_attention_mask( + self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor] + ): + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + word_attention_mask (`torch.LongTensor`): + Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore. + entity_attention_mask (`torch.LongTensor`, *optional*): + Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + attention_mask = word_attention_mask + if entity_attention_mask is not None: + attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1) + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})") + + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min + return extended_attention_mask + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead +class LukeLMHead(nn.Module): + """Roberta Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@auto_docstring( + custom_intro=""" + The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and + masked entity prediction. + """ +) +class LukeForMaskedLM(LukePreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.lm_head = LukeLMHead(config) + self.entity_predictions = EntityPredictionHead(config) + + self.loss_fn = nn.CrossEntropyLoss() + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + super().tie_weights() + self._tie_or_clone_weights(self.entity_predictions.decoder, self.luke.entity_embeddings.entity_embeddings) + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.LongTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + entity_labels: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LukeMaskedLMOutput]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + entity_labels (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + loss = None + + mlm_loss = None + logits = self.lm_head(outputs.last_hidden_state) + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1)) + if loss is None: + loss = mlm_loss + + mep_loss = None + entity_logits = None + if outputs.entity_last_hidden_state is not None: + entity_logits = self.entity_predictions(outputs.entity_last_hidden_state) + if entity_labels is not None: + mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1)) + if loss is None: + loss = mep_loss + else: + loss = loss + mep_loss + + if not return_dict: + return tuple( + v + for v in [ + loss, + mlm_loss, + mep_loss, + logits, + entity_logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ] + if v is not None + ) + + return LukeMaskedLMOutput( + loss=loss, + mlm_loss=mlm_loss, + mep_loss=mep_loss, + logits=logits, + entity_logits=entity_logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity + token) for entity classification tasks, such as Open Entity. + """ +) +class LukeForEntityClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.num_labels = config.num_labels + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, EntityClassificationOutput]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*): + Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is + used for the single-label classification. In this case, labels should contain the indices that should be in + `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy + loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0 + and 1 indicate false and true, respectively. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, LukeForEntityClassification + + >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity") + >>> model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity") + + >>> text = "Beyoncé lives in Los Angeles." + >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé" + >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: person + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + feature_vector = outputs.entity_last_hidden_state[:, 0, :] + feature_vector = self.dropout(feature_vector) + logits = self.classifier(feature_vector) + + loss = None + if labels is not None: + # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if labels.ndim == 1: + loss = nn.functional.cross_entropy(logits, labels) + else: + loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) + + if not return_dict: + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None + ) + + return EntityClassificationOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The LUKE model with a classification head on top (a linear layer on top of the hidden states of the two entity + tokens) for entity pair classification tasks, such as TACRED. + """ +) +class LukeForEntityPairClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.num_labels = config.num_labels + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels, False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, EntityPairClassificationOutput]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*): + Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is + used for the single-label classification. In this case, labels should contain the indices that should be in + `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy + loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0 + and 1 indicate false and true, respectively. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, LukeForEntityPairClassification + + >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred") + >>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred") + + >>> text = "Beyoncé lives in Los Angeles." + >>> entity_spans = [ + ... (0, 7), + ... (17, 28), + ... ] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles" + >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: per:cities_of_residence + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + feature_vector = torch.cat( + [outputs.entity_last_hidden_state[:, 0, :], outputs.entity_last_hidden_state[:, 1, :]], dim=1 + ) + feature_vector = self.dropout(feature_vector) + logits = self.classifier(feature_vector) + + loss = None + if labels is not None: + # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if labels.ndim == 1: + loss = nn.functional.cross_entropy(logits, labels) + else: + loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) + + if not return_dict: + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None + ) + + return EntityPairClassificationOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The LUKE model with a span classification head on top (a linear layer on top of the hidden states output) for tasks + such as named entity recognition. + """ +) +class LukeForEntitySpanClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.num_labels = config.num_labels + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.LongTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + entity_start_positions: Optional[torch.LongTensor] = None, + entity_end_positions: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, EntitySpanClassificationOutput]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + entity_start_positions (`torch.LongTensor`): + The start positions of entities in the word token sequence. + entity_end_positions (`torch.LongTensor`): + The end positions of entities in the word token sequence. + labels (`torch.LongTensor` of shape `(batch_size, entity_length)` or `(batch_size, entity_length, num_labels)`, *optional*): + Labels for computing the classification loss. If the shape is `(batch_size, entity_length)`, the cross + entropy loss is used for the single-label classification. In this case, labels should contain the indices + that should be in `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, entity_length, + num_labels)`, the binary cross entropy loss is used for the multi-label classification. In this case, + labels should only contain `[0, 1]`, where 0 and 1 indicate false and true, respectively. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, LukeForEntitySpanClassification + + >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003") + >>> model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003") + + >>> text = "Beyoncé lives in Los Angeles" + # List all possible entity spans in the text + + >>> word_start_positions = [0, 8, 14, 17, 21] # character-based start positions of word tokens + >>> word_end_positions = [7, 13, 16, 20, 28] # character-based end positions of word tokens + >>> entity_spans = [] + >>> for i, start_pos in enumerate(word_start_positions): + ... for end_pos in word_end_positions[i:]: + ... entity_spans.append((start_pos, end_pos)) + + >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> predicted_class_indices = logits.argmax(-1).squeeze().tolist() + >>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices): + ... if predicted_class_idx != 0: + ... print(text[span[0] : span[1]], model.config.id2label[predicted_class_idx]) + Beyoncé PER + Los Angeles LOC + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + hidden_size = outputs.last_hidden_state.size(-1) + + entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size) + if entity_start_positions.device != outputs.last_hidden_state.device: + entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device) + start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions) + + entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size) + if entity_end_positions.device != outputs.last_hidden_state.device: + entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device) + end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions) + + feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2) + + feature_vector = self.dropout(feature_vector) + logits = self.classifier(feature_vector) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. + if labels.ndim == 2: + loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + else: + loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) + + if not return_dict: + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None + ) + + return EntitySpanClassificationOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class LukeForSequenceClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.luke = LukeModel(config) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LukeSequenceClassifierOutput]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pooled_output = outputs.pooler_output + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None + ) + + return LukeSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To + solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this + class. + """ +) +class LukeForTokenClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.luke = LukeModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LukeTokenClassifierOutput]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + sequence_output = outputs.last_hidden_state + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None + ) + + return LukeTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class LukeForQuestionAnswering(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + + self.luke = LukeModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LukeQuestionAnsweringModelOutput]: + r""" + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + return tuple( + v + for v in [ + total_loss, + start_logits, + end_logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ] + if v is not None + ) + + return LukeQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class LukeForMultipleChoice(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, LukeMultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*): + Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in `[0, 1]`: + + - 0 corresponds to a *portion A* entity token, + - 1 corresponds to a *portion B* entity token. + entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*): + Indices of positions of each input entity in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None + entity_attention_mask = ( + entity_attention_mask.view(-1, entity_attention_mask.size(-1)) + if entity_attention_mask is not None + else None + ) + entity_token_type_ids = ( + entity_token_type_ids.view(-1, entity_token_type_ids.size(-1)) + if entity_token_type_ids is not None + else None + ) + entity_position_ids = ( + entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1)) + if entity_position_ids is not None + else None + ) + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pooled_output = outputs.pooler_output + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(reshaped_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + return tuple( + v + for v in [ + loss, + reshaped_logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ] + if v is not None + ) + + return LukeMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "LukeForEntityClassification", + "LukeForEntityPairClassification", + "LukeForEntitySpanClassification", + "LukeForMultipleChoice", + "LukeForQuestionAnswering", + "LukeForSequenceClassification", + "LukeForTokenClassification", + "LukeForMaskedLM", + "LukeModel", + "LukePreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/luke/tokenization_luke.py b/venv/lib/python3.13/site-packages/transformers/models/luke/tokenization_luke.py new file mode 100644 index 0000000000000000000000000000000000000000..efbc757e86307c7b70a4d36d7d6930f5557a43e7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/luke/tokenization_luke.py @@ -0,0 +1,1731 @@ +# coding=utf-8 +# Copyright Studio-Ouisa and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for LUKE.""" + +import itertools +import json +import os +from collections.abc import Mapping +from functools import lru_cache +from typing import Optional, Union + +import numpy as np +import regex as re + +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + PaddingStrategy, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, + to_py_obj, +) +from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging + + +logger = logging.get_logger(__name__) + +EntitySpan = tuple[int, int] +EntitySpanInput = list[EntitySpan] +Entity = str +EntityInput = list[Entity] + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", + "entity_vocab_file": "entity_vocab.json", +} + + +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **entity_ids** -- List of entity ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **entity_position_ids** -- List of entity positions in the input sequence to be fed to a model. + + - **entity_token_type_ids** -- List of entity token type ids to be fed to a model (when + `return_token_type_ids=True` or if *"entity_token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **entity_attention_mask** -- List of indices specifying which entities should be attended to by the model + (when `return_attention_mask=True` or if *"entity_attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **entity_start_positions** -- List of the start positions of entities in the word token sequence (when + `task="entity_span_classification"`). + - **entity_end_positions** -- List of the end positions of entities in the word token sequence (when + `task="entity_span_classification"`). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`) + +""" + + +@lru_cache +# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.roberta.tokenization_roberta.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class LukeTokenizer(PreTrainedTokenizer): + """ + Constructs a LUKE tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import LukeTokenizer + + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. It also creates entity sequences, namely + `entity_ids`, `entity_attention_mask`, `entity_token_type_ids`, and `entity_position_ids` to be used by the LUKE + model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + entity_vocab_file (`str`): + Path to the entity vocabulary file. + task (`str`, *optional*): + Task for which you want to prepare sequences. One of `"entity_classification"`, + `"entity_pair_classification"`, or `"entity_span_classification"`. If you specify this argument, the entity + sequence is automatically created based on the given entity span(s). + max_entity_length (`int`, *optional*, defaults to 32): + The maximum length of `entity_ids`. + max_mention_length (`int`, *optional*, defaults to 30): + The maximum number of tokens inside an entity span. + entity_token_1 (`str`, *optional*, defaults to ``): + The special token used to represent an entity span in a word token sequence. This token is only used when + `task` is set to `"entity_classification"` or `"entity_pair_classification"`. + entity_token_2 (`str`, *optional*, defaults to ``): + The special token used to represent an entity span in a word token sequence. This token is only used when + `task` is set to `"entity_pair_classification"`. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (LUKE tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + entity_vocab_file, + task=None, + max_entity_length=32, + max_mention_length=30, + entity_token_1="", + entity_token_2="", + entity_unk_token="[UNK]", + entity_pad_token="[PAD]", + entity_mask_token="[MASK]", + entity_mask2_token="[MASK2]", + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + # we add 2 special tokens for downstream tasks + # for more information about lstrip and rstrip, see https://github.com/huggingface/transformers/pull/2778 + entity_token_1 = ( + AddedToken(entity_token_1, lstrip=False, rstrip=False) + if isinstance(entity_token_1, str) + else entity_token_1 + ) + entity_token_2 = ( + AddedToken(entity_token_2, lstrip=False, rstrip=False) + if isinstance(entity_token_2, str) + else entity_token_2 + ) + kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) + kwargs["additional_special_tokens"] += [entity_token_1, entity_token_2] + + with open(entity_vocab_file, encoding="utf-8") as entity_vocab_handle: + self.entity_vocab = json.load(entity_vocab_handle) + for entity_special_token in [entity_unk_token, entity_pad_token, entity_mask_token, entity_mask2_token]: + if entity_special_token not in self.entity_vocab: + raise ValueError( + f"Specified entity special token ``{entity_special_token}`` is not found in entity_vocab. " + f"Probably an incorrect entity vocab file is loaded: {entity_vocab_file}." + ) + self.entity_unk_token_id = self.entity_vocab[entity_unk_token] + self.entity_pad_token_id = self.entity_vocab[entity_pad_token] + self.entity_mask_token_id = self.entity_vocab[entity_mask_token] + self.entity_mask2_token_id = self.entity_vocab[entity_mask2_token] + + self.task = task + if task is None or task == "entity_span_classification": + self.max_entity_length = max_entity_length + elif task == "entity_classification": + self.max_entity_length = 1 + elif task == "entity_pair_classification": + self.max_entity_length = 2 + else: + raise ValueError( + f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification'," + " 'entity_span_classification'] only." + ) + + self.max_mention_length = max_mention_length + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + task=task, + max_entity_length=32, + max_mention_length=30, + entity_token_1="", + entity_token_2="", + entity_unk_token=entity_unk_token, + entity_pad_token=entity_pad_token, + entity_mask_token=entity_mask_token, + entity_mask2_token=entity_mask2_token, + **kwargs, + ) + + @property + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.vocab_size with Roberta->Luke, RoBERTa->LUKE + def vocab_size(self): + return len(self.encoder) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_vocab with Roberta->Luke, RoBERTa->LUKE + def get_vocab(self): + vocab = dict(self.encoder).copy() + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.bpe with Roberta->Luke, RoBERTa->LUKE + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._tokenize with Roberta->Luke, RoBERTa->LUKE + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_token_to_id with Roberta->Luke, RoBERTa->LUKE + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_id_to_token with Roberta->Luke, RoBERTa->LUKE + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.convert_tokens_to_string with Roberta->Luke, RoBERTa->LUKE + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.build_inputs_with_special_tokens with Roberta->Luke, RoBERTa->LUKE + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A LUKE sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_special_tokens_mask with Roberta->Luke, RoBERTa->LUKE + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.create_token_type_ids_from_sequences with Roberta->Luke, RoBERTa->LUKE + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. LUKE does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.prepare_for_tokenization with Roberta->Luke, RoBERTa->LUKE + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, list[TextInput]], + text_pair: Optional[Union[TextInput, list[TextInput]]] = None, + entity_spans: Optional[Union[EntitySpanInput, list[EntitySpanInput]]] = None, + entity_spans_pair: Optional[Union[EntitySpanInput, list[EntitySpanInput]]] = None, + entities: Optional[Union[EntityInput, list[EntityInput]]] = None, + entities_pair: Optional[Union[EntityInput, list[EntityInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences, depending on the task you want to prepare them for. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + text_pair (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + entity_spans (`list[tuple[int, int]]`, `list[list[tuple[int, int]]]`, *optional*): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify + `"entity_classification"` or `"entity_pair_classification"` as the `task` argument in the constructor, + the length of each sequence must be 1 or 2, respectively. If you specify `entities`, the length of each + sequence must be equal to the length of each sequence of `entities`. + entity_spans_pair (`list[tuple[int, int]]`, `list[list[tuple[int, int]]]`, *optional*): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify the + `task` argument in the constructor, this argument is ignored. If you specify `entities_pair`, the + length of each sequence must be equal to the length of each sequence of `entities_pair`. + entities (`list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the `task` argument in the constructor. The length of + each sequence must be equal to the length of each sequence of `entity_spans`. If you specify + `entity_spans` without specifying this argument, the entity sequence or the batch of entity sequences + is automatically constructed by filling it with the [MASK] entity. + entities_pair (`list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the `task` argument in the constructor. The length of + each sequence must be equal to the length of each sequence of `entity_spans_pair`. If you specify + `entity_spans_pair` without specifying this argument, the entity sequence or the batch of entity + sequences is automatically constructed by filling it with the [MASK] entity. + max_entity_length (`int`, *optional*): + The maximum length of `entity_ids`. + """ + # Input type checking for clearer error + is_valid_single_text = isinstance(text, str) + is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str))) + if not (is_valid_single_text or is_valid_batch_text): + raise ValueError("text input must be of type `str` (single example) or `list[str]` (batch).") + + is_valid_single_text_pair = isinstance(text_pair, str) + is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and ( + len(text_pair) == 0 or isinstance(text_pair[0], str) + ) + if not (text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair): + raise ValueError("text_pair input must be of type `str` (single example) or `list[str]` (batch).") + + is_batched = bool(isinstance(text, (list, tuple))) + + if is_batched: + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + if entities is None: + batch_entities_or_entities_pairs = None + else: + batch_entities_or_entities_pairs = ( + list(zip(entities, entities_pair)) if entities_pair is not None else entities + ) + + if entity_spans is None: + batch_entity_spans_or_entity_spans_pairs = None + else: + batch_entity_spans_or_entity_spans_pairs = ( + list(zip(entity_spans, entity_spans_pair)) if entity_spans_pair is not None else entity_spans + ) + + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs, + batch_entities_or_entities_pairs=batch_entities_or_entities_pairs, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + entities=entities, + entities_pair=entities_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + + # prepare_for_model will create the attention_mask and token_type_ids + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[list[TextInput], list[TextInputPair]], + batch_entity_spans_or_entity_spans_pairs: Optional[ + Union[list[EntitySpanInput], list[tuple[EntitySpanInput, EntitySpanInput]]] + ] = None, + batch_entities_or_entities_pairs: Optional[ + Union[list[EntityInput], list[tuple[EntityInput, EntityInput]]] + ] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + # input_ids is a list of tuples (one for each example in the batch) + input_ids = [] + entity_ids = [] + entity_token_spans = [] + for index, text_or_text_pair in enumerate(batch_text_or_text_pairs): + if not isinstance(text_or_text_pair, (list, tuple)): + text, text_pair = text_or_text_pair, None + else: + text, text_pair = text_or_text_pair + + entities, entities_pair = None, None + if batch_entities_or_entities_pairs is not None: + entities_or_entities_pairs = batch_entities_or_entities_pairs[index] + if entities_or_entities_pairs: + if isinstance(entities_or_entities_pairs[0], str): + entities, entities_pair = entities_or_entities_pairs, None + else: + entities, entities_pair = entities_or_entities_pairs + + entity_spans, entity_spans_pair = None, None + if batch_entity_spans_or_entity_spans_pairs is not None: + entity_spans_or_entity_spans_pairs = batch_entity_spans_or_entity_spans_pairs[index] + if len(entity_spans_or_entity_spans_pairs) > 0 and isinstance( + entity_spans_or_entity_spans_pairs[0], list + ): + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs + else: + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + input_ids.append((first_ids, second_ids)) + entity_ids.append((first_entity_ids, second_entity_ids)) + entity_token_spans.append((first_entity_token_spans, second_entity_token_spans)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + batch_entity_ids_pairs=entity_ids, + batch_entity_token_spans_pairs=entity_token_spans, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spans: Optional[EntitySpanInput]): + if not isinstance(entity_spans, list): + raise TypeError("entity_spans should be given as a list") + elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple): + raise ValueError( + "entity_spans should be given as a list of tuples containing the start and end character indices" + ) + + if entities is not None: + if not isinstance(entities, list): + raise ValueError("If you specify entities, they should be given as a list") + + if len(entities) > 0 and not isinstance(entities[0], str): + raise ValueError("If you specify entities, they should be given as a list of entity names") + + if len(entities) != len(entity_spans): + raise ValueError("If you specify entities, entities and entity_spans must be the same length") + + def _create_input_sequence( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + **kwargs, + ) -> tuple[list, list, list, list, list, list]: + def get_input_ids(text): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + + def get_input_ids_and_entity_token_spans(text, entity_spans): + if entity_spans is None: + return get_input_ids(text), None + + cur = 0 + input_ids = [] + entity_token_spans = [None] * len(entity_spans) + + split_char_positions = sorted(frozenset(itertools.chain(*entity_spans))) + char_pos2token_pos = {} + + for split_char_position in split_char_positions: + orig_split_char_position = split_char_position + if ( + split_char_position > 0 and text[split_char_position - 1] == " " + ): # whitespace should be prepended to the following token + split_char_position -= 1 + if cur != split_char_position: + input_ids += get_input_ids(text[cur:split_char_position]) + cur = split_char_position + char_pos2token_pos[orig_split_char_position] = len(input_ids) + + input_ids += get_input_ids(text[cur:]) + + entity_token_spans = [ + (char_pos2token_pos[char_start], char_pos2token_pos[char_end]) for char_start, char_end in entity_spans + ] + + return input_ids, entity_token_spans + + first_ids, second_ids = None, None + first_entity_ids, second_entity_ids = None, None + first_entity_token_spans, second_entity_token_spans = None, None + + if self.task is None: + if entity_spans is None: + first_ids = get_input_ids(text) + else: + self._check_entity_input_format(entities, entity_spans) + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + if entities is None: + first_entity_ids = [self.entity_mask_token_id] * len(entity_spans) + else: + first_entity_ids = [self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities] + + if text_pair is not None: + if entity_spans_pair is None: + second_ids = get_input_ids(text_pair) + else: + self._check_entity_input_format(entities_pair, entity_spans_pair) + + second_ids, second_entity_token_spans = get_input_ids_and_entity_token_spans( + text_pair, entity_spans_pair + ) + if entities_pair is None: + second_entity_ids = [self.entity_mask_token_id] * len(entity_spans_pair) + else: + second_entity_ids = [ + self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities_pair + ] + + elif self.task == "entity_classification": + if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be a list containing a single tuple " + "containing the start and end character indices of an entity" + ) + first_entity_ids = [self.entity_mask_token_id] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + # add special tokens to input ids + entity_token_start, entity_token_end = first_entity_token_spans[0] + first_ids = ( + first_ids[:entity_token_end] + [self.additional_special_tokens_ids[0]] + first_ids[entity_token_end:] + ) + first_ids = ( + first_ids[:entity_token_start] + + [self.additional_special_tokens_ids[0]] + + first_ids[entity_token_start:] + ) + first_entity_token_spans = [(entity_token_start, entity_token_end + 2)] + + elif self.task == "entity_pair_classification": + if not ( + isinstance(entity_spans, list) + and len(entity_spans) == 2 + and isinstance(entity_spans[0], tuple) + and isinstance(entity_spans[1], tuple) + ): + raise ValueError( + "Entity spans should be provided as a list of two tuples, " + "each tuple containing the start and end character indices of an entity" + ) + + head_span, tail_span = entity_spans + first_entity_ids = [self.entity_mask_token_id, self.entity_mask2_token_id] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + head_token_span, tail_token_span = first_entity_token_spans + token_span_with_special_token_ids = [ + (head_token_span, self.additional_special_tokens_ids[0]), + (tail_token_span, self.additional_special_tokens_ids[1]), + ] + if head_token_span[0] < tail_token_span[0]: + first_entity_token_spans[0] = (head_token_span[0], head_token_span[1] + 2) + first_entity_token_spans[1] = (tail_token_span[0] + 2, tail_token_span[1] + 4) + token_span_with_special_token_ids = reversed(token_span_with_special_token_ids) + else: + first_entity_token_spans[0] = (head_token_span[0] + 2, head_token_span[1] + 4) + first_entity_token_spans[1] = (tail_token_span[0], tail_token_span[1] + 2) + + for (entity_token_start, entity_token_end), special_token_id in token_span_with_special_token_ids: + first_ids = first_ids[:entity_token_end] + [special_token_id] + first_ids[entity_token_end:] + first_ids = first_ids[:entity_token_start] + [special_token_id] + first_ids[entity_token_start:] + + elif self.task == "entity_span_classification": + if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be provided as a list of tuples, " + "each tuple containing the start and end character indices of an entity" + ) + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + first_entity_ids = [self.entity_mask_token_id] * len(entity_spans) + + else: + raise ValueError(f"Task {self.task} not supported") + + return ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_ids_pairs: list[tuple[list[int], None]], + batch_entity_ids_pairs: list[tuple[Optional[list[int]], Optional[list[int]]]], + batch_entity_token_spans_pairs: list[tuple[Optional[list[tuple[int, int]]], Optional[list[tuple[int, int]]]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + batch_entity_ids_pairs: list of entity ids or entity ids pairs + batch_entity_token_spans_pairs: list of entity spans or entity spans pairs + max_entity_length: The maximum length of the entity sequence. + """ + + batch_outputs = {} + for input_ids, entity_ids, entity_token_span_pairs in zip( + batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_spans_pairs + ): + first_ids, second_ids = input_ids + first_entity_ids, second_entity_ids = entity_ids + first_entity_token_spans, second_entity_token_spans = entity_token_span_pairs + outputs = self.prepare_for_model( + first_ids, + second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + ids: list[int], + pair_ids: Optional[list[int]] = None, + entity_ids: Optional[list[int]] = None, + pair_entity_ids: Optional[list[int]] = None, + entity_token_spans: Optional[list[tuple[int, int]]] = None, + pair_entity_token_spans: Optional[list[tuple[int, int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids, + entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing + while taking into account the special tokens and manages a moving window (with user defined stride) for + overflowing tokens. Please Note, for *pair_ids* different than `None` and *truncation_strategy = longest_first* + or `True`, it is not possible to return overflowing tokens. Such a combination of arguments will raise an + error. + + Args: + ids (`list[int]`): + Tokenized input ids of the first sequence. + pair_ids (`list[int]`, *optional*): + Tokenized input ids of the second sequence. + entity_ids (`list[int]`, *optional*): + Entity ids of the first sequence. + pair_entity_ids (`list[int]`, *optional*): + Entity ids of the second sequence. + entity_token_spans (`list[tuple[int, int]]`, *optional*): + Entity spans of the first sequence. + pair_entity_token_spans (`list[tuple[int, int]]`, *optional*): + Entity spans of the second sequence. + max_entity_length (`int`, *optional*): + The maximum length of the entity sequence. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + # Compute lengths + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned word encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length and max_entity_length + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + # truncate words up to max_length + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + entity_token_offset = 1 # 1 * token + pair_entity_token_offset = len(ids) + 3 # 1 * token & 2 * tokens + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + entity_token_offset = 0 + pair_entity_token_offset = len(ids) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Set max entity length + if not max_entity_length: + max_entity_length = self.max_entity_length + + if entity_ids is not None: + total_entity_len = 0 + num_invalid_entities = 0 + valid_entity_ids = [ent_id for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)] + valid_entity_token_spans = [span for span in entity_token_spans if span[1] <= len(ids)] + + total_entity_len += len(valid_entity_ids) + num_invalid_entities += len(entity_ids) - len(valid_entity_ids) + + valid_pair_entity_ids, valid_pair_entity_token_spans = None, None + if pair_entity_ids is not None: + valid_pair_entity_ids = [ + ent_id + for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans) + if span[1] <= len(pair_ids) + ] + valid_pair_entity_token_spans = [span for span in pair_entity_token_spans if span[1] <= len(pair_ids)] + total_entity_len += len(valid_pair_entity_ids) + num_invalid_entities += len(pair_entity_ids) - len(valid_pair_entity_ids) + + if num_invalid_entities != 0: + logger.warning( + f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the" + " truncation of input tokens" + ) + + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length: + # truncate entities up to max_entity_length + valid_entity_ids, valid_pair_entity_ids, overflowing_entities = self.truncate_sequences( + valid_entity_ids, + pair_ids=valid_pair_entity_ids, + num_tokens_to_remove=total_entity_len - max_entity_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + valid_entity_token_spans = valid_entity_token_spans[: len(valid_entity_ids)] + if valid_pair_entity_token_spans is not None: + valid_pair_entity_token_spans = valid_pair_entity_token_spans[: len(valid_pair_entity_ids)] + + if return_overflowing_tokens: + encoded_inputs["overflowing_entities"] = overflowing_entities + encoded_inputs["num_truncated_entities"] = total_entity_len - max_entity_length + + final_entity_ids = valid_entity_ids + valid_pair_entity_ids if valid_pair_entity_ids else valid_entity_ids + encoded_inputs["entity_ids"] = list(final_entity_ids) + entity_position_ids = [] + entity_start_positions = [] + entity_end_positions = [] + for token_spans, offset in ( + (valid_entity_token_spans, entity_token_offset), + (valid_pair_entity_token_spans, pair_entity_token_offset), + ): + if token_spans is not None: + for start, end in token_spans: + start += offset + end += offset + position_ids = list(range(start, end))[: self.max_mention_length] + position_ids += [-1] * (self.max_mention_length - end + start) + entity_position_ids.append(position_ids) + entity_start_positions.append(start) + entity_end_positions.append(end - 1) + + encoded_inputs["entity_position_ids"] = entity_position_ids + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = entity_start_positions + encoded_inputs["entity_end_positions"] = entity_end_positions + + if return_token_type_ids: + encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"]) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def pad( + self, + encoded_inputs: Union[ + BatchEncoding, + list[BatchEncoding], + dict[str, EncodedInput], + dict[str, list[EncodedInput]], + list[dict[str, EncodedInput]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + ) -> BatchEncoding: + """ + Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length + in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with + `self.padding_side`, `self.pad_token_id` and `self.pad_token_type_id`) .. note:: If the `encoded_inputs` passed + are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result will use the same type unless + you provide a different tensor type with `return_tensors`. In the case of PyTorch tensors, you will lose the + specific device of your tensors however. + + Args: + encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `dict[str, list[int]]`, `dict[str, list[list[int]]` or `list[dict[str, list[int]]]`): + Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of + tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str, + list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or + TensorFlow tensors), see the note above for the return type. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + max_entity_length (`int`, *optional*): + The maximum length of the entity sequence. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. [What are attention + masks?](../glossary#attention-mask) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + """ + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]} + + # The model's main input name, usually `input_ids`, has be passed for padding + if self.model_input_names[0] not in encoded_inputs: + raise ValueError( + "You should supply an encoding or a list of encodings to this method " + f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" + ) + + required_input = encoded_inputs[self.model_input_names[0]] + + if not required_input: + if return_attention_mask: + encoded_inputs["attention_mask"] = [] + return encoded_inputs + + # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + index = 0 + while len(required_input[index]) == 0: + index += 1 + if index < len(required_input): + first_element = required_input[index][0] + # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. + if not isinstance(first_element, (int, list, tuple)): + if is_tf_tensor(first_element): + return_tensors = "tf" if return_tensors is None else return_tensors + elif is_torch_tensor(first_element): + return_tensors = "pt" if return_tensors is None else return_tensors + elif isinstance(first_element, np.ndarray): + return_tensors = "np" if return_tensors is None else return_tensors + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + "Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in encoded_inputs.items(): + encoded_inputs[key] = to_py_obj(value) + + # Convert padding_strategy in PaddingStrategy + padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, verbose=verbose + ) + + if max_entity_length is None: + max_entity_length = self.max_entity_length + + required_input = encoded_inputs[self.model_input_names[0]] + if required_input and not isinstance(required_input[0], (list, tuple)): + encoded_inputs = self._pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + return BatchEncoding(encoded_inputs, tensor_type=return_tensors) + + batch_size = len(required_input) + if any(len(v) != batch_size for v in encoded_inputs.values()): + raise ValueError("Some items in the output dictionary have a different batch size than others.") + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in required_input) + max_entity_length = ( + max(len(inputs) for inputs in encoded_inputs["entity_ids"]) if "entity_ids" in encoded_inputs else 0 + ) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + inputs = {k: v[i] for k, v in encoded_inputs.items()} + outputs = self._pad( + inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + return BatchEncoding(batch_outputs, tensor_type=return_tensors) + + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + max_entity_length: The maximum length of the entity sequence. + padding_strategy: PaddingStrategy to use for padding. + + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + entities_provided = bool("entity_ids" in encoded_inputs) + + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(encoded_inputs["input_ids"]) + if entities_provided: + max_entity_length = len(encoded_inputs["entity_ids"]) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + if ( + entities_provided + and max_entity_length is not None + and pad_to_multiple_of is not None + and (max_entity_length % pad_to_multiple_of != 0) + ): + max_entity_length = ((max_entity_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ( + len(encoded_inputs["input_ids"]) != max_length + or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length) + ) + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + if entities_provided and return_attention_mask and "entity_attention_mask" not in encoded_inputs: + encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + + if needs_to_be_padded: + difference = max_length - len(encoded_inputs["input_ids"]) + padding_side = padding_side if padding_side is not None else self.padding_side + if entities_provided: + entity_difference = max_entity_length - len(encoded_inputs["entity_ids"]) + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if entities_provided: + encoded_inputs["entity_attention_mask"] = ( + encoded_inputs["entity_attention_mask"] + [0] * entity_difference + ) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference + if entities_provided: + encoded_inputs["entity_token_type_ids"] = ( + encoded_inputs["entity_token_type_ids"] + [0] * entity_difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference + if entities_provided: + encoded_inputs["entity_ids"] = ( + encoded_inputs["entity_ids"] + [self.entity_pad_token_id] * entity_difference + ) + encoded_inputs["entity_position_ids"] = ( + encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference + ) + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = ( + encoded_inputs["entity_start_positions"] + [0] * entity_difference + ) + encoded_inputs["entity_end_positions"] = ( + encoded_inputs["entity_end_positions"] + [0] * entity_difference + ) + + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if entities_provided: + encoded_inputs["entity_attention_mask"] = [0] * entity_difference + encoded_inputs[ + "entity_attention_mask" + ] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"] + if entities_provided: + encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[ + "entity_token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] + if entities_provided: + encoded_inputs["entity_ids"] = [self.entity_pad_token_id] * entity_difference + encoded_inputs[ + "entity_ids" + ] + encoded_inputs["entity_position_ids"] = [ + [-1] * self.max_mention_length + ] * entity_difference + encoded_inputs["entity_position_ids"] + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_start_positions" + ] + encoded_inputs["entity_end_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_end_positions" + ] + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + entity_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"] + ) + + with open(entity_vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + return vocab_file, merge_file, entity_vocab_file + + +__all__ = ["LukeTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/markuplm/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/markuplm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b03aa2e625637809438c4b47d035ddf6db26127 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/markuplm/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_markuplm import * + from .feature_extraction_markuplm import * + from .modeling_markuplm import * + from .processing_markuplm import * + from .tokenization_markuplm import * + from .tokenization_markuplm_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/markuplm/configuration_markuplm.py b/venv/lib/python3.13/site-packages/transformers/models/markuplm/configuration_markuplm.py new file mode 100644 index 0000000000000000000000000000000000000000..34c2083df85a9ba0a1472ae76785cc3313ed0809 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/markuplm/configuration_markuplm.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2021, The Microsoft Research Asia MarkupLM Team authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MarkupLM model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MarkupLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MarkupLMModel`]. It is used to instantiate a + MarkupLM model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MarkupLM + [microsoft/markuplm-base](https://huggingface.co/microsoft/markuplm-base) architecture. + + Configuration objects inherit from [`BertConfig`] and can be used to control the model outputs. Read the + documentation from [`BertConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the MarkupLM model. Defines the different tokens that can be represented by the + *inputs_ids* passed to the forward method of [`MarkupLMModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed into [`MarkupLMModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + max_tree_id_unit_embeddings (`int`, *optional*, defaults to 1024): + The maximum value that the tree id unit embedding might ever use. Typically set this to something large + just in case (e.g., 1024). + max_xpath_tag_unit_embeddings (`int`, *optional*, defaults to 256): + The maximum value that the xpath tag unit embedding might ever use. Typically set this to something large + just in case (e.g., 256). + max_xpath_subs_unit_embeddings (`int`, *optional*, defaults to 1024): + The maximum value that the xpath subscript unit embedding might ever use. Typically set this to something + large just in case (e.g., 1024). + tag_pad_id (`int`, *optional*, defaults to 216): + The id of the padding token in the xpath tags. + subs_pad_id (`int`, *optional*, defaults to 1001): + The id of the padding token in the xpath subscripts. + xpath_tag_unit_hidden_size (`int`, *optional*, defaults to 32): + The hidden size of each tree id unit. One complete tree index will have + (50*xpath_tag_unit_hidden_size)-dim. + max_depth (`int`, *optional*, defaults to 50): + The maximum depth in xpath. + + Examples: + + ```python + >>> from transformers import MarkupLMModel, MarkupLMConfig + + >>> # Initializing a MarkupLM microsoft/markuplm-base style configuration + >>> configuration = MarkupLMConfig() + + >>> # Initializing a model from the microsoft/markuplm-base style configuration + >>> model = MarkupLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "markuplm" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + bos_token_id=0, + eos_token_id=2, + max_xpath_tag_unit_embeddings=256, + max_xpath_subs_unit_embeddings=1024, + tag_pad_id=216, + subs_pad_id=1001, + xpath_unit_hidden_size=32, + max_depth=50, + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + # additional properties + self.max_depth = max_depth + self.max_xpath_tag_unit_embeddings = max_xpath_tag_unit_embeddings + self.max_xpath_subs_unit_embeddings = max_xpath_subs_unit_embeddings + self.tag_pad_id = tag_pad_id + self.subs_pad_id = subs_pad_id + self.xpath_unit_hidden_size = xpath_unit_hidden_size + + +__all__ = ["MarkupLMConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/markuplm/feature_extraction_markuplm.py b/venv/lib/python3.13/site-packages/transformers/models/markuplm/feature_extraction_markuplm.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea35b6da56f12f5f835b4da03e298aef344b297 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/markuplm/feature_extraction_markuplm.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for MarkupLM. +""" + +import html + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...utils import is_bs4_available, logging, requires_backends + + +if is_bs4_available(): + import bs4 + from bs4 import BeautifulSoup + + +logger = logging.get_logger(__name__) + + +class MarkupLMFeatureExtractor(FeatureExtractionMixin): + r""" + Constructs a MarkupLM feature extractor. This can be used to get a list of nodes and corresponding xpaths from HTML + strings. + + This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most + of the main methods. Users should refer to this superclass for more information regarding those methods. + + """ + + def __init__(self, **kwargs): + requires_backends(self, ["bs4"]) + super().__init__(**kwargs) + + def xpath_soup(self, element): + xpath_tags = [] + xpath_subscripts = [] + child = element if element.name else element.parent + for parent in child.parents: # type: bs4.element.Tag + siblings = parent.find_all(child.name, recursive=False) + xpath_tags.append(child.name) + xpath_subscripts.append( + 0 if 1 == len(siblings) else next(i for i, s in enumerate(siblings, 1) if s is child) + ) + child = parent + xpath_tags.reverse() + xpath_subscripts.reverse() + return xpath_tags, xpath_subscripts + + def get_three_from_single(self, html_string): + html_code = BeautifulSoup(html_string, "html.parser") + + all_doc_strings = [] + string2xtag_seq = [] + string2xsubs_seq = [] + + for element in html_code.descendants: + if isinstance(element, bs4.element.NavigableString): + if type(element.parent) is not bs4.element.Tag: + continue + + text_in_this_tag = html.unescape(element).strip() + if not text_in_this_tag: + continue + + all_doc_strings.append(text_in_this_tag) + + xpath_tags, xpath_subscripts = self.xpath_soup(element) + string2xtag_seq.append(xpath_tags) + string2xsubs_seq.append(xpath_subscripts) + + if len(all_doc_strings) != len(string2xtag_seq): + raise ValueError("Number of doc strings and xtags does not correspond") + if len(all_doc_strings) != len(string2xsubs_seq): + raise ValueError("Number of doc strings and xsubs does not correspond") + + return all_doc_strings, string2xtag_seq, string2xsubs_seq + + def construct_xpath(self, xpath_tags, xpath_subscripts): + xpath = "" + for tagname, subs in zip(xpath_tags, xpath_subscripts): + xpath += f"/{tagname}" + if subs != 0: + xpath += f"[{subs}]" + return xpath + + def __call__(self, html_strings) -> BatchFeature: + """ + Main method to prepare for the model one or several HTML strings. + + Args: + html_strings (`str`, `list[str]`): + The HTML string or batch of HTML strings from which to extract nodes and corresponding xpaths. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **nodes** -- Nodes. + - **xpaths** -- Corresponding xpaths. + + Examples: + + ```python + >>> from transformers import MarkupLMFeatureExtractor + + >>> page_name_1 = "page1.html" + >>> page_name_2 = "page2.html" + >>> page_name_3 = "page3.html" + + >>> with open(page_name_1) as f: + ... single_html_string = f.read() + + >>> feature_extractor = MarkupLMFeatureExtractor() + + >>> # single example + >>> encoding = feature_extractor(single_html_string) + >>> print(encoding.keys()) + >>> # dict_keys(['nodes', 'xpaths']) + + >>> # batched example + + >>> multi_html_strings = [] + + >>> with open(page_name_2) as f: + ... multi_html_strings.append(f.read()) + >>> with open(page_name_3) as f: + ... multi_html_strings.append(f.read()) + + >>> encoding = feature_extractor(multi_html_strings) + >>> print(encoding.keys()) + >>> # dict_keys(['nodes', 'xpaths']) + ```""" + + # Input type checking for clearer error + valid_strings = False + + # Check that strings has a valid type + if isinstance(html_strings, str): + valid_strings = True + elif isinstance(html_strings, (list, tuple)): + if len(html_strings) == 0 or isinstance(html_strings[0], str): + valid_strings = True + + if not valid_strings: + raise ValueError( + "HTML strings must of type `str`, `list[str]` (batch of examples), " + f"but is of type {type(html_strings)}." + ) + + is_batched = isinstance(html_strings, (list, tuple)) and (isinstance(html_strings[0], str)) + + if not is_batched: + html_strings = [html_strings] + + # Get nodes + xpaths + nodes = [] + xpaths = [] + for html_string in html_strings: + all_doc_strings, string2xtag_seq, string2xsubs_seq = self.get_three_from_single(html_string) + nodes.append(all_doc_strings) + xpath_strings = [] + for node, tag_list, sub_list in zip(all_doc_strings, string2xtag_seq, string2xsubs_seq): + xpath_string = self.construct_xpath(tag_list, sub_list) + xpath_strings.append(xpath_string) + xpaths.append(xpath_strings) + + # return as Dict + data = {"nodes": nodes, "xpaths": xpaths} + encoded_inputs = BatchFeature(data=data, tensor_type=None) + + return encoded_inputs + + +__all__ = ["MarkupLMFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/markuplm/modeling_markuplm.py b/venv/lib/python3.13/site-packages/transformers/models/markuplm/modeling_markuplm.py new file mode 100644 index 0000000000000000000000000000000000000000..78fbf8f215aad9f6753bb6cf3001a1c7563da65d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/markuplm/modeling_markuplm.py @@ -0,0 +1,1051 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research Asia and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MarkupLM model.""" + +import os +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, can_return_tuple, logging +from .configuration_markuplm import MarkupLMConfig + + +logger = logging.get_logger(__name__) + + +class XPathEmbeddings(nn.Module): + """Construct the embeddings from xpath tags and subscripts. + + We drop tree-id in this version, as its info can be covered by xpath. + """ + + def __init__(self, config): + super().__init__() + self.max_depth = config.max_depth + + self.xpath_unitseq2_embeddings = nn.Linear(config.xpath_unit_hidden_size * self.max_depth, config.hidden_size) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.activation = nn.ReLU() + self.xpath_unitseq2_inner = nn.Linear(config.xpath_unit_hidden_size * self.max_depth, 4 * config.hidden_size) + self.inner2emb = nn.Linear(4 * config.hidden_size, config.hidden_size) + + self.xpath_tag_sub_embeddings = nn.ModuleList( + [ + nn.Embedding(config.max_xpath_tag_unit_embeddings, config.xpath_unit_hidden_size) + for _ in range(self.max_depth) + ] + ) + + self.xpath_subs_sub_embeddings = nn.ModuleList( + [ + nn.Embedding(config.max_xpath_subs_unit_embeddings, config.xpath_unit_hidden_size) + for _ in range(self.max_depth) + ] + ) + + def forward(self, xpath_tags_seq=None, xpath_subs_seq=None): + xpath_tags_embeddings = [] + xpath_subs_embeddings = [] + + for i in range(self.max_depth): + xpath_tags_embeddings.append(self.xpath_tag_sub_embeddings[i](xpath_tags_seq[:, :, i])) + xpath_subs_embeddings.append(self.xpath_subs_sub_embeddings[i](xpath_subs_seq[:, :, i])) + + xpath_tags_embeddings = torch.cat(xpath_tags_embeddings, dim=-1) + xpath_subs_embeddings = torch.cat(xpath_subs_embeddings, dim=-1) + + xpath_embeddings = xpath_tags_embeddings + xpath_subs_embeddings + + xpath_embeddings = self.inner2emb(self.dropout(self.activation(self.xpath_unitseq2_inner(xpath_embeddings)))) + + return xpath_embeddings + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +class MarkupLMEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + self.max_depth = config.max_depth + + self.xpath_embeddings = XPathEmbeddings(config) + + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings.create_position_ids_from_inputs_embeds + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + def forward( + self, + input_ids=None, + xpath_tags_seq=None, + xpath_subs_seq=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # prepare xpath seq + if xpath_tags_seq is None: + xpath_tags_seq = self.config.tag_pad_id * torch.ones( + tuple(list(input_shape) + [self.max_depth]), dtype=torch.long, device=device + ) + if xpath_subs_seq is None: + xpath_subs_seq = self.config.subs_pad_id * torch.ones( + tuple(list(input_shape) + [self.max_depth]), dtype=torch.long, device=device + ) + + words_embeddings = inputs_embeds + position_embeddings = self.position_embeddings(position_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + xpath_embeddings = self.xpath_embeddings(xpath_tags_seq, xpath_subs_seq) + embeddings = words_embeddings + position_embeddings + token_type_embeddings + xpath_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->MarkupLM +class MarkupLMSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class MarkupLMIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->MarkupLM +class MarkupLMOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class MarkupLMPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->MarkupLM +class MarkupLMPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->MarkupLM +class MarkupLMLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = MarkupLMPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->MarkupLM +class MarkupLMOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = MarkupLMLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM +class MarkupLMSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) + + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM +class MarkupLMAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = MarkupLMSelfAttention(config) + self.output = MarkupLMSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM +class MarkupLMLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = MarkupLMAttention(config) + self.intermediate = MarkupLMIntermediate(config) + self.output = MarkupLMOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM +class MarkupLMEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class MarkupLMPreTrainedModel(PreTrainedModel): + config: MarkupLMConfig + base_model_prefix = "markuplm" + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->MarkupLM + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, MarkupLMLMPredictionHead): + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +@auto_docstring +class MarkupLMModel(MarkupLMPreTrainedModel): + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->MarkupLM + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = MarkupLMEmbeddings(config) + self.encoder = MarkupLMEncoder(config) + + self.pooler = MarkupLMPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + xpath_tags_seq: Optional[torch.LongTensor] = None, + xpath_subs_seq: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Tag IDs for each token in the input sequence, padded up to config.max_depth. + xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Subscript IDs for each token in the input sequence, padded up to config.max_depth. + + Examples: + + ```python + >>> from transformers import AutoProcessor, MarkupLMModel + + >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base") + >>> model = MarkupLMModel.from_pretrained("microsoft/markuplm-base") + + >>> html_string = " Page Title " + + >>> encoding = processor(html_string, return_tensors="pt") + + >>> outputs = model(**encoding) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 4, 768] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids=input_ids, + xpath_tags_seq=xpath_tags_seq, + xpath_subs_seq=xpath_subs_seq, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with bert->markuplm, Bert->MarkupLM + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.markuplm = MarkupLMModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + xpath_tags_seq: Optional[torch.Tensor] = None, + xpath_subs_seq: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Tag IDs for each token in the input sequence, padded up to config.max_depth. + xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Subscript IDs for each token in the input sequence, padded up to config.max_depth. + + Examples: + + ```python + >>> from transformers import AutoProcessor, MarkupLMForQuestionAnswering + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base-finetuned-websrc") + >>> model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-base-finetuned-websrc") + + >>> html_string = " My name is Niels " + >>> question = "What's his name?" + + >>> encoding = processor(html_string, questions=question, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**encoding) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1] + >>> processor.decode(predict_answer_tokens).strip() + 'Niels' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.markuplm( + input_ids, + xpath_tags_seq=xpath_tags_seq, + xpath_subs_seq=xpath_subs_seq, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + MarkupLM Model with a `token_classification` head on top. + """ +) +class MarkupLMForTokenClassification(MarkupLMPreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with bert->markuplm, Bert->MarkupLM + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.markuplm = MarkupLMModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + xpath_tags_seq: Optional[torch.Tensor] = None, + xpath_subs_seq: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Tag IDs for each token in the input sequence, padded up to config.max_depth. + xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Subscript IDs for each token in the input sequence, padded up to config.max_depth. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModelForTokenClassification + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base") + >>> processor.parse_html = False + >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/markuplm-base", num_labels=7) + + >>> nodes = ["hello", "world"] + >>> xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"] + >>> node_labels = [1, 2] + >>> encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**encoding) + + >>> loss = outputs.loss + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.markuplm( + input_ids, + xpath_tags_seq=xpath_tags_seq, + xpath_subs_seq=xpath_subs_seq, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + sequence_output = outputs[0] + prediction_scores = self.classifier(sequence_output) # (batch_size, seq_length, node_type_size) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct( + prediction_scores.view(-1, self.config.num_labels), + labels.view(-1), + ) + + return TokenClassifierOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + MarkupLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with bert->markuplm, Bert->MarkupLM + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.markuplm = MarkupLMModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + xpath_tags_seq: Optional[torch.Tensor] = None, + xpath_subs_seq: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Tag IDs for each token in the input sequence, padded up to config.max_depth. + xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): + Subscript IDs for each token in the input sequence, padded up to config.max_depth. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModelForSequenceClassification + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base") + >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/markuplm-base", num_labels=7) + + >>> html_string = " Page Title " + >>> encoding = processor(html_string, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**encoding) + + >>> loss = outputs.loss + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.markuplm( + input_ids, + xpath_tags_seq=xpath_tags_seq, + xpath_subs_seq=xpath_subs_seq, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "MarkupLMForQuestionAnswering", + "MarkupLMForSequenceClassification", + "MarkupLMForTokenClassification", + "MarkupLMModel", + "MarkupLMPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/markuplm/processing_markuplm.py b/venv/lib/python3.13/site-packages/transformers/models/markuplm/processing_markuplm.py new file mode 100644 index 0000000000000000000000000000000000000000..c6208b07bbea4a48fee3577a509315bda5db5db9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/markuplm/processing_markuplm.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for MarkupLM. +""" + +from typing import Optional, Union + +from ...file_utils import TensorType +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TruncationStrategy + + +class MarkupLMProcessor(ProcessorMixin): + r""" + Constructs a MarkupLM processor which combines a MarkupLM feature extractor and a MarkupLM tokenizer into a single + processor. + + [`MarkupLMProcessor`] offers all the functionalities you need to prepare data for the model. + + It first uses [`MarkupLMFeatureExtractor`] to extract nodes and corresponding xpaths from one or more HTML strings. + Next, these are provided to [`MarkupLMTokenizer`] or [`MarkupLMTokenizerFast`], which turns them into token-level + `input_ids`, `attention_mask`, `token_type_ids`, `xpath_tags_seq` and `xpath_subs_seq`. + + Args: + feature_extractor (`MarkupLMFeatureExtractor`): + An instance of [`MarkupLMFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`MarkupLMTokenizer` or `MarkupLMTokenizerFast`): + An instance of [`MarkupLMTokenizer`] or [`MarkupLMTokenizerFast`]. The tokenizer is a required input. + parse_html (`bool`, *optional*, defaults to `True`): + Whether or not to use `MarkupLMFeatureExtractor` to parse HTML strings into nodes and corresponding xpaths. + """ + + feature_extractor_class = "MarkupLMFeatureExtractor" + tokenizer_class = ("MarkupLMTokenizer", "MarkupLMTokenizerFast") + parse_html = True + + def __call__( + self, + html_strings=None, + nodes=None, + xpaths=None, + node_labels=None, + questions=None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchEncoding: + """ + This method first forwards the `html_strings` argument to [`~MarkupLMFeatureExtractor.__call__`]. Next, it + passes the `nodes` and `xpaths` along with the additional arguments to [`~MarkupLMTokenizer.__call__`] and + returns the output. + + Optionally, one can also provide a `text` argument which is passed along as first sequence. + + Please refer to the docstring of the above two methods for more information. + """ + # first, create nodes and xpaths + if self.parse_html: + if html_strings is None: + raise ValueError("Make sure to pass HTML strings in case `parse_html` is set to `True`") + + if nodes is not None or xpaths is not None or node_labels is not None: + raise ValueError( + "Please don't pass nodes, xpaths nor node labels in case `parse_html` is set to `True`" + ) + + features = self.feature_extractor(html_strings) + nodes = features["nodes"] + xpaths = features["xpaths"] + else: + if html_strings is not None: + raise ValueError("You have passed HTML strings but `parse_html` is set to `False`.") + if nodes is None or xpaths is None: + raise ValueError("Make sure to pass nodes and xpaths in case `parse_html` is set to `False`") + + # # second, apply the tokenizer + if questions is not None and self.parse_html: + if isinstance(questions, str): + questions = [questions] # add batch dimension (as the feature extractor always adds a batch dimension) + + encoded_inputs = self.tokenizer( + text=questions if questions is not None else nodes, + text_pair=nodes if questions is not None else None, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs + + +__all__ = ["MarkupLMProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/markuplm/tokenization_markuplm.py b/venv/lib/python3.13/site-packages/transformers/models/markuplm/tokenization_markuplm.py new file mode 100644 index 0000000000000000000000000000000000000000..a090e11ec36dff560134c91ed8053423fe35b7bb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/markuplm/tokenization_markuplm.py @@ -0,0 +1,1470 @@ +# coding=utf-8 +# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for MarkupLM.""" + +import json +import os +from functools import lru_cache +from typing import Optional, Union + +import regex as re + +from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to + `None`, this will use the predefined model maximum length if a maximum length is required by one of the + truncation/padding parameters. If the model has no specific maximum input length (like XLNet) + truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + + +@lru_cache +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. The reversible bpe codes work on unicode strings. This means you need a large # + of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset + you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe + vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class MarkupLMTokenizer(PreTrainedTokenizer): + r""" + Construct a MarkupLM tokenizer. Based on byte-level Byte-Pair-Encoding (BPE). [`MarkupLMTokenizer`] can be used to + turn HTML strings into to token-level `input_ids`, `attention_mask`, `token_type_ids`, `xpath_tags_seq` and + `xpath_tags_seq`. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + merges_file, + tags_dict, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + max_depth=50, + max_width=1000, + pad_width=1001, + pad_token_label=-100, + only_label_first_subword=True, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + + self.tags_dict = tags_dict + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + # additional properties + self.max_depth = max_depth + self.max_width = max_width + self.pad_width = pad_width + self.unk_tag_id = len(self.tags_dict) + self.pad_tag_id = self.unk_tag_id + 1 + self.pad_xpath_tags_seq = [self.pad_tag_id] * self.max_depth + self.pad_xpath_subs_seq = [self.pad_width] * self.max_depth + + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + tags_dict=tags_dict, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + max_depth=max_depth, + max_width=max_width, + pad_width=pad_width, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + **kwargs, + ) + + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + def get_xpath_seq(self, xpath): + """ + Given the xpath expression of one particular node (like "/html/body/div/li[1]/div/span[2]"), return a list of + tag IDs and corresponding subscripts, taking into account max depth. + """ + xpath_tags_list = [] + xpath_subs_list = [] + + xpath_units = xpath.split("/") + for unit in xpath_units: + if not unit.strip(): + continue + name_subs = unit.strip().split("[") + tag_name = name_subs[0] + sub = 0 if len(name_subs) == 1 else int(name_subs[1][:-1]) + xpath_tags_list.append(self.tags_dict.get(tag_name, self.unk_tag_id)) + xpath_subs_list.append(min(self.max_width, sub)) + + xpath_tags_list = xpath_tags_list[: self.max_depth] + xpath_subs_list = xpath_subs_list[: self.max_depth] + xpath_tags_list += [self.pad_tag_id] * (self.max_depth - len(xpath_tags_list)) + xpath_subs_list += [self.pad_width] * (self.max_depth - len(xpath_subs_list)) + + return xpath_tags_list, xpath_subs_list + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + vocab = self.encoder.copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + logger.warning( + "MarkupLM now does not support generative tasks, decoding is experimental and subject to change." + ) + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + # save vocab_file + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + # save merge_file + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RoBERTa sequence has the following format: + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def build_xpath_tags_with_special_tokens( + self, xpath_tags_0: list[int], xpath_tags_1: Optional[list[int]] = None + ) -> list[int]: + pad = [self.pad_xpath_tags_seq] + if len(xpath_tags_1) == 0: + return pad + xpath_tags_0 + pad + return pad + xpath_tags_0 + pad + xpath_tags_1 + pad + + def build_xpath_subs_with_special_tokens( + self, xpath_subs_0: list[int], xpath_subs_1: Optional[list[int]] = None + ) -> list[int]: + pad = [self.pad_xpath_subs_seq] + if len(xpath_subs_1) == 0: + return pad + xpath_subs_0 + pad + return pad + xpath_subs_0 + pad + xpath_subs_1 + pad + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Args: + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not + make use of token type ids, therefore a list of zeros is returned. + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0] + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None, + xpaths: Optional[Union[list[list[int]], list[list[list[int]]]]] = None, + node_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with node-level xpaths and optional labels. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (nodes of a single example or questions of a batch of examples) or a list of list of strings (batch of + nodes). + text_pair (`list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + xpaths (`list[list[int]]`, `list[list[list[int]]]`): + Node-level xpaths. + node_labels (`list[int]`, `list[list[int]]`, *optional*): + Node-level integer labels (for token classification tasks). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = nodes + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `list[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "Nodes must be of type `list[str]` (single pretokenized example), " + "or `list[list[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be nodes + if not isinstance(text, (list, tuple)): + raise ValueError( + "Nodes must be of type `list[str]` (single pretokenized example), " + "or `list[list[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + nodes = text if text_pair is None else text_pair + assert xpaths is not None, "You must provide corresponding xpaths" + if is_batched: + assert len(nodes) == len(xpaths), "You must provide nodes and xpaths for an equal amount of examples" + for nodes_example, xpaths_example in zip(nodes, xpaths): + assert len(nodes_example) == len(xpaths_example), "You must provide as many nodes as there are xpaths" + else: + assert len(nodes) == len(xpaths), "You must provide as many nodes as there are xpaths" + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + xpaths: Optional[list[list[list[int]]]] = None, + node_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + xpaths: Optional[list[list[list[int]]]] = None, + node_labels: Optional[list[list[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + batch_outputs = self._batch_prepare_for_model( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_text_or_text_pairs, + is_pair: Optional[bool] = None, + xpaths: Optional[list[list[int]]] = None, + node_labels: Optional[list[list[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens. + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for idx, example in enumerate(zip(batch_text_or_text_pairs, xpaths)): + batch_text_or_text_pair, xpaths_example = example + outputs = self.prepare_for_model( + batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair, + batch_text_or_text_pair[1] if is_pair else None, + xpaths_example, + node_labels=node_labels[idx] if node_labels is not None else None, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING) + def encode( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + xpaths: Optional[list[list[int]]] = None, + node_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> list[int]: + encoded_inputs = self.encode_plus( + text=text, + text_pair=text_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + xpaths: Optional[list[list[int]]] = None, + node_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated, + `__call__` should be used instead. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`list[str]` or `list[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (nodes of a single example) or a + list of list of strings (nodes of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + xpaths=xpaths, + text_pair=text_pair, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + xpaths: Optional[list[list[int]]] = None, + node_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + return self.prepare_for_model( + text=text, + text_pair=text_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + xpaths: Optional[list[list[int]]] = None, + node_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens, + truncates sequences if overflowing while taking into account the special tokens and manages a moving window + (with user defined stride) for overflowing tokens. Please Note, for *text_pair* different than `None` and + *truncation_strategy = longest_first* or `True`, it is not possible to return overflowing tokens. Such a + combination of arguments will raise an error. + + Node-level `xpaths` are turned into token-level `xpath_tags_seq` and `xpath_subs_seq`. If provided, node-level + `node_labels` are turned into token-level `labels`. The node label is used for the first token of the node, + while remaining tokens are labeled with -100, such that they will be ignored by the loss function. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`list[str]` or `list[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (nodes of a single example) or a + list of list of strings (nodes of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + tokens = [] + pair_tokens = [] + xpath_tags_seq = [] + xpath_subs_seq = [] + pair_xpath_tags_seq = [] + pair_xpath_subs_seq = [] + labels = [] + + if text_pair is None: + if node_labels is None: + # CASE 1: web page classification (training + inference) + CASE 2: token classification (inference) + for word, xpath in zip(text, xpaths): + if len(word) < 1: # skip empty nodes + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + xpath_tags_list, xpath_subs_list = self.get_xpath_seq(xpath) + xpath_tags_seq.extend([xpath_tags_list] * len(word_tokens)) + xpath_subs_seq.extend([xpath_subs_list] * len(word_tokens)) + else: + # CASE 2: token classification (training) + for word, xpath, label in zip(text, xpaths, node_labels): + if len(word) < 1: # skip empty nodes + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + xpath_tags_list, xpath_subs_list = self.get_xpath_seq(xpath) + xpath_tags_seq.extend([xpath_tags_list] * len(word_tokens)) + xpath_subs_seq.extend([xpath_subs_list] * len(word_tokens)) + if self.only_label_first_subword: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1)) + else: + labels.extend([label] * len(word_tokens)) + else: + # CASE 3: web page question answering (inference) + # text = question + # text_pair = nodes + tokens = self.tokenize(text) + xpath_tags_seq = [self.pad_xpath_tags_seq for _ in range(len(tokens))] + xpath_subs_seq = [self.pad_xpath_subs_seq for _ in range(len(tokens))] + + for word, xpath in zip(text_pair, xpaths): + if len(word) < 1: # skip empty nodes + continue + word_tokens = self.tokenize(word) + pair_tokens.extend(word_tokens) + xpath_tags_list, xpath_subs_list = self.get_xpath_seq(xpath) + pair_xpath_tags_seq.extend([xpath_tags_list] * len(word_tokens)) + pair_xpath_subs_seq.extend([xpath_subs_list] * len(word_tokens)) + + # Create ids + pair_ids + ids = self.convert_tokens_to_ids(tokens) + pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None + + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Compute the total size of the returned encodings + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + overflowing_xpath_tags_seq = [] + overflowing_xpath_subs_seq = [] + overflowing_labels = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ( + ids, + xpath_tags_seq, + xpath_subs_seq, + pair_ids, + pair_xpath_tags_seq, + pair_xpath_subs_seq, + labels, + overflowing_tokens, + overflowing_xpath_tags_seq, + overflowing_xpath_subs_seq, + overflowing_labels, + ) = self.truncate_sequences( + ids, + xpath_tags_seq=xpath_tags_seq, + xpath_subs_seq=xpath_subs_seq, + pair_ids=pair_ids, + pair_xpath_tags_seq=pair_xpath_tags_seq, + pair_xpath_subs_seq=pair_xpath_subs_seq, + labels=labels, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["overflowing_xpath_tags_seq"] = overflowing_xpath_tags_seq + encoded_inputs["overflowing_xpath_subs_seq"] = overflowing_xpath_subs_seq + encoded_inputs["overflowing_labels"] = overflowing_labels + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + xpath_tags_ids = self.build_xpath_tags_with_special_tokens(xpath_tags_seq, pair_xpath_tags_seq) + xpath_subs_ids = self.build_xpath_subs_with_special_tokens(xpath_subs_seq, pair_xpath_subs_seq) + if labels: + labels = [self.pad_token_label] + labels + [self.pad_token_label] + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + xpath_tags_ids = xpath_tags_seq + pair_xpath_tags_seq if pair else xpath_tags_seq + xpath_subs_ids = xpath_subs_seq + pair_xpath_subs_seq if pair else xpath_subs_seq + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + encoded_inputs["xpath_tags_seq"] = xpath_tags_ids + encoded_inputs["xpath_subs_seq"] = xpath_subs_ids + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + if labels: + encoded_inputs["labels"] = labels + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def truncate_sequences( + self, + ids: list[int], + xpath_tags_seq: list[list[int]], + xpath_subs_seq: list[list[int]], + pair_ids: Optional[list[int]] = None, + pair_xpath_tags_seq: Optional[list[list[int]]] = None, + pair_xpath_subs_seq: Optional[list[list[int]]] = None, + labels: Optional[list[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> tuple[list[int], list[int], list[int]]: + """ + Args: + Truncates a sequence pair in-place following the strategy. + ids (`list[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + xpath_tags_seq (`list[list[int]]`): + XPath tag IDs of the first sequence. + xpath_subs_seq (`list[list[int]]`): + XPath sub IDs of the first sequence. + pair_ids (`list[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + pair_xpath_tags_seq (`list[list[int]]`, *optional*): + XPath tag IDs of the second sequence. + pair_xpath_subs_seq (`list[list[int]]`, *optional*): + XPath sub IDs of the second sequence. + num_tokens_to_remove (`int`, *optional*, defaults to 0): + Number of tokens to remove using the truncation strategy. + truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to + `False`): + The strategy to follow for truncation. Can be: + - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a + batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater + than the model maximum admissible input size). + stride (`int`, *optional*, defaults to 0): + If set to a positive number, the overflowing tokens returned will contain some tokens from the main + sequence returned. The value of this argument defines the number of additional tokens. + Returns: + `tuple[list[int], list[int], list[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair + of sequences (or a batch of pairs) is provided. + """ + if num_tokens_to_remove <= 0: + return ids, xpath_tags_seq, xpath_subs_seq, pair_ids, pair_xpath_tags_seq, pair_xpath_subs_seq, [], [], [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + overflowing_xpath_tags_seq = [] + overflowing_xpath_subs_seq = [] + overflowing_labels = [] + if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( + truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None + ): + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + overflowing_xpath_tags_seq = xpath_tags_seq[-window_len:] + overflowing_xpath_subs_seq = xpath_subs_seq[-window_len:] + ids = ids[:-num_tokens_to_remove] + xpath_tags_seq = xpath_tags_seq[:-num_tokens_to_remove] + xpath_subs_seq = xpath_subs_seq[:-num_tokens_to_remove] + labels = labels[:-num_tokens_to_remove] + else: + error_msg = ( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the first sequence has a length {len(ids)}. " + ) + if truncation_strategy == TruncationStrategy.ONLY_FIRST: + error_msg = ( + error_msg + "Please select another truncation strategy than " + f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." + ) + logger.error(error_msg) + elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: + logger.warning( + "Be aware, overflowing tokens are not returned for the setting you have chosen," + f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' " + "truncation strategy. So the returned list will always be empty even if some " + "tokens have been removed." + ) + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + ids = ids[:-1] + xpath_tags_seq = xpath_tags_seq[:-1] + xpath_subs_seq = xpath_subs_seq[:-1] + labels = labels[:-1] + else: + pair_ids = pair_ids[:-1] + pair_xpath_tags_seq = pair_xpath_tags_seq[:-1] + pair_xpath_subs_seq = pair_xpath_subs_seq[:-1] + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + overflowing_xpath_tags_seq = pair_xpath_tags_seq[-window_len:] + overflowing_xpath_subs_seq = pair_xpath_subs_seq[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + pair_xpath_tags_seq = pair_xpath_tags_seq[:-num_tokens_to_remove] + pair_xpath_subs_seq = pair_xpath_subs_seq[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_first'." + ) + + return ( + ids, + xpath_tags_seq, + xpath_subs_seq, + pair_ids, + pair_xpath_tags_seq, + pair_xpath_subs_seq, + labels, + overflowing_tokens, + overflowing_xpath_tags_seq, + overflowing_xpath_subs_seq, + overflowing_labels, + ) + + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Args: + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + encoded_inputs: + Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "xpath_tags_seq" in encoded_inputs: + encoded_inputs["xpath_tags_seq"] = ( + encoded_inputs["xpath_tags_seq"] + [self.pad_xpath_tags_seq] * difference + ) + if "xpath_subs_seq" in encoded_inputs: + encoded_inputs["xpath_subs_seq"] = ( + encoded_inputs["xpath_subs_seq"] + [self.pad_xpath_subs_seq] * difference + ) + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "xpath_tags_seq" in encoded_inputs: + encoded_inputs["xpath_tags_seq"] = [self.pad_xpath_tags_seq] * difference + encoded_inputs[ + "xpath_tags_seq" + ] + if "xpath_subs_seq" in encoded_inputs: + encoded_inputs["xpath_subs_seq"] = [self.pad_xpath_subs_seq] * difference + encoded_inputs[ + "xpath_subs_seq" + ] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + +__all__ = ["MarkupLMTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/markuplm/tokenization_markuplm_fast.py b/venv/lib/python3.13/site-packages/transformers/models/markuplm/tokenization_markuplm_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..4033ef319ff8cbec11f6c29d075e59b741710513 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/markuplm/tokenization_markuplm_fast.py @@ -0,0 +1,929 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fast tokenization class for MarkupLM. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus +and _encode_plus, in which the Rust tokenizer is used. +""" + +import json +from functools import lru_cache +from typing import Optional, Union + +from tokenizers import processors + +from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_markuplm import MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, MarkupLMTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +@lru_cache +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. The reversible bpe codes work on unicode strings. This means you need a large # + of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset + you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe + vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class MarkupLMTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a MarkupLM tokenizer. Based on byte-level Byte-Pair-Encoding (BPE). + + [`MarkupLMTokenizerFast`] can be used to turn HTML strings into to token-level `input_ids`, `attention_mask`, + `token_type_ids`, `xpath_tags_seq` and `xpath_tags_seq`. This tokenizer inherits from [`PreTrainedTokenizer`] which + contains most of the main methods. + + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = MarkupLMTokenizer + + def __init__( + self, + vocab_file, + merges_file, + tags_dict, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + max_depth=50, + max_width=1000, + pad_width=1001, + pad_token_label=-100, + only_label_first_subword=True, + trim_offsets=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + tags_dict=tags_dict, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + max_depth=max_depth, + max_width=max_width, + pad_width=pad_width, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + **kwargs, + ) + if trim_offsets: + # Not implemented yet, because we need to chain two post processors which is not possible yet + # We need to wait for https://github.com/huggingface/tokenizers/pull/1005 + # With `trim_offsets=False` we don't need to do add `processors.ByteLevel(trim_offsets=False)` + # because it's not doing anything + raise NotImplementedError( + "`trim_offsets=True` is not implemented for MarkupLMTokenizerFast. Please set it to False." + ) + + self.tags_dict = tags_dict + + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + # additional properties + self.max_depth = max_depth + self.max_width = max_width + self.pad_width = pad_width + self.unk_tag_id = len(self.tags_dict) + self.pad_tag_id = self.unk_tag_id + 1 + self.pad_xpath_tags_seq = [self.pad_tag_id] * self.max_depth + self.pad_xpath_subs_seq = [self.pad_width] * self.max_depth + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + def get_xpath_seq(self, xpath): + """ + Given the xpath expression of one particular node (like "/html/body/div/li[1]/div/span[2]"), return a list of + tag IDs and corresponding subscripts, taking into account max depth. + """ + xpath_tags_list = [] + xpath_subs_list = [] + + xpath_units = xpath.split("/") + for unit in xpath_units: + if not unit.strip(): + continue + name_subs = unit.strip().split("[") + tag_name = name_subs[0] + sub = 0 if len(name_subs) == 1 else int(name_subs[1][:-1]) + xpath_tags_list.append(self.tags_dict.get(tag_name, self.unk_tag_id)) + xpath_subs_list.append(min(self.max_width, sub)) + + xpath_tags_list = xpath_tags_list[: self.max_depth] + xpath_subs_list = xpath_subs_list[: self.max_depth] + xpath_tags_list += [self.pad_tag_id] * (self.max_depth - len(xpath_tags_list)) + xpath_subs_list += [self.pad_width] * (self.max_depth - len(xpath_subs_list)) + + return xpath_tags_list, xpath_subs_list + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None, + xpaths: Optional[Union[list[list[int]], list[list[list[int]]]]] = None, + node_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with nodes, xpaths and optional labels. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + xpaths (`list[list[int]]`, `list[list[list[int]]]`): + Node-level xpaths. Each bounding box should be normalized to be on a 0-1000 scale. + node_labels (`list[int]`, `list[list[int]]`, *optional*): + Node-level integer labels (for token classification tasks). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = nodes + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `list[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "Nodes must be of type `list[str]` (single pretokenized example), " + "or `list[list[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be nodes + if not isinstance(text, (list, tuple)): + raise ValueError( + "Nodes must be of type `list[str]` (single pretokenized example), " + "or `list[list[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + nodes = text if text_pair is None else text_pair + assert xpaths is not None, "You must provide corresponding xpaths" + if is_batched: + assert len(nodes) == len(xpaths), "You must provide nodes and xpaths for an equal amount of examples" + for nodes_example, xpaths_example in zip(nodes, xpaths): + assert len(nodes_example) == len(xpaths_example), "You must provide as many nodes as there are xpaths" + else: + assert len(nodes) == len(xpaths), "You must provide as many nodes as there are xpaths" + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + xpaths: Optional[list[list[list[int]]]] = None, + node_labels: Optional[Union[list[int], list[list[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + xpaths=xpaths, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]: + batched_input = [(text, pair)] if pair else [text] + encodings = self._tokenizer.encode_batch( + batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs + ) + + return encodings[0].tokens + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, MARKUPLM_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + xpaths: Optional[list[list[int]]] = None, + node_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated, + `__call__` should be used instead. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`list[str]` or `list[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + xpaths=xpaths, + text_pair=text_pair, + node_labels=node_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + xpaths: Optional[list[list[list[int]]]] = None, + node_labels: Optional[list[list[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + if not isinstance(batch_text_or_text_pairs, list): + raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})") + + # Set the truncation and padding strategy and restore the initial configuration + self.set_truncation_and_padding( + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + ) + + if is_pair: + batch_text_or_text_pairs = [([text], text_pair) for text, text_pair in batch_text_or_text_pairs] + + encodings = self._tokenizer.encode_batch( + batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + is_pretokenized=True, # we set this to True as MarkupLM always expects pretokenized inputs + ) + + # Convert encoding to dict + # `Tokens` is a tuple of (list[dict[str, list[list[int]]]] or list[dict[str, 2D-Tensor]], + # list[EncodingFast]) with nested dimensions corresponding to batch, overflows, sequence length + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=True + if node_labels is not None + else return_offsets_mapping, # we use offsets to create the labels + return_length=return_length, + verbose=verbose, + ) + for encoding in encodings + ] + + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0]: + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [] + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) + + # create the token-level xpaths tags and subscripts + xpath_tags_seq = [] + xpath_subs_seq = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + xpath_tags_seq_example = [] + xpath_subs_seq_example = [] + for id, sequence_id, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_encodings[batch_index].sequence_ids, + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if is_pair and sequence_id == 0: + xpath_tags_seq_example.append(self.pad_xpath_tags_seq) + xpath_subs_seq_example.append(self.pad_xpath_subs_seq) + else: + xpath_tags_list, xpath_subs_list = self.get_xpath_seq(xpaths[original_index][word_id]) + xpath_tags_seq_example.extend([xpath_tags_list]) + xpath_subs_seq_example.extend([xpath_subs_list]) + else: + if id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]: + xpath_tags_seq_example.append(self.pad_xpath_tags_seq) + xpath_subs_seq_example.append(self.pad_xpath_subs_seq) + else: + raise ValueError("Id not recognized") + xpath_tags_seq.append(xpath_tags_seq_example) + xpath_subs_seq.append(xpath_subs_seq_example) + + sanitized_tokens["xpath_tags_seq"] = xpath_tags_seq + sanitized_tokens["xpath_subs_seq"] = xpath_subs_seq + + # optionally, create the labels + if node_labels is not None: + labels = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + labels_example = [] + for id, offset, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_tokens["offset_mapping"][batch_index], + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if self.only_label_first_subword: + if offset[0] == 0: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels_example.append(node_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + else: + labels_example.append(node_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + labels.append(labels_example) + + sanitized_tokens["labels"] = labels + # finally, remove offsets if the user didn't want them + if not return_offsets_mapping: + del sanitized_tokens["offset_mapping"] + + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + xpaths: Optional[list[list[int]]] = None, + node_labels: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # make it a batched input + # 2 options: + # 1) only text, in case text must be a list of str + # 2) text + text_pair, in which case text = str and text_pair a list of str + batched_input = [(text, text_pair)] if text_pair else [text] + batched_xpaths = [xpaths] + batched_node_labels = [node_labels] if node_labels is not None else None + batched_output = self._batch_encode_plus( + batched_input, + is_pair=bool(text_pair is not None), + xpaths=batched_xpaths, + node_labels=batched_node_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + # Overflowing tokens are returned as a batch of output so we keep them in this case + if return_tensors is None and not return_overflowing_tokens: + batched_output = BatchEncoding( + { + key: value[0] if len(value) > 0 and isinstance(value[0], list) else value + for key, value in batched_output.items() + }, + batched_output.encodings, + ) + + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + + return batched_output + + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Args: + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + encoded_inputs: + Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "xpath_tags_seq" in encoded_inputs: + encoded_inputs["xpath_tags_seq"] = ( + encoded_inputs["xpath_tags_seq"] + [self.pad_xpath_tags_seq] * difference + ) + if "xpath_subs_seq" in encoded_inputs: + encoded_inputs["xpath_subs_seq"] = ( + encoded_inputs["xpath_subs_seq"] + [self.pad_xpath_subs_seq] * difference + ) + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "xpath_tags_seq" in encoded_inputs: + encoded_inputs["xpath_tags_seq"] = [self.pad_xpath_tags_seq] * difference + encoded_inputs[ + "xpath_tags_seq" + ] + if "xpath_subs_seq" in encoded_inputs: + encoded_inputs["xpath_subs_seq"] = [self.pad_xpath_subs_seq] * difference + encoded_inputs[ + "xpath_subs_seq" + ] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RoBERTa sequence has the following format: + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["MarkupLMTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mbart50/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mbart50/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cd8c28da631fdd44cc09458380527b6323d044 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mbart50/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_mbart50 import * + from .tokenization_mbart50_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mbart50/tokenization_mbart50.py b/venv/lib/python3.13/site-packages/transformers/models/mbart50/tokenization_mbart50.py new file mode 100644 index 0000000000000000000000000000000000000000..413beaa03a83e4eddb9eb5a050d85c347ecbe2b5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mbart50/tokenization_mbart50.py @@ -0,0 +1,359 @@ +# coding=utf-8 +# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"] # fmt: skip + + +@requires(backends=("sentencepiece",)) +class MBart50Tokenizer(PreTrainedTokenizer): + """ + Construct a MBart50 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + src_lang (`str`, *optional*): + A string representing the source language. + tgt_lang (`str`, *optional*): + A string representing the target language. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Examples: + + ```python + >>> from transformers import MBart50Tokenizer + + >>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") + >>> src_text = " UN Chief Says There Is No Military Solution in Syria" + >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt") + >>> # model(**model_inputs) should work + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: list[int] = [] + suffix_tokens: list[int] = [] + + def __init__( + self, + vocab_file, + src_lang=None, + tgt_lang=None, + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) or [] + kwargs["additional_special_tokens"] += [ + code for code in FAIRSEQ_LANGUAGE_CODES if code not in kwargs["additional_special_tokens"] + ] + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + + super().__init__( + src_lang=src_lang, + tgt_lang=tgt_lang, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + self._src_lang = src_lang if src_lang is not None else "en_XX" + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def vocab_size(self) -> int: + return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def __getstate__(self) -> dict: + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d: dict) -> None: + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def get_vocab(self) -> dict: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> list[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An MBART-50 sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `[src_lang_code] X [eos]` + - `labels`: (for decoder) `[tgt_lang_code] X [eos]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def prepare_seq2seq_batch( + self, + src_texts: list[str], + src_lang: str = "en_XX", + tgt_texts: Optional[list[str]] = None, + tgt_lang: str = "ro_RO", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang: str) -> None: + """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.lang_code_to_id[src_lang] + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None: + """Reset the special tokens to the target language setting. prefix=[tgt_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.lang_code_to_id[tgt_lang] + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] + + +__all__ = ["MBart50Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mbart50/tokenization_mbart50_fast.py b/venv/lib/python3.13/site-packages/transformers/models/mbart50/tokenization_mbart50_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..985b0929f87c5516a2edd8e02c6aaaa8926d496e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mbart50/tokenization_mbart50_fast.py @@ -0,0 +1,258 @@ +# coding=utf-8 +# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_mbart50 import MBart50Tokenizer +else: + MBart50Tokenizer = None + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"] # fmt: skip + + +class MBart50TokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" MBART tokenizer for mBART-50 (backed by HuggingFace's *tokenizers* library). Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + src_lang (`str`, *optional*): + A string representing the source language. + tgt_lang (`str`, *optional*): + A string representing the target language. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + + Examples: + + ```python + >>> from transformers import MBart50TokenizerFast + + >>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") + >>> src_text = " UN Chief Says There Is No Military Solution in Syria" + >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt") + >>> # model(**model_inputs) should work + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = MBart50Tokenizer + + prefix_tokens: list[int] = [] + suffix_tokens: list[int] = [] + + def __init__( + self, + vocab_file=None, + src_lang=None, + tgt_lang=None, + tokenizer_file=None, + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) or [] + kwargs["additional_special_tokens"] += [ + code for code in FAIRSEQ_LANGUAGE_CODES if code not in kwargs["additional_special_tokens"] + ] + + super().__init__( + vocab_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + self.vocab_file = vocab_file + + self.lang_code_to_id = { + lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES + } + + self._src_lang = src_lang if src_lang is not None else "en_XX" + self.tgt_lang = tgt_lang + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + self.set_src_lang_special_tokens(self._src_lang) + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. The special tokens depend on calling set_lang. + + An MBART-50 sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `[src_lang_code] X [eos]` + - `labels`: (for decoder) `[tgt_lang_code] X [eos]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def prepare_seq2seq_batch( + self, + src_texts: list[str], + src_lang: str = "en_XX", + tgt_texts: Optional[list[str]] = None, + tgt_lang: str = "ro_RO", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang: str) -> None: + """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.convert_tokens_to_ids(src_lang) + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None: + """Reset the special tokens to the target language setting. prefix=[src_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.convert_tokens_to_ids(tgt_lang) + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["MBart50TokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09d805dc1bf67d12cd7855d961a51d9461f03431 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_megatron_bert import * + from .modeling_megatron_bert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/configuration_megatron_bert.py b/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/configuration_megatron_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..1505388e2925788fddcf4a40613da1ca9d13a82d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/configuration_megatron_bert.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2021- NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MEGATRON_BERT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MegatronBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MegatronBertModel`]. It is used to instantiate a + MEGATRON_BERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MEGATRON_BERT + [nvidia/megatron-bert-uncased-345m](https://huggingface.co/nvidia/megatron-bert-uncased-345m) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 29056): + Vocabulary size of the MEGATRON_BERT model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`MegatronBertModel`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`MegatronBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + + Examples: + + ```python + >>> from transformers import MegatronBertConfig, MegatronBertModel + + >>> # Initializing a MEGATRON_BERT google-bert/bert-base-uncased style configuration + >>> configuration = MegatronBertConfig() + + >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration + >>> model = MegatronBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "megatron-bert" + + def __init__( + self, + vocab_size=29056, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + + +__all__ = ["MegatronBertConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/modeling_megatron_bert.py b/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/modeling_megatron_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a75c0f575aca45c38134054c27c135bc9691c13b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -0,0 +1,1649 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MegatronBERT model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_megatron_bert import MegatronBertConfig + + +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class MegatronBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + + # In Megatron, layer-norm is applied after the 1st dropout. + # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + # Megatron BERT moves that layer norm after the drop-out (and to each layer). + # embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert +class MegatronBertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in MegatronBertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +# Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to MegatronBertAttention below. +class MegatronBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return residual + hidden_states + + +# Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm. +class MegatronBertAttention(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.self = MegatronBertSelfAttention(config, layer_idx=layer_idx) + self.output = MegatronBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + ln_outputs = self.ln(hidden_states) + self_outputs = self.self( + ln_outputs, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->MegatronBert +class MegatronBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertOutput. Moved LayerNorm to MegatronBertLayer below. +class MegatronBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return input_tensor + hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertLayer. Added LayerNorm. +class MegatronBertLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = MegatronBertAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise TypeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = MegatronBertAttention(config, layer_idx=layer_idx) + self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.intermediate = MegatronBertIntermediate(config) + self.output = MegatronBertOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + return (layer_output,) + outputs + + def feed_forward_chunk(self, attention_output): + ln_output = self.ln(attention_output) + intermediate_output = self.intermediate(ln_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class MegatronBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([MegatronBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + + # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one + # is simply the final LN (Transformer's BERT has it attached to each hidden layer). + self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + output_attentions, + cache_position, + ) + + # Because we moved the layer-norm at the end of the hidden layer, we have non-normali- + # zed data here. If that's really needed, we must apply LN to match Transformer's BERT. + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Finalize the hidden states. + hidden_states = self.ln(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->MegatronBert +class MegatronBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->MegatronBert +class MegatronBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->MegatronBert +class MegatronBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = MegatronBertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->MegatronBert +class MegatronBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = MegatronBertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->MegatronBert +class MegatronBertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->MegatronBert +class MegatronBertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = MegatronBertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@auto_docstring +class MegatronBertPreTrainedModel(PreTrainedModel): + config: MegatronBertConfig + load_tf_weights = load_tf_weights_in_megatron_bert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, MegatronBertLMPredictionHead): + module.bias.data.zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`MegatronBertForPreTraining`]. + """ +) +# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert +class MegatronBertForPreTrainingOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + seq_relationship_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class MegatronBertModel(MegatronBertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = MegatronBertEmbeddings(config) + self.encoder = MegatronBertEncoder(config) + + self.pooler = MegatronBertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + MegatronBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a + `next sentence prediction (classification)` head. + """ +) +class MegatronBertForPreTraining(MegatronBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config, add_binary_head=True): + r""" + add_binary_head (`bool`, *optional*, defaults to `True`): + Whether or not to add a binary head. + """ + super().__init__(config) + + self.bert = MegatronBertModel(config) + self.cls = MegatronBertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + next_sentence_label: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MegatronBertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MegatronBertForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m") + >>> model = MegatronBertForPreTraining.from_pretrained("nvidia/megatron-bert-cased-345m") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return MegatronBertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + MegatronBert Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class MegatronBertForCausalLM(MegatronBertPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `MegatronBertForCausalLM` as a standalone, add `is_decoder=True.`") + + self.bert = MegatronBertModel(config, add_pooling_layer=False) + self.cls = MegatronBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> from transformers import AutoTokenizer, MegatronBertForCausalLM, MegatronBertConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m") + >>> model = MegatronBertForCausalLM.from_pretrained("nvidia/megatron-bert-cased-345m", is_decoder=True) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring +class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `MegatronBertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = MegatronBertModel(config, add_pooling_layer=False) + self.cls = MegatronBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@auto_docstring( + custom_intro=""" + MegatronBert Model with a `next sentence prediction (classification)` head on top. + """ +) +class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = MegatronBertModel(config) + self.cls = MegatronBertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MegatronBertForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m") + >>> model = MegatronBertForNextSentencePrediction.from_pretrained("nvidia/megatron-bert-cased-345m") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ```""" + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + MegatronBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = MegatronBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = MegatronBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MegatronBertForTokenClassification(MegatronBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = MegatronBertModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = MegatronBertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "MegatronBertForCausalLM", + "MegatronBertForMaskedLM", + "MegatronBertForMultipleChoice", + "MegatronBertForNextSentencePrediction", + "MegatronBertForPreTraining", + "MegatronBertForQuestionAnswering", + "MegatronBertForSequenceClassification", + "MegatronBertForTokenClassification", + "MegatronBertModel", + "MegatronBertPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/megatron_gpt2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/megatron_gpt2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/venv/lib/python3.13/site-packages/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py b/venv/lib/python3.13/site-packages/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py new file mode 100644 index 0000000000000000000000000000000000000000..ab866834a2c974469f6eed9693732f0526b8e84d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py @@ -0,0 +1,925 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib.util +import json +import os +import re +import sys +import types + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from packaging import version + +from transformers import AutoTokenizer, GPT2Config +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME +from transformers.utils import check_torch_load_is_safe + + +def add_checkpointing_args(parser): + parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository") + parser.add_argument( + "--convert_checkpoint_from_megatron_to_transformers", + action="store_true", + help=( + "If True, convert a Megatron checkpoint to a Transformers checkpoint. " + "If False, convert a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--load_path", + type=str, + required=True, + help="Path to the checkpoint to convert.", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to the converted checkpoint.", + ) + parser.add_argument("--print-checkpoint-structure", action="store_true") + return parser + + +def add_megatron_checkpoint_args(parser): + parser.add_argument( + "--target_tensor_model_parallel_size", + type=int, + default=1, + help=( + "The tensor model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_pipeline_model_parallel_size", + type=int, + default=1, + help=( + "The pipeline model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_data_parallel_size", + type=int, + default=1, + help=( + "The data parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_params_dtype", + type=str, + default="fp32", + help=( + "The dtype of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--make_vocab_size_divisible_by", + type=int, + default=128, + help=( + "Pad the vocab size to be divisible by this value. " + "This is added for computational efficiency reasons. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--use_distributed_optimizer", + action="store_true", + help=( + "If True, use the distributed optimizer. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + return parser + + +def add_transformers_checkpoint_args(parser): + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help=( + "The name of the pre-trained tokenizer to save. " + "If not None, the tokenizer will be saved. " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + parser.add_argument( + "--max_shard_size", + type=str, + default="10GB", + help=( + "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + + return parser + + +# The simple map of names for "automated" rules. +megatron_to_transformers = { + "attention.dense": ".attn.c_proj.", + "self_attention.dense": ".attn.c_proj.", + "mlp.dense_h_to_4h": ".mlp.c_fc.", + "mlp.dense_4h_to_h": ".mlp.c_proj.", +} +transformers_to_megatron = {v[1:-1]: k for k, v in megatron_to_transformers.items()} + +tensor_parallel_params = [ + # megatron-lm layers to merge across tp ranks + "self_attention.query_key_value.weight", + "self_attention.query_key_value.bias", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "mlp.dense_4h_to_h.weight", + # deprecated + "attention.query_key_value.weight", + "attention.query_key_value.bias", + "attention.dense.weight", + # transformers layers to split across tp ranks + "attn.c_attn.weight", + "attn.c_attn.bias", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_fc.bias", + "mlp.c_proj.weight", +] + + +def recursive_print(name, val, spaces=0): + """ + Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + name (str): the name of the current tensor parameter + val (Tuple(int)): the shape of the current tensor parameter + spaces (int): the number of spaces to print before the output for a nested structure + """ + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val: + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +def megatron_to_transformers_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] for compatibility with later versions + of NVIDIA Megatron-LM. The inverse operation is performed inside Megatron-LM to read checkpoints: + https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 If param is the weight tensor of the + self-attention block, the returned tensor will have to be transposed one more time to be read by HuggingFace GPT2. + This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def transformers_to_megatron_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM checkpoint versions. Input + is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version + 1.0 and [num_heads * num_splits * hidden_size, :] for version 2.0 and later. If param is the weight tensor of the + self-attention block, the param needs to be already transposed before calling this function. + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + # Input is [num_splits * num_heads * hidden_size, :] + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def merge_transformers_sharded_states(path, num_checkpoints): + """ + Merge sharded checkpoints from transformers into a single checkpoint. + + Args: + path (str): the path to the sharded checkpoints + num_checkpoints (int): the number of checkpoints to merge + """ + state_dict = {} + for i in range(1, num_checkpoints + 1): + checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin") + check_torch_load_is_safe() + current_chunk = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + state_dict.update(current_chunk) + return state_dict + + +def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): + """ + Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline + parallel size and pipeline parallel rank. + + Args: + args (argparse.Namespace): the arguments to the script + tp_size (int): the tensor parallel size + pp_size (int): the pipeline parallel size + pp_rank (int): the pipeline parallel rank + """ + tp_state_dicts = [] + for i in range(tp_size): + sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}" + for checkpoint_name in ["model_optim_rng.pt", "model_rng.pt"]: + checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name) + if os.path.isfile(checkpoint_path): + break + check_torch_load_is_safe() + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + tp_state_dicts.append(state_dict) + return tp_state_dicts + + +def get_element_from_dict_by_path(d, path): + """ + Get element from dictionary by path. If element is not present, recursively add empty dictionaries. + + Args: + d (dict): the dictionary to get the element from + path (list): the path to the element which is delimited by "." + """ + path = path.split(".") + for k in path: + if k not in d: + d[k] = {} + d = d[k] + return d + + +def convert_checkpoint_from_megatron_to_transformers(args): + """ + Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints + with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards + using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of + `convert_megatron_gpt2_checkpoint.py` + + Args: + args (argparse.Namespace): the arguments to the script + """ + # Load Megatron-LM checkpoint arguments from the state dict + sub_dirs = os.listdir(args.load_path) + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"] + for sub_dir in possible_sub_dirs: + if sub_dir in sub_dirs: + rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0] + rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name) + break + print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") + check_torch_load_is_safe() + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu", weights_only=True) + megatron_args = state_dict.get("args", None) + if megatron_args is None: + raise ValueError( + "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint instead of user having to" + " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron" + " arguments to use this utility." + ) + + # Create Transformers GPT2 config from Megatron-LM arguments + if megatron_args is not None: + if megatron_args.bias_gelu_fusion: + activation_function = "gelu_fast" + elif megatron_args.openai_gelu: + activation_function = "gelu_new" + else: + activation_function = "gelu" + else: + # in the very early days this used to be "gelu_new" + activation_function = "gelu_new" + vocab_size = ( + megatron_args.padded_vocab_size + if getattr(megatron_args, "orig_vocab_size", None) is None + else megatron_args.orig_vocab_size + ) + print(vocab_size) + + config = GPT2Config( + vocab_size=vocab_size, + n_positions=megatron_args.max_position_embeddings, + n_embd=megatron_args.hidden_size, + n_layer=megatron_args.num_layers, + n_head=megatron_args.num_attention_heads, + n_inner=megatron_args.ffn_hidden_size, + activation_function=activation_function, + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=vocab_size - 1, + eos_token_id=vocab_size - 1, + architectures=["GPT2LMHeadModel"], + ) + + output_state_dict = {} + + checkpoint_version = state_dict.get("checkpoint_version", 0.0) + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + dtype = torch.float32 + # The regex to extract layer names. + layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + + # Convert. + print("Converting") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) + + # Convert and store the position embeddings. + position_embeddings = get_element_from_dict_by_path( + tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight" + ) + output_state_dict["transformer.wpe.weight"] = position_embeddings.to(dtype) + + # Convert and store the word embeddings. + word_embeddings = torch.cat( + [ + get_element_from_dict_by_path( + tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight" + ) + for tp_rank in range(tp_size) + ], + dim=0, + ) + word_embeddings = word_embeddings[:vocab_size].to(dtype) + output_state_dict["transformer.wte.weight"] = word_embeddings + + # Transformer Layers + print("Converting transformer layers") + # The number of heads. + heads = config.n_head + # The hidden_size per head. + hidden_size_per_head = config.n_embd // config.n_head + n_positions = config.n_positions + num_layers = config.num_hidden_layers // pp_size + + for pp_rank in range(pp_size): + if pp_size > 0: + print(f"Converting pipeline parallel rank {pp_rank}") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank) + + # The transformer. + path = ( + "model.language_model.transformer" + if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model") + else "model.language_model.encoder" + ) + # Extract the layers. + for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items(): + # Match the name. + m = layer_re.match(key) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + layer_idx = int(m.group(1)) + pp_rank * num_layers + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # The name of the layer. + layer_name = f"transformer.h.{layer_idx}" + + if op_name + "." + weight_or_bias not in tensor_parallel_params: + params = val.to(dtype) + else: + dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h", "attention.dense"] else 0 + params = torch.cat( + [val] + + [ + get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + for tp_rank in range(1, tp_size) + ], + dim=dim, + ).to(dtype) + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layernorm"): + ln_name = "ln_1" if op_name.startswith("input") else "ln_2" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Transpose the QKV matrix. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "weight": + # Insert a tensor of 1x1xDxD bias. + causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=dtype)).view( + 1, 1, n_positions, n_positions + ) + output_state_dict[layer_name + ".attn.bias"] = causal_mask + + # Insert a "dummy" tensor for masked_bias. + masked_bias = torch.tensor(-1e4, dtype=dtype) + output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias + + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, + checkpoint_version, + 3, + heads, + hidden_size_per_head, + ) + # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. + out_val = out_val.transpose(0, 1).contiguous() + # Store. + output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val + + # Transpose the bias. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "bias": + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, checkpoint_version, 3, heads, hidden_size_per_head + ) + # Store. No change of shape. + output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val + + # Transpose the weights. + elif weight_or_bias == "weight": + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params.transpose(0, 1) + + # Copy the bias. + elif weight_or_bias == "bias": + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "bias"] = params + + if config.n_layer != (layer_idx + 1): + raise ValueError(f"Expected {config.n_layer} layers but found {layer_idx + 1}") + + # The final layernorm. + print("Converting final layernorm") + params = get_element_from_dict_by_path(tp_state_dicts[0], str(path)) + output_state_dict["transformer.ln_f.weight"] = params["final_layernorm.weight"].to(dtype) + output_state_dict["transformer.ln_f.bias"] = params["final_layernorm.bias"].to(dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + output_state_dict["lm_head.weight"] = word_embeddings.to(dtype) + + # It should be done! + print("Conversion from Megatron-LM to Transformers is done!") + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Add tokenizer class info to config + # see https://github.com/huggingface/transformers/issues/13906) + + if args.tokenizer_name is None: + tokenizer_name = "openai-community/gpt2" + else: + tokenizer_name = args.tokenizer_name + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer_class = type(tokenizer).__name__ + config.tokenizer_class = tokenizer_class + + # Store the config to file. + print("Saving config") + config.save_pretrained(args.save_path) + + # Save tokenizer based on args + if args.tokenizer_name is not None: + print(f"Adding {tokenizer_class} tokenizer files") + tokenizer.save_pretrained(args.save_path) + + # Store the state_dict to file. + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size + state_dict_split = split_torch_state_dict_into_shards(output_state_dict, max_shard_size=max_shard_size) + shards = index = None + for tensors in state_dict_split.filename_to_tensors.values(): + shards = {tensor: state_dict[tensor] for tensor in tensors} + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Save the model + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(args.save_path, shard_file)) + + if index is None: + print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + print( + f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + +def convert_checkpoint_from_transformers_to_megatron(args): + """ + Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable + tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers + which can have multiple shards. + + Args: + args (argparse.Namespace): the arguments to the script + + """ + os.makedirs(args.save_path, exist_ok=True) + # Search in directory above this + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + megatron_exists = importlib.util.find_spec("megatron") is not None + if megatron_exists: + from megatron.core import package_info + + if version.parse(package_info.__version__) >= version.parse("0.6.0"): + from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding + else: + from megatron.tokenizer.tokenizer import _vocab_size_with_padding + + else: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + exit(1) + + # load the transformers model state dict and config + sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")] + if len(sub_dirs) == 1: + checkpoint_name = "pytorch_model.bin" + check_torch_load_is_safe() + state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu", weights_only=True) + else: + num_checkpoints = len(sub_dirs) - 1 + state_dict = merge_transformers_sharded_states(args.load_path, num_checkpoints) + + config = GPT2Config.from_pretrained(args.load_path) + + # Saving the tracker file + tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt") + with open(tracker_filepath, "w") as f: + f.write("release") + + # create `release` dir in args.load_path + release_dir = os.path.join(args.save_path, "release") + os.makedirs(release_dir, exist_ok=True) + + # megatron args + megatron_args = { + "orig_vocab_size": config.vocab_size, + "max_position_embeddings": config.n_positions, + "hidden_size": config.n_embd, + "num_layers": config.n_layer, + "num_attention_heads": config.n_head, + "ffn_hidden_size": config.n_inner, + "tensor_model_parallel_size": args.target_tensor_model_parallel_size, + "pipeline_model_parallel_size": args.target_pipeline_model_parallel_size, + "data_parallel_size": args.target_data_parallel_size, + "make_vocab_size_divisible_by": args.make_vocab_size_divisible_by, + "rank": 0, + "tokenizer_type": "GPT2BPETokenizer", + } + + if config.activation_function == "gelu": + megatron_args["bias_gelu_fusion"] = False + megatron_args["openai_gelu"] = False + elif config.activation_function == "gelu_fast": + megatron_args["bias_gelu_fusion"] = True + megatron_args["openai_gelu"] = False + elif config.activation_function == "gelu_new": + megatron_args["bias_gelu_fusion"] = False + megatron_args["openai_gelu"] = True + + margs = types.SimpleNamespace() + for k, v in megatron_args.items(): + setattr(margs, k, v) + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + setattr(margs, "params_dtype", dtype) + + # save dummy optim state dict + dummy_optim_state_dict = {} + dummy_optim_state_dict["optimizer"] = { + "step": 0, + "param_groups": [ + { + "lr": 0.0, + "beta1": 0.0, + "beta2": 0.0, + "eps": 0.0, + "weight_decay": 0.0, + "correct_bias": False, + "params": [], + } + ], + } + if args.use_distributed_optimizer: + for i in range(args.target_pipeline_model_parallel_size): + for j in range(args.target_tensor_model_parallel_size): + for k in range(args.target_data_parallel_size): + if args.target_pipeline_model_parallel_size == 1: + checkpoint_dir = f"mp_rank_{j:02d}_{k:03d}" + else: + checkpoint_dir = f"mp_rank_{j:02d}_{i:03d}_{k:03d}" + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + torch.save( + dummy_optim_state_dict, + os.path.join(checkpoint_dir, "optim.pt"), + ) + + # Convert. + print("Converting") + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + # Embedding layer + print("converting embedding layer") + pos_embedding = state_dict["transformer.wpe.weight"].to(dtype) + word_embedding = state_dict["transformer.wte.weight"].to(dtype) + orig_vocab_size = config.vocab_size + padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs) + setattr(margs, "padded_vocab_size", padded_vocab_size) + # Cut out extra padding we don't need + if orig_vocab_size > padded_vocab_size: + full_word_embed = word_embedding[0:padded_vocab_size, :] + # Expanding embedding to larger size by replicating final entry + elif orig_vocab_size < padded_vocab_size: + padding_size = padded_vocab_size - orig_vocab_size + full_word_embed = torch.cat((word_embedding, word_embedding[-1].unsqueeze(0).expand(padding_size, -1))) + # Same size! + else: + full_word_embed = word_embedding + + # Split into new tensor model parallel sizes + out_word_embed = torch.chunk(full_word_embed, args.target_tensor_model_parallel_size, dim=0) + for i in range(args.target_tensor_model_parallel_size): + pos_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.position_embeddings" + ) + pos_emb_dict["weight"] = pos_embedding + + word_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.word_embeddings" + ) + word_emb_dict["weight"] = out_word_embed[i].clone() + + # Transformer layers + print("converting transformer layers") + if config.num_attention_heads % args.target_tensor_model_parallel_size != 0: + raise ValueError( + f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of tensor parallelism" + f" ({args.target_tensor_model_parallel_size})" + ) + + if config.num_hidden_layers % args.target_pipeline_model_parallel_size != 0: + raise ValueError( + f"Number of layers ({config.num_hidden_layers}) must be divisible by number of pipeline parallelism" + f" ({args.target_pipeline_model_parallel_size})" + ) + + num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size + + layer_re = re.compile(r"transformer.h\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + # The number of heads. + heads = config.n_head + # The hidden_size per head. + hidden_size_per_head = config.n_embd // config.n_head + for pp_rank in range(args.target_pipeline_model_parallel_size): + layer_offset = pp_rank * num_layers + if pp_rank > 0: + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + for layer in range(num_layers): + pp_layer_id = layer + layer_offset + layers_to_copy = [ + layer_name for layer_name in state_dict if layer_name.startswith(f"transformer.h.{pp_layer_id}.") + ] + + for layer_name in layers_to_copy: + m = layer_re.match(layer_name) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + _ = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + params = state_dict[layer_name].to(dtype) + # handle layernorm + if op_name.startswith("ln"): + out_name = "input_layernorm" if op_name.endswith("1") else "post_attention_layernorm" + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # handle attention K, V, Q weights + elif op_name.startswith("attn.c_attn") and weight_or_bias == "weight": + # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D. + params = params.transpose(0, 1).contiguous() + + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 3, + heads, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + + # handle attention K, V, Q bias + elif op_name.startswith("attn.c_attn") and weight_or_bias == "bias": + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 3, + heads, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + + # handle attention and mlp weights + elif weight_or_bias == "weight": + out_name = transformers_to_megatron.get(op_name) + if out_name is None: + continue + params = params.transpose(0, 1) + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # handle attention and mlp bias + elif weight_or_bias == "bias": + out_name = transformers_to_megatron.get(op_name) + if out_name is None: + continue + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # skip + else: + continue + + if op_name + "." + weight_or_bias in tensor_parallel_params: + dim = 1 if op_name in ["attn.c_proj", "mlp.c_proj"] else 0 + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=dim) + + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = ( + params[i].clone() if (op_name + "." + weight_or_bias in tensor_parallel_params) else params + ) + + if pp_rank == args.target_pipeline_model_parallel_size - 1: + # handle final layernorm + for weight_or_bias in ["weight", "bias"]: + params = state_dict[f"transformer.ln_f.{weight_or_bias}"].to(dtype) + layer_name = f"final_layernorm.{weight_or_bias}" + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = params + + # add the LM head + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.word_embeddings_for_head") + params_dict["weight"] = out_word_embed[i].clone() + + # saving the state dict as per the tp_rank and pp_rank + for tp_rank in range(args.target_tensor_model_parallel_size): + output_state_dict[tp_rank]["checkpoint_version"] = 3.0 + output_state_dict[tp_rank]["args"] = margs + checkpoint_dir = ( + f"mp_rank_{tp_rank:02d}" + if args.target_pipeline_model_parallel_size == 1 + else f"mp_rank_{tp_rank:02d}_{pp_rank:03d}" + ) + if args.use_distributed_optimizer: + checkpoint_name = "model_rng.pt" + else: + checkpoint_name = "model_optim_rng.pt" + output_state_dict[tp_rank]["optimizer"] = dummy_optim_state_dict["optimizer"] + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + if args.print_checkpoint_structure: + print( + f"Checkpoint structure of model state dict shard belonging to TP rank {tp_rank} and PP rank" + f" {pp_rank}:" + ) + recursive_print(None, output_state_dict[tp_rank]) + torch.save(output_state_dict[tp_rank], checkpoint_path) + + +def main(): + parser = argparse.ArgumentParser() + parser = add_checkpointing_args(parser) + parser = add_megatron_checkpoint_args(parser) + parser = add_transformers_checkpoint_args(parser) + args = parser.parse_args() + if args.convert_checkpoint_from_megatron_to_transformers: + convert_checkpoint_from_megatron_to_transformers(args) + else: + convert_checkpoint_from_transformers_to_megatron(args) + + +if __name__ == "__main__": + main() diff --git a/venv/lib/python3.13/site-packages/transformers/models/mistral3/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mistral3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11a9fcbdc4ed3935dc549e3cd1992b7550d9bbad --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mistral3/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mistral3 import * + from .modeling_mistral3 import * + from .processing_mistral3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mistral3/configuration_mistral3.py b/venv/lib/python3.13/site-packages/transformers/models/mistral3/configuration_mistral3.py new file mode 100644 index 0000000000000000000000000000000000000000..90ae36abcec371cc0871452005981628e5219413 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mistral3/configuration_mistral3.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class Mistral3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Mistral3ForConditionalGeneration`]. It is used to instantiate an + Mistral3 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + [mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `PixtralVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MistralConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 10): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + multimodal_projector_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the multimodal projector. + spatial_merge_size (`int`, *optional*, defaults to 2): + The downsampling factor for the spatial merge operation. + + Example: + + ```python + >>> from transformers import Mistral3ForConditionalGeneration, Mistral3Config, PixtralVisionConfig, MistralConfig + + >>> # Initializing a Pixtral-vision config + >>> vision_config = PixtralVisionConfig() + + >>> # Initializing a Mistral config + >>> text_config = MistralConfig() + + >>> # Initializing a Mistral3 configuration + >>> configuration = Mistral3Config(vision_config, text_config) + + >>> # Initializing a model from the mistral3.1 configuration + >>> model = Mistral3ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral3" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + is_composition = True + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=10, + projector_hidden_act="gelu", + vision_feature_layer=-1, + multimodal_projector_bias=False, + spatial_merge_size=2, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "pixtral") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["pixtral"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=1540, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + head_dim=64, + hidden_act="gelu", + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "mistral") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["mistral"]( + attention_dropout=0.0, + head_dim=128, + hidden_act="silu", + hidden_size=5120, + initializer_range=0.02, + intermediate_size=32768, + max_position_embeddings=131072, + model_type="mistral", + num_attention_heads=32, + num_hidden_layers=40, + num_key_value_heads=8, + rms_norm_eps=1e-05, + rope_theta=1000000000.0, + sliding_window=None, + use_cache=True, + vocab_size=131072, + ) + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + self.spatial_merge_size = spatial_merge_size + + super().__init__(**kwargs) + + +__all__ = ["Mistral3Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mistral3/modeling_mistral3.py b/venv/lib/python3.13/site-packages/transformers/models/mistral3/modeling_mistral3.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2a53a54352bdb543318d20e094efdc5e2559b8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mistral3/modeling_mistral3.py @@ -0,0 +1,536 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mistral3/modular_mistral3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mistral3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ..auto import AutoModel +from .configuration_mistral3 import Mistral3Config + + +@use_kernel_forward_from_hub("RMSNorm") +class Mistral3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Mistral3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Mistral3PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__(self, config: Mistral3Config): + super().__init__() + self.config = config + + hidden_size = config.vision_config.hidden_size + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = self.config.vision_config.patch_size + self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + grid = grid.view(d * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + return image_features + + +class Mistral3MultiModalProjector(nn.Module): + def __init__(self, config: Mistral3Config): + super().__init__() + self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps) + self.patch_merger = Mistral3PatchMerger(config) + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor): + image_features = self.norm(image_features) + image_features = self.patch_merger(image_features, image_sizes) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Mistral3 causal language model (or autoregressive) outputs. + """ +) +class Mistral3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Mistral3 outputs, with hidden states and attentions. + """ +) +class Mistral3ModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@auto_docstring +class Mistral3PreTrainedModel(PreTrainedModel): + config: Mistral3Config + base_model_prefix = "" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + + +@auto_docstring( + custom_intro=""" + The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head. + """ +) +class Mistral3Model(Mistral3PreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: Mistral3Config): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = Mistral3MultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_sizes (`torch.Tensor`, *optional*): + Tensor containing the image sizes as returned by the processor. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size + split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes] + image_features = torch.split(image_features.squeeze(0), split_sizes) + return image_features + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Mistral3ModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + image_sizes=image_sizes, + ) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return Mistral3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The MISTRAL3 model which consists of a vision backbone and a language model. + """ +) +class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Mistral3Config): + super().__init__(config) + self.model = Mistral3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + image_sizes=image_sizes, + vision_feature_layer=vision_feature_layer, + **kwargs, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + >>> prompt = "[INST][IMG]What is the image?[/INST]" + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is the image?The image depicts two cats lying on a pink blanket." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return Mistral3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mistral3/modular_mistral3.py b/venv/lib/python3.13/site-packages/transformers/models/mistral3/modular_mistral3.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc499d21453ef10540378c5e5e46bdced64f627 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mistral3/modular_mistral3.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...processing_utils import Unpack +from ...utils import logging +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + LlavaModel, + LlavaModelOutputWithPast, + LlavaPreTrainedModel, + TransformersKwargs, +) +from ..mistral.modeling_mistral import MistralRMSNorm +from .configuration_mistral3 import Mistral3Config + + +logger = logging.get_logger(__name__) + + +class Mistral3RMSNorm(MistralRMSNorm): + pass + + +class Mistral3PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__(self, config: Mistral3Config): + super().__init__() + self.config = config + + hidden_size = config.vision_config.hidden_size + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = self.config.vision_config.patch_size + self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + grid = grid.view(d * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + return image_features + + +class Mistral3MultiModalProjector(nn.Module): + def __init__(self, config: Mistral3Config): + super().__init__() + self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps) + self.patch_merger = Mistral3PatchMerger(config) + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor): + image_features = self.norm(image_features) + image_features = self.patch_merger(image_features, image_sizes) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + pass + + +class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast): + pass + + +class Mistral3PreTrainedModel(LlavaPreTrainedModel): + pass + + +class Mistral3Model(LlavaModel): + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_sizes (`torch.Tensor`, *optional*): + Tensor containing the image sizes as returned by the processor. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size + split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes] + image_features = torch.split(image_features.squeeze(0), split_sizes) + return image_features + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Mistral3ModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + image_sizes=image_sizes, + ) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return Mistral3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + image_sizes=image_sizes, + vision_feature_layer=vision_feature_layer, + **kwargs, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + >>> prompt = "[INST][IMG]What is the image?[/INST]" + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is the image?The image depicts two cats lying on a pink blanket." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return Mistral3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + +__all__ = [ + "Mistral3Model", + "Mistral3PreTrainedModel", + "Mistral3ForConditionalGeneration", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mixtral/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mixtral/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ca36bacbee79261b8fe19e1df5cd2fb91b7344 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mixtral/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 Mixtral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mixtral import * + from .modeling_mixtral import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mixtral/configuration_mixtral.py b/venv/lib/python3.13/site-packages/transformers/models/mixtral/configuration_mixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..35ae3f10dbc67a8040def0e641d73009c0505098 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mixtral/configuration_mixtral.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mixtral model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MixtralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an + Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1. + + [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B) + [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MixtralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + + ```python + >>> from transformers import MixtralModel, MixtralConfig + + >>> # Initializing a Mixtral 7B style configuration + >>> configuration = MixtralConfig() + + >>> # Initializing a model from the Mixtral 7B style configuration + >>> model = MixtralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mixtral" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.block_sparse_moe.experts.*.w1": "colwise", + "layers.*.block_sparse_moe.experts.*.w2": "rowwise", + "layers.*.block_sparse_moe.experts.*.w3": "colwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["MixtralConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mixtral/modeling_mixtral.py b/venv/lib/python3.13/site-packages/transformers/models/mixtral/modeling_mixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..affeb2f1491bf500a0e46406a7f18c224ef9a667 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mixtral/modeling_mixtral.py @@ -0,0 +1,689 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mixtral/modular_mixtral.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mixtral.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers.utils.generic import check_model_inputs + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder +from .configuration_mixtral import MixtralConfig + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +@use_kernel_forward_from_hub("RMSNorm") +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class MixtralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MixtralDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MixtralAttention(config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class MixtralRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: MixtralConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class MixtralPreTrainedModel(PreTrainedModel): + config: MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1), + "hidden_states": MixtralDecoderLayer, + "attentions": MixtralAttention, + } + + +@auto_docstring +class MixtralModel(MixtralPreTrainedModel): + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MixtralRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class MixtralForSequenceClassification(GenericForSequenceClassification, MixtralPreTrainedModel): + pass + + +class MixtralForTokenClassification(GenericForTokenClassification, MixtralPreTrainedModel): + pass + + +class MixtralForQuestionAnswering(GenericForQuestionAnswering, MixtralPreTrainedModel): + pass + + +__all__ = [ + "MixtralForCausalLM", + "MixtralForQuestionAnswering", + "MixtralModel", + "MixtralPreTrainedModel", + "MixtralForSequenceClassification", + "MixtralForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mixtral/modular_mixtral.py b/venv/lib/python3.13/site-packages/transformers/models/mixtral/modular_mixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..744f8c1321dc9fe51e702b81dea677e0deb4e00f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mixtral/modular_mixtral.py @@ -0,0 +1,467 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mixtral model.""" + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralForCausalLM, + MistralForQuestionAnswering, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, + MistralPreTrainedModel, + MistralRMSNorm, + MistralRotaryEmbedding, +) +from .configuration_mixtral import MixtralConfig + + +logger = logging.get_logger(__name__) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralRMSNorm(MistralRMSNorm): + pass + + +class MixtralAttention(MistralAttention): + pass + + +class MixtralDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MixtralAttention(config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class MixtralRotaryEmbedding(MistralRotaryEmbedding): + pass + + +class MixtralPreTrainedModel(MistralPreTrainedModel): + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_record_outputs = { + "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1), + "hidden_states": MixtralDecoderLayer, + "attentions": MixtralAttention, + } + + +class MixtralModel(MistralModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class MixtralForCausalLM(MistralForCausalLM): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class MixtralForSequenceClassification(MistralForSequenceClassification): + pass + + +class MixtralForTokenClassification(MistralForTokenClassification): + pass + + +class MixtralForQuestionAnswering(MistralForQuestionAnswering): + pass + + +__all__ = [ + "MixtralForCausalLM", + "MixtralForQuestionAnswering", + "MixtralModel", + "MixtralPreTrainedModel", + "MixtralForSequenceClassification", + "MixtralForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mlcd/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mlcd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3556177d5eb8e493aa5e051785b980b0e10d9c4a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mlcd/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mlcd import * + from .modeling_mlcd import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mlcd/configuration_mlcd.py b/venv/lib/python3.13/site-packages/transformers/models/mlcd/configuration_mlcd.py new file mode 100644 index 0000000000000000000000000000000000000000..f28a5f1a7cab3d453c3bf07caea9fa9fd25a593b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mlcd/configuration_mlcd.py @@ -0,0 +1,117 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mlcd.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig + + +class MLCDVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the vision encoder of the MLCD + [DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1664): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + projection_dim (`int`, *optional*, defaults to 1024): + Dimensionality of text and vision projection layers. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 336): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import MLCDVisionConfig, MLCDVisionModel + + >>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration + >>> configuration = MLCDVisionConfig() + + >>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration + >>> model = MLCDVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mlcd_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1664, + intermediate_size=8192, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_groups=1, + num_channels=3, + image_size=336, + patch_size=14, + hidden_act="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_groups = num_key_value_groups + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +__all__ = ["MLCDVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mlcd/modeling_mlcd.py b/venv/lib/python3.13/site-packages/transformers/models/mlcd/modeling_mlcd.py new file mode 100644 index 0000000000000000000000000000000000000000..5379c5313726dd2da0401b766580b569e354e277 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mlcd/modeling_mlcd.py @@ -0,0 +1,611 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mlcd.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, torch_int +from .configuration_mlcd import MLCDVisionConfig + + +class MLCDMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MLCDRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: + """ + Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size. + + Args: + num_patches_height (int): Number of patches in the height dimension. + num_patches_width (int): Number of patches in the width dimension. + + Returns: + torch.Tensor: Rotary positional embeddings for the given grid size. + """ + # Generate position IDs for height and width dimensions + hpos_ids = ( + torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width) + ) + wpos_ids = ( + torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1) + ) + + # Flatten and stack the position IDs + pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) + + # Generate the full rotary positional embeddings for the maximum grid size + max_grid_size = max(num_patches_height, num_patches_width) + seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + + # Select and flatten the embeddings based on the position IDs + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb + + +class MLCDVisionEmbeddings(nn.Module): + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + # patch_embeds -> shape = [batch, width, grid, grid] + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class MLCDAttention(nn.Module): + """Multi-headed attention with RoPE. Refer to papers: + - Attention is all you need: + https://huggingface.co/papers/1706.03762 + - RoFormer: Enhanced Transformer with Rotary Position Embedding: + https://huggingface.co/papers/2104.09864 + """ + + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.num_key_value_groups = config.num_key_value_groups + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + batch_size, seq_length = hidden_states.shape[:-1] + + # Each of shape: [batch_size, seq_length, num_heads, head_dim] + query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + + # Apply positional embeddings + cos = position_embeddings[0].unsqueeze(0).float() + sin = position_embeddings[1].unsqueeze(0).float() + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + # Each of shape: [batch_size, num_heads, seq_length, head_dim] + query_states = query_states.permute(0, 2, 1, 3).contiguous() + key_states = key_states.permute(0, 2, 1, 3).contiguous() + value_states = value_states.permute(0, 2, 1, 3).contiguous() + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scale, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim] + attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim] + attn_output = self.out_proj(attn_output) + attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim] + return attn_output, attn_weights + + +class MLCDEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = MLCDAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLCDMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + Represents the hidden states from the previous layer or the input embeddings. + position_embeddings (`tuple[torch.Tensor, torch.Tensor]`): + A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`. + Represents absolute positional embeddings for the query and key in the attention mechanism. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class MLCDEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`MLCDEncoderLayer`]. + + Args: + config: MLCDVisionConfig + """ + + def __init__(self, config: MLCDVisionConfig): + """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`.""" + super().__init__() + self.config = config + self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds: torch.FloatTensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + position_embeddings (`tuple[torch.Tensor, torch.Tensor]`): + A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`. + Represents absolute positional embeddings for the query and key in the attention mechanism. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class MLCDVisionTransformer(nn.Module): + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = MLCDVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = MLCDEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) + self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + num_patches_height = pixel_values.shape[-2] // self.config.patch_size + num_patches_width = pixel_values.shape[-1] // self.config.patch_size + rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) + rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device) + rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class MLCDPreTrainedModel(PreTrainedModel): + config: MLCDVisionConfig + base_model_prefix = "mlcd" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, MLCDVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, MLCDAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, MLCDMLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, MLCDVisionTransformer): + factor = self.config.initializer_factor + pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor + nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + The vision model from M_L_C_D without any head or projection on top. + """ +) +class MLCDVisionModel(MLCDPreTrainedModel): + config: MLCDVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["MLCDEncoderLayer"] + + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + self.vision_model = MLCDVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + Example: + + ```python + >>> import requests + >>> from PIL import Image + >>> from transformers import AutoProcessor, MLCDVisionModel + >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs, output_attentions=True) + + >>> features = outputs.last_hidden_state + >>> print(f"Extracted features shape: {features.shape}") + >>> print(f"Number of attention layers: {len(outputs.attentions)}") + >>> print(f"Attention shape: {outputs.attentions[0].shape}") + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +__all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mlcd/modular_mlcd.py b/venv/lib/python3.13/site-packages/transformers/models/mlcd/modular_mlcd.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc18ab2b1c8f123be84c502d534c6000882a683 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mlcd/modular_mlcd.py @@ -0,0 +1,529 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, logging +from ..clip.modeling_clip import ( + CLIPMLP, + CLIPAttention, + CLIPEncoder, + CLIPEncoderLayer, + CLIPVisionEmbeddings, + CLIPVisionModel, + CLIPVisionTransformer, +) +from ..llama.modeling_llama import eager_attention_forward +from ..qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding, apply_rotary_pos_emb_vision + + +logger = logging.get_logger(__name__) + + +class MLCDVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the vision encoder of the MLCD + [DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1664): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + projection_dim (`int`, *optional*, defaults to 1024): + Dimensionality of text and vision projection layers. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 336): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import MLCDVisionConfig, MLCDVisionModel + + >>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration + >>> configuration = MLCDVisionConfig() + + >>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration + >>> model = MLCDVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mlcd_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1664, + intermediate_size=8192, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_groups=1, + num_channels=3, + image_size=336, + patch_size=14, + hidden_act="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_groups = num_key_value_groups + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class MLCDMLP(CLIPMLP): + pass + + +class MLCDRotaryEmbedding(VisionRotaryEmbedding): + def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: + """ + Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size. + + Args: + num_patches_height (int): Number of patches in the height dimension. + num_patches_width (int): Number of patches in the width dimension. + + Returns: + torch.Tensor: Rotary positional embeddings for the given grid size. + """ + # Generate position IDs for height and width dimensions + hpos_ids = ( + torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width) + ) + wpos_ids = ( + torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1) + ) + + # Flatten and stack the position IDs + pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) + + # Generate the full rotary positional embeddings for the maximum grid size + max_grid_size = max(num_patches_height, num_patches_width) + seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + + # Select and flatten the embeddings based on the position IDs + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb + + +class MLCDVisionEmbeddings(CLIPVisionEmbeddings): + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + del self.position_embedding + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + # patch_embeds -> shape = [batch, width, grid, grid] + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + return embeddings + + +class MLCDAttention(CLIPAttention): + """Multi-headed attention with RoPE. Refer to papers: + - Attention is all you need: + https://huggingface.co/papers/1706.03762 + - RoFormer: Enhanced Transformer with Rotary Position Embedding: + https://huggingface.co/papers/2104.09864 + """ + + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + self.num_key_value_groups = config.num_key_value_groups + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch_size, seq_length = hidden_states.shape[:-1] + + # Each of shape: [batch_size, seq_length, num_heads, head_dim] + query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + + # Apply positional embeddings + cos = position_embeddings[0].unsqueeze(0).float() + sin = position_embeddings[1].unsqueeze(0).float() + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + # Each of shape: [batch_size, num_heads, seq_length, head_dim] + query_states = query_states.permute(0, 2, 1, 3).contiguous() + key_states = key_states.permute(0, 2, 1, 3).contiguous() + value_states = value_states.permute(0, 2, 1, 3).contiguous() + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scale, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim] + attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim] + attn_output = self.out_proj(attn_output) + attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim] + return attn_output, attn_weights + + +class MLCDEncoderLayer(CLIPEncoderLayer): + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + self.self_attn = MLCDAttention(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + Represents the hidden states from the previous layer or the input embeddings. + position_embeddings (`tuple[torch.Tensor, torch.Tensor]`): + A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`. + Represents absolute positional embeddings for the query and key in the attention mechanism. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class MLCDEncoder(CLIPEncoder): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`MLCDEncoderLayer`]. + + Args: + config: MLCDVisionConfig + """ + + def __init__(self, config: MLCDVisionConfig): + """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`.""" + super().__init__(config) + + def forward( + self, + inputs_embeds: torch.FloatTensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + position_embeddings (`tuple[torch.Tensor, torch.Tensor]`): + A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`. + Represents absolute positional embeddings for the query and key in the attention mechanism. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class MLCDVisionTransformer(CLIPVisionTransformer): + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) + self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + num_patches_height = pixel_values.shape[-2] // self.config.patch_size + num_patches_width = pixel_values.shape[-1] // self.config.patch_size + rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) + rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device) + rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class MLCDPreTrainedModel(PreTrainedModel): + config: MLCDVisionConfig + base_model_prefix = "mlcd" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, MLCDVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, MLCDAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, MLCDMLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, MLCDVisionTransformer): + factor = self.config.initializer_factor + pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor + nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class MLCDVisionModel(CLIPVisionModel): + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + Example: + + ```python + >>> import requests + >>> from PIL import Image + >>> from transformers import AutoProcessor, MLCDVisionModel + >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs, output_attentions=True) + + >>> features = outputs.last_hidden_state + >>> print(f"Extracted features shape: {features.shape}") + >>> print(f"Number of attention layers: {len(outputs.attentions)}") + >>> print(f"Attention shape: {outputs.attentions[0].shape}") + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +__all__ = [ + "MLCDVisionConfig", + "MLCDPreTrainedModel", + "MLCDVisionModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mluke/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mluke/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..521db20088e161bcfed5ca6ad016ebc5d50e18d7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mluke/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_mluke import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mluke/tokenization_mluke.py b/venv/lib/python3.13/site-packages/transformers/models/mluke/tokenization_mluke.py new file mode 100644 index 0000000000000000000000000000000000000000..15f4db53287a7c64d8b156578c63abd9b5feba3e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mluke/tokenization_mluke.py @@ -0,0 +1,1641 @@ +# coding=utf-8 +# Copyright 2021 Studio Ousia and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for mLUKE.""" + +import itertools +import json +import os +from collections.abc import Mapping +from shutil import copyfile +from typing import Any, Optional, Union + +import numpy as np +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + PaddingStrategy, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, + to_py_obj, +) +from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +EntitySpan = tuple[int, int] +EntitySpanInput = list[EntitySpan] +Entity = str +EntityInput = list[Entity] + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "entity_vocab_file": "entity_vocab.json"} + + +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **entity_ids** -- List of entity ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **entity_position_ids** -- List of entity positions in the input sequence to be fed to a model. + + - **entity_token_type_ids** -- List of entity token type ids to be fed to a model (when + `return_token_type_ids=True` or if *"entity_token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **entity_attention_mask** -- List of indices specifying which entities should be attended to by the model + (when `return_attention_mask=True` or if *"entity_attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **entity_start_positions** -- List of the start positions of entities in the word token sequence (when + `task="entity_span_classification"`). + - **entity_end_positions** -- List of the end positions of entities in the word token sequence (when + `task="entity_span_classification"`). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`) + +""" + + +@requires(backends=("sentencepiece",)) +class MLukeTokenizer(PreTrainedTokenizer): + """ + Adapted from [`XLMRobertaTokenizer`] and [`LukeTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + entity_vocab_file (`str`): + Path to the entity vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + task (`str`, *optional*): + Task for which you want to prepare sequences. One of `"entity_classification"`, + `"entity_pair_classification"`, or `"entity_span_classification"`. If you specify this argument, the entity + sequence is automatically created based on the given entity span(s). + max_entity_length (`int`, *optional*, defaults to 32): + The maximum length of `entity_ids`. + max_mention_length (`int`, *optional*, defaults to 30): + The maximum number of tokens inside an entity span. + entity_token_1 (`str`, *optional*, defaults to ``): + The special token used to represent an entity span in a word token sequence. This token is only used when + `task` is set to `"entity_classification"` or `"entity_pair_classification"`. + entity_token_2 (`str`, *optional*, defaults to ``): + The special token used to represent an entity span in a word token sequence. This token is only used when + `task` is set to `"entity_pair_classification"`. + additional_special_tokens (`list[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + entity_vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + task=None, + max_entity_length=32, + max_mention_length=30, + entity_token_1="", + entity_token_2="", + entity_unk_token="[UNK]", + entity_pad_token="[PAD]", + entity_mask_token="[MASK]", + entity_mask2_token="[MASK2]", + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + # we add 2 special tokens for downstream tasks + # for more information about lstrip and rstrip, see https://github.com/huggingface/transformers/pull/2778 + entity_token_1 = ( + AddedToken(entity_token_1, lstrip=False, rstrip=False) + if isinstance(entity_token_1, str) + else entity_token_1 + ) + entity_token_2 = ( + AddedToken(entity_token_2, lstrip=False, rstrip=False) + if isinstance(entity_token_2, str) + else entity_token_2 + ) + additional_special_tokens = kwargs.pop("additional_special_tokens", []) + additional_special_tokens += [entity_token_1, entity_token_2] + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + self.fairseq_offset + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + + with open(entity_vocab_file, encoding="utf-8") as entity_vocab_handle: + self.entity_vocab = json.load(entity_vocab_handle) + for entity_special_token in [entity_unk_token, entity_pad_token, entity_mask_token, entity_mask2_token]: + if entity_special_token not in self.entity_vocab: + raise ValueError( + f"Specified entity special token ``{entity_special_token}`` is not found in entity_vocab. " + f"Probably an incorrect entity vocab file is loaded: {entity_vocab_file}." + ) + self.entity_unk_token_id = self.entity_vocab[entity_unk_token] + self.entity_pad_token_id = self.entity_vocab[entity_pad_token] + self.entity_mask_token_id = self.entity_vocab[entity_mask_token] + self.entity_mask2_token_id = self.entity_vocab[entity_mask2_token] + + self.task = task + if task is None or task == "entity_span_classification": + self.max_entity_length = max_entity_length + elif task == "entity_classification": + self.max_entity_length = 1 + elif task == "entity_pair_classification": + self.max_entity_length = 2 + else: + raise ValueError( + f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification'," + " 'entity_span_classification'] only." + ) + + self.max_mention_length = max_mention_length + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + task=task, + max_entity_length=max_entity_length, + max_mention_length=max_mention_length, + entity_token_1=entity_token_1, + entity_token_2=entity_token_2, + entity_unk_token=entity_unk_token, + entity_pad_token=entity_pad_token, + entity_mask_token=entity_mask_token, + entity_mask2_token=entity_mask2_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + @property + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.vocab_size + def vocab_size(self): + return len(self.sp_model) + self.fairseq_offset + 1 # Add the token + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer._tokenize + def _tokenize(self, text: str) -> list[str]: + # TODO check if the t5/llama PR also applies here + return self.sp_model.encode(text, out_type=str) + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.__call__ + def __call__( + self, + text: Union[TextInput, list[TextInput]], + text_pair: Optional[Union[TextInput, list[TextInput]]] = None, + entity_spans: Optional[Union[EntitySpanInput, list[EntitySpanInput]]] = None, + entity_spans_pair: Optional[Union[EntitySpanInput, list[EntitySpanInput]]] = None, + entities: Optional[Union[EntityInput, list[EntityInput]]] = None, + entities_pair: Optional[Union[EntityInput, list[EntityInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences, depending on the task you want to prepare them for. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + text_pair (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + entity_spans (`list[tuple[int, int]]`, `list[list[tuple[int, int]]]`, *optional*): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify + `"entity_classification"` or `"entity_pair_classification"` as the `task` argument in the constructor, + the length of each sequence must be 1 or 2, respectively. If you specify `entities`, the length of each + sequence must be equal to the length of each sequence of `entities`. + entity_spans_pair (`list[tuple[int, int]]`, `list[list[tuple[int, int]]]`, *optional*): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify the + `task` argument in the constructor, this argument is ignored. If you specify `entities_pair`, the + length of each sequence must be equal to the length of each sequence of `entities_pair`. + entities (`list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the `task` argument in the constructor. The length of + each sequence must be equal to the length of each sequence of `entity_spans`. If you specify + `entity_spans` without specifying this argument, the entity sequence or the batch of entity sequences + is automatically constructed by filling it with the [MASK] entity. + entities_pair (`list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the `task` argument in the constructor. The length of + each sequence must be equal to the length of each sequence of `entity_spans_pair`. If you specify + `entity_spans_pair` without specifying this argument, the entity sequence or the batch of entity + sequences is automatically constructed by filling it with the [MASK] entity. + max_entity_length (`int`, *optional*): + The maximum length of `entity_ids`. + """ + # Input type checking for clearer error + is_valid_single_text = isinstance(text, str) + is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str))) + if not (is_valid_single_text or is_valid_batch_text): + raise ValueError("text input must be of type `str` (single example) or `list[str]` (batch).") + + is_valid_single_text_pair = isinstance(text_pair, str) + is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and ( + len(text_pair) == 0 or isinstance(text_pair[0], str) + ) + if not (text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair): + raise ValueError("text_pair input must be of type `str` (single example) or `list[str]` (batch).") + + is_batched = bool(isinstance(text, (list, tuple))) + + if is_batched: + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + if entities is None: + batch_entities_or_entities_pairs = None + else: + batch_entities_or_entities_pairs = ( + list(zip(entities, entities_pair)) if entities_pair is not None else entities + ) + + if entity_spans is None: + batch_entity_spans_or_entity_spans_pairs = None + else: + batch_entity_spans_or_entity_spans_pairs = ( + list(zip(entity_spans, entity_spans_pair)) if entity_spans_pair is not None else entity_spans + ) + + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs, + batch_entities_or_entities_pairs=batch_entities_or_entities_pairs, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + entities=entities, + entities_pair=entities_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._encode_plus + def _encode_plus( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + + # prepare_for_model will create the attention_mask and token_type_ids + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._batch_encode_plus + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[list[TextInput], list[TextInputPair]], + batch_entity_spans_or_entity_spans_pairs: Optional[ + Union[list[EntitySpanInput], list[tuple[EntitySpanInput, EntitySpanInput]]] + ] = None, + batch_entities_or_entities_pairs: Optional[ + Union[list[EntityInput], list[tuple[EntityInput, EntityInput]]] + ] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + # input_ids is a list of tuples (one for each example in the batch) + input_ids = [] + entity_ids = [] + entity_token_spans = [] + for index, text_or_text_pair in enumerate(batch_text_or_text_pairs): + if not isinstance(text_or_text_pair, (list, tuple)): + text, text_pair = text_or_text_pair, None + else: + text, text_pair = text_or_text_pair + + entities, entities_pair = None, None + if batch_entities_or_entities_pairs is not None: + entities_or_entities_pairs = batch_entities_or_entities_pairs[index] + if entities_or_entities_pairs: + if isinstance(entities_or_entities_pairs[0], str): + entities, entities_pair = entities_or_entities_pairs, None + else: + entities, entities_pair = entities_or_entities_pairs + + entity_spans, entity_spans_pair = None, None + if batch_entity_spans_or_entity_spans_pairs is not None: + entity_spans_or_entity_spans_pairs = batch_entity_spans_or_entity_spans_pairs[index] + if len(entity_spans_or_entity_spans_pairs) > 0 and isinstance( + entity_spans_or_entity_spans_pairs[0], list + ): + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs + else: + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + input_ids.append((first_ids, second_ids)) + entity_ids.append((first_entity_ids, second_entity_ids)) + entity_token_spans.append((first_entity_token_spans, second_entity_token_spans)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + batch_entity_ids_pairs=entity_ids, + batch_entity_token_spans_pairs=entity_token_spans, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._check_entity_input_format + def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spans: Optional[EntitySpanInput]): + if not isinstance(entity_spans, list): + raise TypeError("entity_spans should be given as a list") + elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple): + raise ValueError( + "entity_spans should be given as a list of tuples containing the start and end character indices" + ) + + if entities is not None: + if not isinstance(entities, list): + raise ValueError("If you specify entities, they should be given as a list") + + if len(entities) > 0 and not isinstance(entities[0], str): + raise ValueError("If you specify entities, they should be given as a list of entity names") + + if len(entities) != len(entity_spans): + raise ValueError("If you specify entities, entities and entity_spans must be the same length") + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._create_input_sequence + def _create_input_sequence( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + **kwargs, + ) -> tuple[list, list, list, list, list, list]: + def get_input_ids(text): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + + def get_input_ids_and_entity_token_spans(text, entity_spans): + if entity_spans is None: + return get_input_ids(text), None + + cur = 0 + input_ids = [] + entity_token_spans = [None] * len(entity_spans) + + split_char_positions = sorted(frozenset(itertools.chain(*entity_spans))) + char_pos2token_pos = {} + + for split_char_position in split_char_positions: + orig_split_char_position = split_char_position + if ( + split_char_position > 0 and text[split_char_position - 1] == " " + ): # whitespace should be prepended to the following token + split_char_position -= 1 + if cur != split_char_position: + input_ids += get_input_ids(text[cur:split_char_position]) + cur = split_char_position + char_pos2token_pos[orig_split_char_position] = len(input_ids) + + input_ids += get_input_ids(text[cur:]) + + entity_token_spans = [ + (char_pos2token_pos[char_start], char_pos2token_pos[char_end]) for char_start, char_end in entity_spans + ] + + return input_ids, entity_token_spans + + first_ids, second_ids = None, None + first_entity_ids, second_entity_ids = None, None + first_entity_token_spans, second_entity_token_spans = None, None + + if self.task is None: + if entity_spans is None: + first_ids = get_input_ids(text) + else: + self._check_entity_input_format(entities, entity_spans) + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + if entities is None: + first_entity_ids = [self.entity_mask_token_id] * len(entity_spans) + else: + first_entity_ids = [self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities] + + if text_pair is not None: + if entity_spans_pair is None: + second_ids = get_input_ids(text_pair) + else: + self._check_entity_input_format(entities_pair, entity_spans_pair) + + second_ids, second_entity_token_spans = get_input_ids_and_entity_token_spans( + text_pair, entity_spans_pair + ) + if entities_pair is None: + second_entity_ids = [self.entity_mask_token_id] * len(entity_spans_pair) + else: + second_entity_ids = [ + self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities_pair + ] + + elif self.task == "entity_classification": + if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be a list containing a single tuple " + "containing the start and end character indices of an entity" + ) + first_entity_ids = [self.entity_mask_token_id] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + # add special tokens to input ids + entity_token_start, entity_token_end = first_entity_token_spans[0] + first_ids = ( + first_ids[:entity_token_end] + [self.additional_special_tokens_ids[0]] + first_ids[entity_token_end:] + ) + first_ids = ( + first_ids[:entity_token_start] + + [self.additional_special_tokens_ids[0]] + + first_ids[entity_token_start:] + ) + first_entity_token_spans = [(entity_token_start, entity_token_end + 2)] + + elif self.task == "entity_pair_classification": + if not ( + isinstance(entity_spans, list) + and len(entity_spans) == 2 + and isinstance(entity_spans[0], tuple) + and isinstance(entity_spans[1], tuple) + ): + raise ValueError( + "Entity spans should be provided as a list of two tuples, " + "each tuple containing the start and end character indices of an entity" + ) + + head_span, tail_span = entity_spans + first_entity_ids = [self.entity_mask_token_id, self.entity_mask2_token_id] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + head_token_span, tail_token_span = first_entity_token_spans + token_span_with_special_token_ids = [ + (head_token_span, self.additional_special_tokens_ids[0]), + (tail_token_span, self.additional_special_tokens_ids[1]), + ] + if head_token_span[0] < tail_token_span[0]: + first_entity_token_spans[0] = (head_token_span[0], head_token_span[1] + 2) + first_entity_token_spans[1] = (tail_token_span[0] + 2, tail_token_span[1] + 4) + token_span_with_special_token_ids = reversed(token_span_with_special_token_ids) + else: + first_entity_token_spans[0] = (head_token_span[0] + 2, head_token_span[1] + 4) + first_entity_token_spans[1] = (tail_token_span[0], tail_token_span[1] + 2) + + for (entity_token_start, entity_token_end), special_token_id in token_span_with_special_token_ids: + first_ids = first_ids[:entity_token_end] + [special_token_id] + first_ids[entity_token_end:] + first_ids = first_ids[:entity_token_start] + [special_token_id] + first_ids[entity_token_start:] + + elif self.task == "entity_span_classification": + if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be provided as a list of tuples, " + "each tuple containing the start and end character indices of an entity" + ) + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + first_entity_ids = [self.entity_mask_token_id] * len(entity_spans) + + else: + raise ValueError(f"Task {self.task} not supported") + + return ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._batch_prepare_for_model + def _batch_prepare_for_model( + self, + batch_ids_pairs: list[tuple[list[int], None]], + batch_entity_ids_pairs: list[tuple[Optional[list[int]], Optional[list[int]]]], + batch_entity_token_spans_pairs: list[tuple[Optional[list[tuple[int, int]]], Optional[list[tuple[int, int]]]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + batch_entity_ids_pairs: list of entity ids or entity ids pairs + batch_entity_token_spans_pairs: list of entity spans or entity spans pairs + max_entity_length: The maximum length of the entity sequence. + """ + + batch_outputs = {} + for input_ids, entity_ids, entity_token_span_pairs in zip( + batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_spans_pairs + ): + first_ids, second_ids = input_ids + first_entity_ids, second_entity_ids = entity_ids + first_entity_token_spans, second_entity_token_spans = entity_token_span_pairs + outputs = self.prepare_for_model( + first_ids, + second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.prepare_for_model + def prepare_for_model( + self, + ids: list[int], + pair_ids: Optional[list[int]] = None, + entity_ids: Optional[list[int]] = None, + pair_entity_ids: Optional[list[int]] = None, + entity_token_spans: Optional[list[tuple[int, int]]] = None, + pair_entity_token_spans: Optional[list[tuple[int, int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids, + entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing + while taking into account the special tokens and manages a moving window (with user defined stride) for + overflowing tokens. Please Note, for *pair_ids* different than `None` and *truncation_strategy = longest_first* + or `True`, it is not possible to return overflowing tokens. Such a combination of arguments will raise an + error. + + Args: + ids (`list[int]`): + Tokenized input ids of the first sequence. + pair_ids (`list[int]`, *optional*): + Tokenized input ids of the second sequence. + entity_ids (`list[int]`, *optional*): + Entity ids of the first sequence. + pair_entity_ids (`list[int]`, *optional*): + Entity ids of the second sequence. + entity_token_spans (`list[tuple[int, int]]`, *optional*): + Entity spans of the first sequence. + pair_entity_token_spans (`list[tuple[int, int]]`, *optional*): + Entity spans of the second sequence. + max_entity_length (`int`, *optional*): + The maximum length of the entity sequence. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + # Compute lengths + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned word encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length and max_entity_length + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + # truncate words up to max_length + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + entity_token_offset = 1 # 1 * token + pair_entity_token_offset = len(ids) + 3 # 1 * token & 2 * tokens + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + entity_token_offset = 0 + pair_entity_token_offset = len(ids) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Set max entity length + if not max_entity_length: + max_entity_length = self.max_entity_length + + if entity_ids is not None: + total_entity_len = 0 + num_invalid_entities = 0 + valid_entity_ids = [ent_id for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)] + valid_entity_token_spans = [span for span in entity_token_spans if span[1] <= len(ids)] + + total_entity_len += len(valid_entity_ids) + num_invalid_entities += len(entity_ids) - len(valid_entity_ids) + + valid_pair_entity_ids, valid_pair_entity_token_spans = None, None + if pair_entity_ids is not None: + valid_pair_entity_ids = [ + ent_id + for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans) + if span[1] <= len(pair_ids) + ] + valid_pair_entity_token_spans = [span for span in pair_entity_token_spans if span[1] <= len(pair_ids)] + total_entity_len += len(valid_pair_entity_ids) + num_invalid_entities += len(pair_entity_ids) - len(valid_pair_entity_ids) + + if num_invalid_entities != 0: + logger.warning( + f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the" + " truncation of input tokens" + ) + + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length: + # truncate entities up to max_entity_length + valid_entity_ids, valid_pair_entity_ids, overflowing_entities = self.truncate_sequences( + valid_entity_ids, + pair_ids=valid_pair_entity_ids, + num_tokens_to_remove=total_entity_len - max_entity_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + valid_entity_token_spans = valid_entity_token_spans[: len(valid_entity_ids)] + if valid_pair_entity_token_spans is not None: + valid_pair_entity_token_spans = valid_pair_entity_token_spans[: len(valid_pair_entity_ids)] + + if return_overflowing_tokens: + encoded_inputs["overflowing_entities"] = overflowing_entities + encoded_inputs["num_truncated_entities"] = total_entity_len - max_entity_length + + final_entity_ids = valid_entity_ids + valid_pair_entity_ids if valid_pair_entity_ids else valid_entity_ids + encoded_inputs["entity_ids"] = list(final_entity_ids) + entity_position_ids = [] + entity_start_positions = [] + entity_end_positions = [] + for token_spans, offset in ( + (valid_entity_token_spans, entity_token_offset), + (valid_pair_entity_token_spans, pair_entity_token_offset), + ): + if token_spans is not None: + for start, end in token_spans: + start += offset + end += offset + position_ids = list(range(start, end))[: self.max_mention_length] + position_ids += [-1] * (self.max_mention_length - end + start) + entity_position_ids.append(position_ids) + entity_start_positions.append(start) + entity_end_positions.append(end - 1) + + encoded_inputs["entity_position_ids"] = entity_position_ids + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = entity_start_positions + encoded_inputs["entity_end_positions"] = entity_end_positions + + if return_token_type_ids: + encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"]) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.pad + def pad( + self, + encoded_inputs: Union[ + BatchEncoding, + list[BatchEncoding], + dict[str, EncodedInput], + dict[str, list[EncodedInput]], + list[dict[str, EncodedInput]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + ) -> BatchEncoding: + """ + Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length + in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with + `self.padding_side`, `self.pad_token_id` and `self.pad_token_type_id`) .. note:: If the `encoded_inputs` passed + are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result will use the same type unless + you provide a different tensor type with `return_tensors`. In the case of PyTorch tensors, you will lose the + specific device of your tensors however. + + Args: + encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `dict[str, list[int]]`, `dict[str, list[list[int]]` or `list[dict[str, list[int]]]`): + Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of + tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str, + list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or + TensorFlow tensors), see the note above for the return type. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + max_entity_length (`int`, *optional*): + The maximum length of the entity sequence. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. [What are attention + masks?](../glossary#attention-mask) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + """ + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]} + + # The model's main input name, usually `input_ids`, has be passed for padding + if self.model_input_names[0] not in encoded_inputs: + raise ValueError( + "You should supply an encoding or a list of encodings to this method " + f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" + ) + + required_input = encoded_inputs[self.model_input_names[0]] + + if not required_input: + if return_attention_mask: + encoded_inputs["attention_mask"] = [] + return encoded_inputs + + # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + index = 0 + while len(required_input[index]) == 0: + index += 1 + if index < len(required_input): + first_element = required_input[index][0] + # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. + if not isinstance(first_element, (int, list, tuple)): + if is_tf_tensor(first_element): + return_tensors = "tf" if return_tensors is None else return_tensors + elif is_torch_tensor(first_element): + return_tensors = "pt" if return_tensors is None else return_tensors + elif isinstance(first_element, np.ndarray): + return_tensors = "np" if return_tensors is None else return_tensors + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + "Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in encoded_inputs.items(): + encoded_inputs[key] = to_py_obj(value) + + # Convert padding_strategy in PaddingStrategy + padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, verbose=verbose + ) + + if max_entity_length is None: + max_entity_length = self.max_entity_length + + required_input = encoded_inputs[self.model_input_names[0]] + if required_input and not isinstance(required_input[0], (list, tuple)): + encoded_inputs = self._pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + return BatchEncoding(encoded_inputs, tensor_type=return_tensors) + + batch_size = len(required_input) + if any(len(v) != batch_size for v in encoded_inputs.values()): + raise ValueError("Some items in the output dictionary have a different batch size than others.") + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in required_input) + max_entity_length = ( + max(len(inputs) for inputs in encoded_inputs["entity_ids"]) if "entity_ids" in encoded_inputs else 0 + ) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + inputs = {k: v[i] for k, v in encoded_inputs.items()} + outputs = self._pad( + inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + return BatchEncoding(batch_outputs, tensor_type=return_tensors) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._pad + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + max_entity_length: The maximum length of the entity sequence. + padding_strategy: PaddingStrategy to use for padding. + + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + entities_provided = bool("entity_ids" in encoded_inputs) + + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(encoded_inputs["input_ids"]) + if entities_provided: + max_entity_length = len(encoded_inputs["entity_ids"]) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + if ( + entities_provided + and max_entity_length is not None + and pad_to_multiple_of is not None + and (max_entity_length % pad_to_multiple_of != 0) + ): + max_entity_length = ((max_entity_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ( + len(encoded_inputs["input_ids"]) != max_length + or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length) + ) + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + if entities_provided and return_attention_mask and "entity_attention_mask" not in encoded_inputs: + encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + + if needs_to_be_padded: + difference = max_length - len(encoded_inputs["input_ids"]) + padding_side = padding_side if padding_side is not None else self.padding_side + if entities_provided: + entity_difference = max_entity_length - len(encoded_inputs["entity_ids"]) + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if entities_provided: + encoded_inputs["entity_attention_mask"] = ( + encoded_inputs["entity_attention_mask"] + [0] * entity_difference + ) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference + if entities_provided: + encoded_inputs["entity_token_type_ids"] = ( + encoded_inputs["entity_token_type_ids"] + [0] * entity_difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference + if entities_provided: + encoded_inputs["entity_ids"] = ( + encoded_inputs["entity_ids"] + [self.entity_pad_token_id] * entity_difference + ) + encoded_inputs["entity_position_ids"] = ( + encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference + ) + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = ( + encoded_inputs["entity_start_positions"] + [0] * entity_difference + ) + encoded_inputs["entity_end_positions"] = ( + encoded_inputs["entity_end_positions"] + [0] * entity_difference + ) + + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if entities_provided: + encoded_inputs["entity_attention_mask"] = [0] * entity_difference + encoded_inputs[ + "entity_attention_mask" + ] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"] + if entities_provided: + encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[ + "entity_token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] + if entities_provided: + encoded_inputs["entity_ids"] = [self.entity_pad_token_id] * entity_difference + encoded_inputs[ + "entity_ids" + ] + encoded_inputs["entity_position_ids"] = [ + [-1] * self.max_mention_length + ] * entity_difference + encoded_inputs["entity_position_ids"] + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_start_positions" + ] + encoded_inputs["entity_end_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_end_positions" + ] + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str, str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + entity_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"] + ) + + with open(entity_vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + return out_vocab_file, entity_vocab_file + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An XLM-RoBERTa sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + +__all__ = ["MLukeTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72257c2ad40072337696a98ad8ada72dcf05b539 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mm_grounding_dino import * + from .modeling_mm_grounding_dino import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..e49ccde7e2d79427a38b9907ab12016ef4c50814 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py @@ -0,0 +1,303 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mm_grounding_dino.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class MMGroundingDinoConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MMGroundingDinoModel`]. It is used to instantiate a + MM Grounding DINO model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MM Grounding DINO tiny architecture + [openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_v3det](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_v3det). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `BertConfig`): + The config object or dictionary of the text backbone. + num_queries (`int`, *optional*, defaults to 900): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`MMGroundingDinoModel`] can detect in a single image. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (`str`, *optional*, defaults to `"sine"`): + Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. + num_feature_levels (`int`, *optional*, defaults to 4): + The number of input feature levels. + encoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the encoder. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + two_stage (`bool`, *optional*, defaults to `True`): + Whether to apply a two-stage deformable DETR, where the region proposals are also generated by a variant of + Grounding DINO, which are further fed into the decoder for iterative bounding box refinement. + class_cost (`float`, *optional*, defaults to 1.0): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + bbox_loss_coefficient (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + focal_alpha (`float`, *optional*, defaults to 0.25): + Alpha parameter in the focal loss. + disable_custom_kernels (`bool`, *optional*, defaults to `False`): + Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom + kernels are not supported by PyTorch ONNX export. + max_text_len (`int`, *optional*, defaults to 256): + The maximum length of the text input. + text_enhancer_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the text enhancer. + fusion_droppath (`float`, *optional*, defaults to 0.1): + The droppath ratio for the fusion module. + fusion_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the fusion module. + embedding_init_target (`bool`, *optional*, defaults to `True`): + Whether to initialize the target with Embedding weights. + query_dim (`int`, *optional*, defaults to 4): + The dimension of the query vector. + positional_embedding_temperature (`float`, *optional*, defaults to 20): + The temperature for Sine Positional Embedding that is used together with vision backbone. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + + Examples: + + ```python + >>> from transformers import MMGroundingDinoConfig, MMGroundingDinoModel + + >>> # Initializing a MM Grounding DINO configuration + >>> configuration = MMGroundingDinoConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = MMGroundingDinoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mm-grounding-dino" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + text_config=None, + num_queries=900, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + auxiliary_loss=False, + position_embedding_type="sine", + num_feature_levels=4, + encoder_n_points=4, + decoder_n_points=4, + two_stage=True, + class_cost=1.0, + bbox_cost=5.0, + giou_cost=2.0, + bbox_loss_coefficient=5.0, + giou_loss_coefficient=2.0, + focal_alpha=0.25, + disable_custom_kernels=False, + # other parameters + max_text_len=256, + text_enhancer_dropout=0.0, + fusion_droppath=0.1, + fusion_dropout=0.0, + embedding_init_target=True, + query_dim=4, + positional_embedding_temperature=20, + init_std=0.02, + layer_norm_eps=1e-5, + **kwargs, + ): + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.") + backbone_config = CONFIG_MAPPING["swin"]( + window_size=7, + image_size=224, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the text config with default values (`BertConfig`).") + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.num_queries = num_queries + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + # deformable attributes + self.num_feature_levels = num_feature_levels + self.encoder_n_points = encoder_n_points + self.decoder_n_points = decoder_n_points + self.two_stage = two_stage + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.focal_alpha = focal_alpha + self.disable_custom_kernels = disable_custom_kernels + # Text backbone + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "bert") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["bert"]() + + self.text_config = text_config + self.max_text_len = max_text_len + + # Text Enhancer + self.text_enhancer_dropout = text_enhancer_dropout + # Fusion + self.fusion_droppath = fusion_droppath + self.fusion_dropout = fusion_dropout + # Others + self.embedding_init_target = embedding_init_target + self.query_dim = query_dim + self.positional_embedding_temperature = positional_embedding_temperature + self.init_std = init_std + self.layer_norm_eps = layer_norm_eps + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + @property + def sub_configs(self): + sub_configs = {} + backbone_config = getattr(self, "backbone_config", None) + text_config = getattr(self, "text_config", None) + if isinstance(backbone_config, PretrainedConfig): + sub_configs["backbone_config"] = type(backbone_config) + if isinstance(text_config, PretrainedConfig): + sub_configs["text_config"] = type(self.text_config) + return sub_configs + + +__all__ = ["MMGroundingDinoConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d498323de4dc3b60dcf8c3c8ac02f25d8c3ed1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -0,0 +1,2605 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mm_grounding_dino.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...file_utils import ModelOutput, is_timm_available, requires_backends +from ...integrations import use_kernel_forward_from_hub +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import meshgrid +from ...utils import auto_docstring +from ...utils.backbone_utils import load_backbone +from ..auto.modeling_auto import AutoModel +from .configuration_mm_grounding_dino import MMGroundingDinoConfig + + +if is_timm_available(): + from timm import create_model + + +class MMGroundingDinoContrastiveEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.max_text_len = config.max_text_len + self.bias = nn.Parameter(torch.tensor(0.0)) + + def forward( + self, + vision_hidden_state: torch.FloatTensor, + text_hidden_state: torch.FloatTensor, + text_token_mask: torch.BoolTensor, + ) -> torch.FloatTensor: + res = vision_hidden_state @ text_hidden_state.transpose(-1, -2) + res = res / math.sqrt(vision_hidden_state.shape[-1]) + res = res + self.bias + res.masked_fill_(~text_token_mask[:, None, :], float("-inf")) + + # padding to max_text_len + new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device) + new_res[..., : res.shape[-1]] = res + + return new_res + + +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +class MultiScaleDeformableAttention(nn.Module): + def forward( + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: list[tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ): + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes_list): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class MMGroundingDinoLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, config): + super().__init__() + + embedding_dim = config.d_model // 2 + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) + + def forward(self, pixel_values, pixel_mask=None): + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + return pos + + +class MMGroundingDinoMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, config: MMGroundingDinoConfig, num_heads: int, n_points: int): + super().__init__() + + self.attn = MultiScaleDeformableAttention() + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in MMGroundingDinoMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.disable_custom_kernels = config.disable_custom_kernels + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + # Ignore copy + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + num_coordinates = reference_points.shape[-1] + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + + output = self.output_proj(output) + + return output, attention_weights + + +class MMGroundingDinoBiMultiHeadAttention(nn.Module): + def __init__(self, config): + super().__init__() + + vision_dim = text_dim = config.d_model + embed_dim = config.encoder_ffn_dim // 2 + num_heads = config.encoder_attention_heads // 2 + dropout = config.fusion_dropout + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.vision_dim = vision_dim + self.text_dim = text_dim + + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by `num_heads` (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) + self.scale = self.head_dim ** (-0.5) + self.dropout = dropout + + self.vision_proj = nn.Linear(self.vision_dim, self.embed_dim) + self.text_proj = nn.Linear(self.text_dim, self.embed_dim) + self.values_vision_proj = nn.Linear(self.vision_dim, self.embed_dim) + self.values_text_proj = nn.Linear(self.text_dim, self.embed_dim) + + self.out_vision_proj = nn.Linear(self.embed_dim, self.vision_dim) + self.out_text_proj = nn.Linear(self.embed_dim, self.text_dim) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + vision_features: torch.FloatTensor, + text_features: torch.FloatTensor, + vision_attention_mask: Optional[torch.BoolTensor] = None, + text_attention_mask: Optional[torch.BoolTensor] = None, + ) -> tuple[tuple[torch.FloatTensor, torch.FloatTensor], tuple[torch.FloatTensor, torch.FloatTensor]]: + """Image-to-text and text-to-image cross-attention + + Args: + vision_features (`torch.FloatTensor` of shape `(batch_size, vision_sequence_length, hidden_dim)`): + Projected flattened image features generated by the vision backbone. + text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_dim)`): + Projected text features generated by the text encoder. + vision_attention_mask (`torch.BoolTensor`, **optional**): + Attention mask for image-to-text cross-attention. False for real tokens and True for padding tokens. + text_attention_mask (`torch.BoolTensor`, **optional**): + Attention mask for text-to-image cross-attention. False for real tokens and True for padding tokens. + + Returns: + `tuple(tuple(torch.FloatTensor), tuple(torch.FloatTensor))` where each inner tuple comprises an attention + output and weights: + - **vision_attn_output** (`torch.FloatTensor` of shape `(batch_size, vision_sequence_length, hidden_din)`) + -- + Output of the image-to-text cross-attention layer. + - **vision_attn_weights** (`torch.FloatTensor` of shape `(batch_size, num_heads, vision_sequence_length, + vision_sequence_length)`) -- + Attention weights of the image-to-text cross-attention layer. + - **text_attn_output** (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_dim)`) -- + Output of the text-to-image cross-attention layer. + - **text_attn_weights** (`torch.FloatTensor` of shape `(batch_size, num_heads, text_sequence_length, + text_sequence_length)`) -- + Attention weights of the text-to-image cross-attention layer. + """ + batch_size, tgt_len, _ = vision_features.size() + + vision_query_states = self.vision_proj(vision_features) * self.scale + vision_query_states = self._reshape(vision_query_states, tgt_len, batch_size) + + text_key_states = self.text_proj(text_features) + text_key_states = self._reshape(text_key_states, -1, batch_size) + + vision_value_states = self.values_vision_proj(vision_features) + vision_value_states = self._reshape(vision_value_states, -1, batch_size) + + text_value_states = self.values_text_proj(text_features) + text_value_states = self._reshape(text_value_states, -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + + vision_query_states = vision_query_states.view(*proj_shape) + text_key_states = text_key_states.view(*proj_shape) + vision_value_states = vision_value_states.view(*proj_shape) + text_value_states = text_value_states.view(*proj_shape) + + src_len = text_key_states.size(1) + attn_weights = torch.bmm(vision_query_states, text_key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt + + if attn_weights.size() != (batch_size * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + attn_weights = attn_weights - attn_weights.max() + # Do not increase -50000/50000, data type half has quite limited range + attn_weights = torch.clamp(attn_weights, min=-50000, max=50000) + + attn_weights_transposed = attn_weights.transpose(1, 2) + text_attn_weights = attn_weights_transposed - torch.max(attn_weights_transposed, dim=-1, keepdim=True)[0] + + # Do not increase -50000/50000, data type half has quite limited range + text_attn_weights = torch.clamp(text_attn_weights, min=-50000, max=50000) + + # mask vision for language + if vision_attention_mask is not None: + vision_attention_mask = ( + vision_attention_mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) + ) + text_attn_weights.masked_fill_(vision_attention_mask, float("-inf")) + + text_attn_weights = text_attn_weights.softmax(dim=-1) + + # mask language for vision + if text_attention_mask is not None: + text_attention_mask = text_attention_mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) + attn_weights.masked_fill_(text_attention_mask, float("-inf")) + vision_attn_weights = attn_weights.softmax(dim=-1) + + vision_attn_probs = F.dropout(vision_attn_weights, p=self.dropout, training=self.training) + text_attn_probs = F.dropout(text_attn_weights, p=self.dropout, training=self.training) + + vision_attn_output = torch.bmm(vision_attn_probs, text_value_states) + text_attn_output = torch.bmm(text_attn_probs, vision_value_states) + + if vision_attn_output.size() != (batch_size * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`vision_attn_output` should be of size {(batch_size, self.num_heads, tgt_len, self.head_dim)}, but is {vision_attn_output.size()}" + ) + + if text_attn_output.size() != (batch_size * self.num_heads, src_len, self.head_dim): + raise ValueError( + f"`text_attn_output` should be of size {(batch_size, self.num_heads, src_len, self.head_dim)}, but is {text_attn_output.size()}" + ) + + vision_attn_output = vision_attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim) + vision_attn_output = vision_attn_output.transpose(1, 2) + vision_attn_output = vision_attn_output.reshape(batch_size, tgt_len, self.embed_dim) + + text_attn_output = text_attn_output.view(batch_size, self.num_heads, src_len, self.head_dim) + text_attn_output = text_attn_output.transpose(1, 2) + text_attn_output = text_attn_output.reshape(batch_size, src_len, self.embed_dim) + + vision_attn_output = self.out_vision_proj(vision_attn_output) + text_attn_output = self.out_text_proj(text_attn_output) + + return (vision_attn_output, vision_attn_weights), (text_attn_output, text_attn_weights) + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class MMGroundingDinoDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class MMGroundingDinoFusionLayer(nn.Module): + def __init__(self, config): + super().__init__() + drop_path = config.fusion_droppath + + # pre layer norm + self.layer_norm_vision = nn.LayerNorm(config.d_model, config.layer_norm_eps) + self.layer_norm_text = nn.LayerNorm(config.d_model, config.layer_norm_eps) + self.attn = MMGroundingDinoBiMultiHeadAttention(config) + + # add layer scale for training stability + self.drop_path = MMGroundingDinoDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + init_values = 1e-4 + self.vision_param = nn.Parameter(init_values * torch.ones(config.d_model), requires_grad=True) + self.text_param = nn.Parameter(init_values * torch.ones(config.d_model), requires_grad=True) + + def forward( + self, + vision_features: torch.FloatTensor, + text_features: torch.FloatTensor, + attention_mask_vision: Optional[torch.BoolTensor] = None, + attention_mask_text: Optional[torch.BoolTensor] = None, + ) -> tuple[tuple[torch.FloatTensor, torch.FloatTensor], tuple[torch.FloatTensor, torch.FloatTensor]]: + """Image and text features fusion + + Args: + vision_features (`torch.FloatTensor` of shape `(batch_size, vision_sequence_length, hidden_dim)`): + Projected flattened image features generated by the vision backbone. + text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_dim)`): + Projected text features generated by the text encoder. + attention_mask_vision (`torch.BoolTensor`, **optional**): + Attention mask for image-to-text cross-attention. False for real tokens and True for padding tokens. + attention_mask_text (`torch.BoolTensor`, **optional**): + Attention mask for text-to-image cross-attention. False for real tokens and True for padding tokens. + + Returns: + `tuple(tuple(torch.FloatTensor), tuple(torch.FloatTensor))` where each inner tuple comprises an enhanced + feature and attention output and weights: + - **vision_features** (`torch.FloatTensor` of shape `(batch_size, vision_sequence_length, vision_dim)`) -- + Updated vision features with attention output from image-to-text cross-attention layer. + - **vision_attn_weights** (`torch.FloatTensor` of shape `(batch_size, num_heads, vision_sequence_length, + vision_sequence_length)`) -- + Attention weights of the image-to-text cross-attention layer. + - **text_features** (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, text_dim)`) -- + Updated text features with attention output from text-to-image cross-attention layer. + - **text_attn_weights** (`torch.FloatTensor` of shape `(batch_size, num_heads, text_sequence_length, + text_sequence_length)`) -- + Attention weights of the text-to-image cross-attention layer. + """ + vision_features = self.layer_norm_vision(vision_features) + text_features = self.layer_norm_text(text_features) + (delta_v, vision_attn), (delta_t, text_attn) = self.attn( + vision_features, + text_features, + vision_attention_mask=attention_mask_vision, + text_attention_mask=attention_mask_text, + ) + vision_features = vision_features + self.drop_path(self.vision_param * delta_v) + text_features = text_features + self.drop_path(self.text_param * delta_t) + + return (vision_features, vision_attn), (text_features, text_attn) + + +@auto_docstring +class MMGroundingDinoPreTrainedModel(PreTrainedModel): + config: MMGroundingDinoConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module): + std = self.config.init_std + + if isinstance(module, MMGroundingDinoLearnedPositionEmbedding): + nn.init.uniform_(module.row_embeddings.weight) + nn.init.uniform_(module.column_embeddings.weight) + elif isinstance(module, MMGroundingDinoMultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + elif isinstance(module, MMGroundingDinoBiMultiHeadAttention): + nn.init.xavier_uniform_(module.vision_proj.weight) + module.vision_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(module.text_proj.weight) + module.text_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(module.values_vision_proj.weight) + module.values_vision_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(module.values_text_proj.weight) + module.values_text_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(module.out_vision_proj.weight) + module.out_vision_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(module.out_text_proj.weight) + module.out_text_proj.bias.data.fill_(0) + elif isinstance(module, MMGroundingDinoFusionLayer): + module.vision_param.data.fill_(1e-4) + module.text_param.data.fill_(1e-4) + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MMGroundingDinoMLPPredictionHead): + nn.init.constant_(module.layers[-1].weight.data, 0) + nn.init.constant_(module.layers[-1].bias.data, 0) + + if hasattr(module, "reference_points") and not self.config.two_stage: + nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) + nn.init.constant_(module.reference_points.bias.data, 0.0) + if hasattr(module, "level_embed"): + nn.init.normal_(module.level_embed) + if isinstance(module, MMGroundingDinoContrastiveEmbedding): + nn.init.constant_(module.bias, -math.log((1 - 0.01) / 0.01)) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MMGroundingDinoDecoder): + module.gradient_checkpointing = value + + +class MMGroundingDinoFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `MMGroundingDinoFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = MMGroundingDinoFrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class MMGroundingDinoConvEncoder(nn.Module): + """ + Convolutional backbone, using either the AutoBackbone API or one from the timm library. + + nn.BatchNorm2d layers are replaced by MMGroundingDinoFrozenBatchNorm2d as defined above. + + """ + + def __init__(self, config): + super().__init__() + + self.config = config + + if config.use_timm_backbone: + requires_backends(self, ["timm"]) + backbone = create_model( + config.backbone, + pretrained=config.use_pretrained_backbone, + features_only=True, + **config.backbone_kwargs, + ) + else: + backbone = load_backbone(config) + + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = ( + self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels + ) + + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + + if "resnet" in backbone_model_type: + for name, parameter in self.model.named_parameters(): + if config.use_timm_backbone: + if "layer2" not in name and "layer3" not in name and "layer4" not in name: + parameter.requires_grad_(False) + else: + if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name: + parameter.requires_grad_(False) + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +class MMGroundingDinoConvModel(nn.Module): + """ + This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. + """ + + def __init__(self, conv_encoder, position_embedding): + super().__init__() + self.conv_encoder = conv_encoder + self.position_embedding = position_embedding + + def forward(self, pixel_values, pixel_mask): + # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples + out = self.conv_encoder(pixel_values, pixel_mask) + pos = [] + for feature_map, mask in out: + # position encoding + pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) + + return out, pos + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the MMGroundingDinoEncoder. This class extends BaseModelOutput, due to: + - vision and text last hidden states + - vision and text intermediate hidden states + """ +) +class MMGroundingDinoEncoderOutput(ModelOutput): + r""" + last_hidden_state_vision (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the vision encoder. + last_hidden_state_text (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the text encoder. + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the vision embeddings + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the vision encoder at the + output of each layer plus the initial embedding outputs. + text_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the text embeddings + one for the output of each layer) + of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the text encoder at the output of + each layer plus the initial embedding outputs. + """ + + last_hidden_state_vision: Optional[torch.FloatTensor] = None + last_hidden_state_text: Optional[torch.FloatTensor] = None + vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None + text_hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + + +class MMGroundingDinoMultiheadAttention(nn.Module): + """Equivalent implementation of nn.MultiheadAttention with `batch_first=True`.""" + + def __init__(self, config, num_attention_heads=None): + super().__init__() + if config.hidden_size % num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({num_attention_heads})" + ) + + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(config.hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + queries: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = queries.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in MMGroundingDinoModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class MMGroundingDinoTextEnhancerLayer(nn.Module): + """Vanilla Transformer with text embeddings as input""" + + def __init__(self, config): + super().__init__() + self.self_attn = MMGroundingDinoMultiheadAttention( + config, num_attention_heads=config.encoder_attention_heads // 2 + ) + + # Implementation of Feedforward model + self.fc1 = nn.Linear(config.d_model, config.encoder_ffn_dim // 2) + self.fc2 = nn.Linear(config.encoder_ffn_dim // 2, config.d_model) + + self.layer_norm_before = nn.LayerNorm(config.d_model, config.layer_norm_eps) + self.layer_norm_after = nn.LayerNorm(config.d_model, config.layer_norm_eps) + + self.activation = ACT2FN[config.activation_function] + self.num_heads = config.encoder_attention_heads // 2 + self.dropout = config.text_enhancer_dropout + + def with_pos_embed(self, hidden_state: Tensor, position_embeddings: Optional[Tensor]): + return hidden_state if position_embeddings is None else hidden_state + position_embeddings + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_masks: Optional[torch.BoolTensor] = None, + position_embeddings: Optional[torch.FloatTensor] = None, + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: + """Text self-attention to enhance projection of text features generated by + the text encoder (AutoModel based on text_config) within MMGroundingDinoEncoderLayer + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`): + Text features generated by the text encoder. + attention_masks (`torch.BoolTensor`, *optional*): + Attention mask for text self-attention. False for real tokens and True for padding tokens. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings to be added to the hidden states. + + Returns: + `tuple(torch.FloatTensor)` comprising two elements: + - **hidden_states** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) -- + Output of the text self-attention layer. + - **attention_weights** (`torch.FloatTensor` of shape `(batch_size, num_heads, sequence_length, + sequence_length)`) -- + Attention weights of the text self-attention layer. + """ + + # repeat attn mask + if attention_masks.dim() == 3 and attention_masks.shape[0] == hidden_states.shape[0]: + # batch_size, num_queries, num_keys + attention_masks = attention_masks[:, None, :, :] + attention_masks = attention_masks.repeat(1, self.num_heads, 1, 1) + + dtype = hidden_states.dtype + attention_masks = attention_masks.to(dtype=dtype) # fp16 compatibility + attention_masks = (1.0 - attention_masks) * torch.finfo(dtype).min + + queries = keys = self.with_pos_embed(hidden_states, position_embeddings) + attention_output, attention_weights = self.self_attn( + queries=queries, + keys=keys, + values=hidden_states, + attention_mask=attention_masks, + output_attentions=True, + ) + attention_output = nn.functional.dropout(attention_output, p=self.dropout, training=self.training) + hidden_states = hidden_states + attention_output + hidden_states = self.layer_norm_before(hidden_states) + + residual = hidden_states + hidden_states = self.activation(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = hidden_states + residual + hidden_states = self.layer_norm_after(hidden_states) + + return hidden_states, attention_weights + + +class MMGroundingDinoDeformableLayer(nn.Module): + def __init__(self, config: MMGroundingDinoConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = MMGroundingDinoMultiscaleDeformableAttention( + config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Input to the layer. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Attention mask. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings, to be added to `hidden_states`. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes of the backbone feature maps. + spatial_shapes_list (`list[tuple[int, int]]`, *optional*): + Spatial shapes of the backbone feature maps (but as list for export compatibility). + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states, attn_weights + + +# Based on https://github.com/IDEA-Research/MMGroundingDino/blob/2b62f419c292ca9c518daae55512fabc3fead4a4/MMGroundingDino/models/MMGroundingDino/utils.py#L24 +def get_sine_pos_embed( + pos_tensor: torch.Tensor, num_pos_feats: int = 128, temperature: int = 10000, exchange_xy: bool = True +) -> Tensor: + """ + Generate sine position embeddings from a position tensor. + + Args: + pos_tensor (torch.Tensor): + Tensor containing positions. Shape: [..., n]. + num_pos_feats (`int`, *optional*, defaults to 128): + Projected shape for each float in the tensor. + temperature (`int`, *optional*, defaults to 10000): + Temperature in the sine/cosine function. + exchange_xy (`bool`, *optional*, defaults to `True`): + Exchange pos x and pos y. For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. + + Returns: + position_embeddings (torch.Tensor): shape: [..., n * hidden_size]. + """ + scale = 2 * math.pi + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) + + def sine_func(x: torch.Tensor): + sin_x = x * scale / dim_t + sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2) + return sin_x + + pos_tensor = pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1) + position_embeddings = [sine_func(x) for x in pos_tensor] + if exchange_xy: + position_embeddings[0], position_embeddings[1] = position_embeddings[1], position_embeddings[0] + position_embeddings = torch.cat(position_embeddings, dim=-1) + return position_embeddings + + +class MMGroundingDinoEncoderLayer(nn.Module): + def __init__(self, config) -> None: + super().__init__() + + self.d_model = config.d_model + + self.text_enhancer_layer = MMGroundingDinoTextEnhancerLayer(config) + self.fusion_layer = MMGroundingDinoFusionLayer(config) + self.deformable_layer = MMGroundingDinoDeformableLayer(config) + + def get_text_position_embeddings( + self, + text_features: Tensor, + text_position_embedding: Optional[torch.Tensor], + text_position_ids: Optional[torch.Tensor], + ) -> Tensor: + batch_size, seq_length, _ = text_features.shape + if text_position_embedding is None and text_position_ids is None: + text_position_embedding = torch.arange(seq_length, device=text_features.device) + text_position_embedding = text_position_embedding.float() + text_position_embedding = text_position_embedding.unsqueeze(0).unsqueeze(-1) + text_position_embedding = text_position_embedding.repeat(batch_size, 1, 1) + text_position_embedding = get_sine_pos_embed( + text_position_embedding, num_pos_feats=self.d_model, exchange_xy=False + ) + if text_position_ids is not None: + text_position_embedding = get_sine_pos_embed( + text_position_ids[..., None], num_pos_feats=self.d_model, exchange_xy=False + ) + + return text_position_embedding + + def forward( + self, + vision_features: Tensor, + vision_position_embedding: Tensor, + spatial_shapes: Tensor, + spatial_shapes_list: list[tuple[int, int]], + level_start_index: Tensor, + key_padding_mask: Tensor, + reference_points: Tensor, + text_features: Optional[Tensor] = None, + text_attention_mask: Optional[Tensor] = None, + text_position_embedding: Optional[Tensor] = None, + text_self_attention_masks: Optional[Tensor] = None, + text_position_ids: Optional[Tensor] = None, + ): + text_position_embedding = self.get_text_position_embeddings( + text_features, text_position_embedding, text_position_ids + ) + + (vision_features, vision_fused_attn), (text_features, text_fused_attn) = self.fusion_layer( + vision_features=vision_features, + text_features=text_features, + attention_mask_vision=key_padding_mask, + attention_mask_text=text_attention_mask, + ) + + (text_features, text_enhanced_attn) = self.text_enhancer_layer( + hidden_states=text_features, + attention_masks=~text_self_attention_masks, # note we use ~ for mask here + position_embeddings=(text_position_embedding if text_position_embedding is not None else None), + ) + + (vision_features, vision_deformable_attn) = self.deformable_layer( + hidden_states=vision_features, + attention_mask=~key_padding_mask, + position_embeddings=vision_position_embedding, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + ) + + return ( + (vision_features, text_features), + (vision_fused_attn, text_fused_attn, text_enhanced_attn, vision_deformable_attn), + ) + + +class MMGroundingDinoEncoder(MMGroundingDinoPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a + [`MMGroundingDinoEncoderLayer`]. + + The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers. + + Args: + config: MMGroundingDinoConfig + """ + + def __init__(self, config: MMGroundingDinoConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([MMGroundingDinoEncoderLayer(config) for _ in range(config.encoder_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """ + Get reference points for each feature map. + + Args: + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Valid ratios of each feature map. + device (`torch.device`): + Device on which to create the tensors. + Returns: + `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` + """ + reference_points_list = [] + for level, (height, width) in enumerate(spatial_shapes): + ref_y, ref_x = meshgrid( + torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), + indexing="ij", + ) + # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36 + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + vision_features: Tensor, + vision_attention_mask: Tensor, + vision_position_embedding: Tensor, + spatial_shapes: Tensor, + spatial_shapes_list: list[tuple[int, int]], + level_start_index: Tensor, + valid_ratios=None, + text_features: Optional[Tensor] = None, + text_attention_mask: Optional[Tensor] = None, + text_position_embedding: Optional[Tensor] = None, + text_self_attention_masks: Optional[Tensor] = None, + text_position_ids: Optional[Tensor] = None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + vision_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + vision_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 0 for pixel features that are real (i.e. **not masked**), + - 1 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + vision_position_embedding (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + spatial_shapes_list (`list[tuple[int, int]]`): + Spatial shapes of each feature map (but as list for export compatibility). + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`): + Flattened text features that are passed to the encoder. + text_attention_mask (`torch.Tensor` of shape `(batch_size, text_seq_len)`, *optional*): + Mask to avoid performing attention on padding text features. Mask values selected in `[0, 1]`: + - 0 for text features that are real (i.e. **not masked**), + - 1 for text features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + text_position_embedding (`torch.FloatTensor` of shape `(batch_size, text_seq_len)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + text_self_attention_masks (`torch.BoolTensor` of shape `(batch_size, text_seq_len, text_seq_len)`): + Masks to avoid performing attention between padding text features. Mask values selected in `[0, 1]`: + - 1 for text features that are real (i.e. **not masked**), + - 0 for text features that are padding (i.e. **masked**). + text_position_ids (`torch.LongTensor` of shape `(batch_size, num_queries)`): + Position ids for text features. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=vision_features.device) + + encoder_vision_states = () if output_hidden_states else None + encoder_text_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_attn_fused_text = () if output_attentions else None + all_attn_fused_vision = () if output_attentions else None + all_attn_enhanced_text = () if output_attentions else None + all_attn_deformable = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_vision_states += (vision_features,) + encoder_text_states += (text_features,) + + (vision_features, text_features), attentions = encoder_layer( + vision_features=vision_features, + vision_position_embedding=vision_position_embedding, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + key_padding_mask=vision_attention_mask, + reference_points=reference_points, + text_features=text_features, + text_attention_mask=text_attention_mask, + text_position_embedding=text_position_embedding, + text_self_attention_masks=text_self_attention_masks, + text_position_ids=text_position_ids, + ) + + if output_attentions: + all_attn_fused_vision += (attentions[0],) + all_attn_fused_text += (attentions[1],) + all_attn_enhanced_text += (attentions[2],) + all_attn_deformable += (attentions[3],) + + if output_hidden_states: + encoder_vision_states += (vision_features,) + encoder_text_states += (text_features,) + + if output_attentions: + all_attns = (all_attn_fused_vision, all_attn_fused_text, all_attn_enhanced_text, all_attn_deformable) + + if not return_dict: + enc_outputs = [vision_features, text_features, encoder_vision_states, encoder_text_states, all_attns] + return tuple(v for v in enc_outputs if v is not None) + return MMGroundingDinoEncoderOutput( + last_hidden_state_vision=vision_features, + last_hidden_state_text=text_features, + vision_hidden_states=encoder_vision_states, + text_hidden_states=encoder_text_states, + attentions=all_attns, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the MMGroundingDinoDecoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + """ +) +class MMGroundingDinoDecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + + +class MMGroundingDinoDecoderLayer(nn.Module): + def __init__(self, config: MMGroundingDinoConfig): + super().__init__() + self.embed_dim = config.d_model + + # self-attention + self.self_attn = MMGroundingDinoMultiheadAttention(config, num_attention_heads=config.decoder_attention_heads) + + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, config.layer_norm_eps) + # cross-attention text + self.encoder_attn_text = MMGroundingDinoMultiheadAttention( + config, num_attention_heads=config.decoder_attention_heads + ) + self.encoder_attn_text_layer_norm = nn.LayerNorm(self.embed_dim, config.layer_norm_eps) + # cross-attention + self.encoder_attn = MMGroundingDinoMultiscaleDeformableAttention( + config, + num_heads=config.decoder_attention_heads, + n_points=config.decoder_n_points, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, config.layer_norm_eps) + # feedforward neural networks + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, config.layer_norm_eps) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + vision_encoder_hidden_states: Optional[torch.Tensor] = None, + vision_encoder_attention_mask: Optional[torch.Tensor] = None, + text_encoder_hidden_states: Optional[torch.Tensor] = None, + text_encoder_attention_mask: Optional[torch.Tensor] = None, + self_attn_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + residual = hidden_states + + # Self Attention + queries = keys = self.with_pos_embed(hidden_states, position_embeddings) + hidden_states, self_attn_weights = self.self_attn( + queries=queries, + keys=keys, + values=hidden_states, + attention_mask=self_attn_mask, + output_attentions=True, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention Text + queries = self.with_pos_embed(hidden_states, position_embeddings) + hidden_states, text_cross_attn_weights = self.encoder_attn_text( + queries=queries, + keys=text_encoder_hidden_states, + values=text_encoder_hidden_states, + attention_mask=text_encoder_attention_mask, + output_attentions=True, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = second_residual + hidden_states + hidden_states = self.encoder_attn_text_layer_norm(hidden_states) + + third_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + attention_mask=vision_encoder_attention_mask, + encoder_hidden_states=vision_encoder_hidden_states, + encoder_attention_mask=vision_encoder_attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = third_residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, text_cross_attn_weights, cross_attn_weights) + + return outputs + + +class MMGroundingDinoDecoder(MMGroundingDinoPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MMGroundingDinoDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some tweaks for Grounding DINO: + + - `position_embeddings`, `reference_points`, `spatial_shapes` and `valid_ratios` are added to the forward pass. + - it also returns a stack of intermediate outputs and reference points from all decoding layers. + + Args: + config: MMGroundingDinoConfig + """ + + def __init__(self, config: MMGroundingDinoConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layer_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps) + self.layers = nn.ModuleList([MMGroundingDinoDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.reference_points_head = MMGroundingDinoMLPPredictionHead( + config.query_dim // 2 * config.d_model, config.d_model, config.d_model, 2 + ) + self.gradient_checkpointing = False + + # hack implementation for iterative bounding box refinement as in two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.query_scale = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds, + vision_encoder_hidden_states, + vision_encoder_attention_mask=None, + text_encoder_hidden_states=None, + text_encoder_attention_mask=None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + valid_ratios=None, + self_attn_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + vision_encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Last hidden state from encoder related to vision feature map. + vision_encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + text_encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`): + Last hidden state from encoder related to text features. + text_encoder_attention_mask (`torch.Tensor` of shape `(batch_size, text_seq_len)`, *optional*): + Mask to avoid performing attention on padding text features. Mask values selected in `[0, 1]`: + - 0 for text features that are real (i.e. **not masked**), + - 1 for text features that are padding (i.e. **masked**). + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + spatial_shapes_list (`list[tuple[int, int]]`): + Spatial shapes of the feature maps (but as list for export compatibility). + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + self_attn_mask (`torch.BoolTensor` of shape `(batch_size, text_seq_len)`): + Masks to avoid performing self-attention between vision hidden state. Mask values selected in `[0, 1]`: + - 1 for queries that are real (i.e. **not masked**), + - 0 for queries that are padding (i.e. **masked**). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_attns = () if output_attentions else None + all_cross_attns_vision = () if (output_attentions and vision_encoder_hidden_states is not None) else None + all_cross_attns_text = () if (output_attentions and text_encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + + if text_encoder_attention_mask is not None: + dtype = text_encoder_hidden_states.dtype + + text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :] + text_encoder_attention_mask = text_encoder_attention_mask.repeat( + 1, self.config.decoder_attention_heads, self.config.num_queries, 1 + ) + text_encoder_attention_mask = text_encoder_attention_mask.to(dtype=dtype) + text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(dtype).min + + for idx, decoder_layer in enumerate(self.layers): + num_coordinates = reference_points.shape[-1] + if num_coordinates == 4: + reference_points_input = ( + reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + ) + elif num_coordinates == 2: + reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] + else: + raise ValueError("Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + query_pos = get_sine_pos_embed(reference_points_input[:, :, 0, :], num_pos_feats=self.config.d_model // 2) + query_pos = self.reference_points_head(query_pos) + + # In original implementation they apply layer norm before outputting intermediate hidden states + # Though that's not through between layers so the layers use as input the output of the previous layer + # without layer norm + if output_hidden_states: + all_hidden_states += (self.layer_norm(hidden_states),) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + query_pos, + reference_points_input, + spatial_shapes, + level_start_index, + vision_encoder_hidden_states, + vision_encoder_attention_mask, + text_encoder_hidden_states, + text_encoder_attention_mask, + self_attn_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states=hidden_states, + position_embeddings=query_pos, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + vision_encoder_hidden_states=vision_encoder_hidden_states, + vision_encoder_attention_mask=vision_encoder_attention_mask, + text_encoder_hidden_states=text_encoder_hidden_states, + text_encoder_attention_mask=text_encoder_attention_mask, + self_attn_mask=self_attn_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[idx](hidden_states) + num_coordinates = reference_points.shape[-1] + if num_coordinates == 4: + new_reference_points = tmp + torch.special.logit(reference_points, eps=1e-5) + new_reference_points = new_reference_points.sigmoid() + elif num_coordinates == 2: + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + torch.special.logit(reference_points, eps=1e-5) + new_reference_points = new_reference_points.sigmoid() + else: + raise ValueError( + f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}" + ) + reference_points = new_reference_points.detach() + + intermediate += (self.layer_norm(hidden_states),) + intermediate_reference_points += (reference_points,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if text_encoder_hidden_states is not None: + all_cross_attns_text += (layer_outputs[2],) + + if vision_encoder_hidden_states is not None: + all_cross_attns_vision += (layer_outputs[3],) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if output_attentions: + all_attns += (all_self_attns, all_cross_attns_text, all_cross_attns_vision) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_reference_points, + all_hidden_states, + all_attns, + ] + if v is not None + ) + return MMGroundingDinoDecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_reference_points=intermediate_reference_points, + hidden_states=all_hidden_states, + attentions=all_attns, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Grounding DINO encoder-decoder model. + """ +) +class MMGroundingDinoModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + encoder_last_hidden_state_vision (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_last_hidden_state_text (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the vision embeddings + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the vision encoder at the + output of each layer plus the initial embedding outputs. + encoder_text_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the text embeddings + one for the output of each layer) + of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the text encoder at the output of + each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the text-vision attention, vision-text attention, text-enhancer (self-attention) and + multi-scale deformable attention heads. attention softmax, used to compute the weighted average in the + bi-attention heads. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.num_queries` scoring bounding boxes are picked as + region proposals in the first stage. Output of bounding box binary classification (i.e. foreground and + background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + encoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): + Logits of top `config.num_queries` scoring bounding boxes in the first stage. + encoder_pred_boxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Coordinates of top `config.num_queries` scoring bounding boxes in the first stage. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + init_reference_points: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + encoder_last_hidden_state_vision: Optional[torch.FloatTensor] = None + encoder_last_hidden_state_text: Optional[torch.FloatTensor] = None + encoder_vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_text_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + encoder_logits: Optional[torch.FloatTensor] = None + encoder_pred_boxes: Optional[torch.FloatTensor] = None + + +class MMGroundingDinoSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__(self, config): + super().__init__() + self.embedding_dim = config.d_model // 2 + self.temperature = config.positional_embedding_temperature + self.scale = 2 * math.pi + + def forward(self, pixel_values, pixel_mask): + y_embed = pixel_mask.cumsum(1, dtype=torch.float32) + x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +def build_position_encoding(config): + if config.position_embedding_type == "sine": + position_embedding = MMGroundingDinoSinePositionEmbedding(config) + elif config.position_embedding_type == "learned": + position_embedding = MMGroundingDinoLearnedPositionEmbedding(config) + else: + raise ValueError(f"Not supported {config.position_embedding_type}") + + return position_embedding + + +# these correspond to [CLS], [SEP], . and ? +SPECIAL_TOKENS = [101, 102, 1012, 1029] + + +def generate_masks_with_special_tokens_and_transfer_map(input_ids: torch.LongTensor) -> tuple[Tensor, Tensor]: + """Generate attention mask between each pair of special tokens and positional ids. + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + Returns: + `tuple(torch.Tensor)` comprising attention mask between each special tokens and position_ids: + - **attention_mask** (`torch.BoolTensor` of shape `(batch_size, sequence_length, sequence_length)`) + - **position_ids** (`torch.LongTensor` of shape `(batch_size, sequence_length)`) + """ + batch_size, num_token = input_ids.shape + # special_tokens_mask: batch_size, num_token. 1 for special tokens. 0 for normal tokens + special_tokens_mask = torch.zeros((batch_size, num_token), device=input_ids.device).bool() + for special_token in SPECIAL_TOKENS: + special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token) + + # idxs: each row is a list of indices of special tokens + idxs = torch.nonzero(special_tokens_mask) + + # generate attention mask and positional ids + attention_mask = torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(batch_size, 1, 1) + position_ids = torch.zeros((batch_size, num_token), device=input_ids.device) + previous_col = 0 + for i in range(idxs.shape[0]): + row, col = idxs[i] + if (col == 0) or (col == num_token - 1): + attention_mask[row, col, col] = True + position_ids[row, col] = 0 + else: + attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True + position_ids[row, previous_col + 1 : col + 1] = torch.arange( + 0, col - previous_col, device=input_ids.device + ) + + previous_col = col + + return attention_mask, position_ids.to(torch.long) + + +@auto_docstring( + custom_intro=""" + The bare Grounding DINO Model (consisting of a backbone and encoder-decoder Transformer) outputting raw + hidden-states without any specific head on top. + """ +) +class MMGroundingDinoModel(MMGroundingDinoPreTrainedModel): + def __init__(self, config: MMGroundingDinoConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = MMGroundingDinoConvEncoder(config) + position_embeddings = build_position_encoding(config) + self.backbone = MMGroundingDinoConvModel(backbone, position_embeddings) + + # Create input projection layers + num_backbone_outs = len(backbone.intermediate_channel_sizes) + input_proj_list = [] + for i in range(num_backbone_outs): + in_channels = backbone.intermediate_channel_sizes[i] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1), + nn.GroupNorm(32, config.d_model), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, config.d_model), + ) + ) + in_channels = config.d_model + self.input_proj_vision = nn.ModuleList(input_proj_list) + + # Create text backbone + self.text_backbone = AutoModel.from_config(config.text_config, add_pooling_layer=False) + self.text_projection = nn.Linear(config.text_config.hidden_size, config.d_model) + + if config.embedding_init_target or not config.two_stage: + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = MMGroundingDinoEncoder(config) + self.decoder = MMGroundingDinoDecoder(config) + + self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model)) + + self.enc_output = nn.Linear(config.d_model, config.d_model) + self.enc_output_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps) + self.encoder_output_bbox_embed = MMGroundingDinoMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + self.encoder_output_class_embed = MMGroundingDinoContrastiveEmbedding(config) + + self.post_init() + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + def get_valid_ratio(self, mask): + """Get the valid ratio of all feature maps.""" + + _, height, width = mask.shape + valid_height = torch.sum(mask[:, :, 0], 1) + valid_width = torch.sum(mask[:, 0, :], 1) + valid_ratio_height = valid_height.float() / height + valid_ratio_width = valid_width.float() / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) + return valid_ratio + + def generate_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): + """Generate the encoder output proposals from encoded enc_output. + + Args: + enc_output (`torch.Tensor[batch_size, sequence_length, hidden_size]`): Output of the encoder. + padding_mask (`torch.Tensor[batch_size, sequence_length]`): Padding mask for `enc_output`. + spatial_shapes (`torch.Tensor[num_feature_levels, 2]`): Spatial shapes of the feature maps. + + Returns: + `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. + - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to + directly predict a bounding box. (without the need of a decoder) + - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse + sigmoid. + """ + batch_size = enc_output.shape[0] + proposals = [] + current_position = 0 + for level, (height, width) in enumerate(spatial_shapes): + mask_flatten_ = padding_mask[:, current_position : (current_position + height * width)] + mask_flatten_ = mask_flatten_.view(batch_size, height, width, 1) + valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = meshgrid( + torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device), + torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device), + indexing="ij", + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale + width_height = torch.ones_like(grid) * 0.05 * (2.0**level) + proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4) + proposals.append(proposal) + current_position += height * width + + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) # inverse sigmoid + output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + # assign each pixel as an object query + object_query = enc_output + object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0)) + object_query = object_query.masked_fill(~output_proposals_valid, float(0)) + object_query = self.enc_output_norm(self.enc_output(object_query)) + return object_query, output_proposals + + @auto_docstring + def forward( + self, + pixel_values: Tensor, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + pixel_mask: Optional[Tensor] = None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`BertTokenizer.__call__`] for details. + token_type_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: 0 corresponds to a `sentence A` token, 1 corresponds to a `sentence B` token + + [What are token type IDs?](../glossary#token-type-ids) + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "a cat." + + >>> processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") + >>> model = AutoModel.from_pretrained("IDEA-Research/grounding-dino-tiny") + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 900, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + text_token_mask = attention_mask.bool() # just to avoid renaming everywhere + + max_text_len = self.config.max_text_len + if text_self_attention_masks.shape[1] > max_text_len: + text_self_attention_masks = text_self_attention_masks[:, :max_text_len, :max_text_len] + position_ids = position_ids[:, :max_text_len] + input_ids = input_ids[:, :max_text_len] + token_type_ids = token_type_ids[:, :max_text_len] + text_token_mask = text_token_mask[:, :max_text_len] + + # Extract text features from text backbone + text_outputs = self.text_backbone( + input_ids, text_self_attention_masks, token_type_ids, position_ids, return_dict=return_dict + ) + text_features = text_outputs.last_hidden_state if return_dict else text_outputs[0] + text_features = self.text_projection(text_features) + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device) + + # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper) + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # which is a list of tuples + vision_features, position_embeddings_list = self.backbone(pixel_values, pixel_mask) + + # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + feature_maps = [] + masks = [] + for level, (source, mask) in enumerate(vision_features): + feature_maps.append(self.input_proj_vision[level](source)) + masks.append(mask) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(feature_maps): + _len_sources = len(feature_maps) + for level in range(_len_sources, self.config.num_feature_levels): + if level == _len_sources: + source = self.input_proj_vision[level](vision_features[-1][0]) + else: + source = self.input_proj_vision[level](feature_maps[-1]) + mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone.position_embedding(source, mask).to(source.dtype) + feature_maps.append(source) + masks.append(mask) + position_embeddings_list.append(pos_l) + + # Create queries + query_embeds = None + if self.config.embedding_init_target or self.config.two_stage: + query_embeds = self.query_position_embeddings.weight + + # Prepare encoder inputs (by flattening) + source_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes_list = [] + for level, (source, mask, pos_embed) in enumerate(zip(feature_maps, masks, position_embeddings_list)): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes_list.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + source_flatten.append(source) + mask_flatten.append(mask) + source_flatten = torch.cat(source_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + valid_ratios = valid_ratios.float() + + # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder + # Also provide spatial_shapes, level_start_index and valid_ratios + if encoder_outputs is None: + encoder_outputs = self.encoder( + vision_features=source_flatten, + vision_attention_mask=~mask_flatten, + vision_position_embedding=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + text_features=text_features, + text_attention_mask=~text_token_mask, + text_position_embedding=None, + text_self_attention_masks=~text_self_attention_masks, + text_position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a MMGroundingDinoEncoderOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, MMGroundingDinoEncoderOutput): + encoder_outputs = MMGroundingDinoEncoderOutput( + last_hidden_state_vision=encoder_outputs[0], + last_hidden_state_text=encoder_outputs[1], + vision_hidden_states=encoder_outputs[2] if output_hidden_states else None, + text_hidden_states=encoder_outputs[3] if output_hidden_states else None, + attentions=encoder_outputs[-1] if output_attentions else None, + ) + + # Fifth, prepare decoder inputs + topk_proposals = None + enc_outputs_class = None + enc_outputs_coord_logits = None + encoder_logits = None + encoder_pred_boxes = None + if self.config.two_stage: + object_query_embedding, output_proposals = self.generate_encoder_output_proposals( + encoder_outputs[0], ~mask_flatten, spatial_shapes + ) + + # hack implementation as in two-stage Deformable DETR + # apply a detection head to each pixel (A.4 in paper) + # linear projection for bounding box binary classification (i.e. foreground and background) + enc_outputs_class = self.encoder_output_class_embed( + object_query_embedding, encoder_outputs[1], text_token_mask + ) + # 3-layer FFN to predict bounding boxes coordinates (bbox regression branch) + delta_bbox = self.encoder_output_bbox_embed(object_query_embedding) + enc_outputs_coord_logits = delta_bbox + output_proposals + + # only keep top scoring `config.num_queries` proposals + topk = self.config.num_queries + topk_logits = enc_outputs_class.max(-1)[0] + topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] + topk_coords_logits = torch.gather( + enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + ) + + topk_coords_logits = topk_coords_logits.detach() + reference_points = topk_coords_logits.sigmoid() + init_reference_points = reference_points + if query_embeds is not None: + target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1) + else: + target = torch.gather( + object_query_embedding, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + ).detach() + + # Set intermediate topk proposals (coords and class) for loss computation + encoder_pred_boxes = reference_points + encoder_logits = self.encoder_output_class_embed(target, text_features, text_token_mask) + else: + target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1) + reference_points = self.reference_points.weight.unsqueeze(0).repeat(batch_size, 1, 1).sigmoid() + init_reference_points = reference_points + + decoder_outputs = self.decoder( + inputs_embeds=target, + vision_encoder_hidden_states=encoder_outputs[0], + vision_encoder_attention_mask=mask_flatten, + text_encoder_hidden_states=encoder_outputs[1], + text_encoder_attention_mask=~text_token_mask, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + self_attn_mask=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple( + value + for value in [ + enc_outputs_class, + enc_outputs_coord_logits, + encoder_logits, + encoder_pred_boxes, + ] + if value is not None + ) + tuple_outputs = ( + (decoder_outputs[0], init_reference_points) + decoder_outputs[1:] + encoder_outputs + enc_outputs + ) + + return tuple_outputs + + return MMGroundingDinoModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + init_reference_points=init_reference_points, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + encoder_last_hidden_state_vision=encoder_outputs.last_hidden_state_vision, + encoder_last_hidden_state_text=encoder_outputs.last_hidden_state_text, + encoder_vision_hidden_states=encoder_outputs.vision_hidden_states, + encoder_text_hidden_states=encoder_outputs.text_hidden_states, + encoder_attentions=encoder_outputs.attentions, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + encoder_logits=encoder_logits, + encoder_pred_boxes=encoder_pred_boxes, + ) + + +class MMGroundingDinoMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`MMGroundingDinoForObjectDetection`]. + """ +) +class MMGroundingDinoObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~MMGroundingDinoProcessor.post_process_grounded_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + encoder_last_hidden_state_vision (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_last_hidden_state_text (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the vision embeddings + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the vision encoder at the + output of each layer plus the initial embedding outputs. + encoder_text_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the text embeddings + one for the output of each layer) + of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the text encoder at the output of + each layer plus the initial embedding outputs. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.num_queries` scoring bounding boxes are picked as + region proposals in the first stage. Output of bounding box binary classification (i.e. foreground and + background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + encoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.two_stage=True`): + Logits of top `config.num_queries` scoring bounding boxes in the first stage. + encoder_pred_boxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Coordinates of top `config.num_queries` scoring bounding boxes in the first stage. + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Encoded candidate labels sequence. Used in processor to post process object detection result. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + init_reference_points: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + encoder_last_hidden_state_vision: Optional[torch.FloatTensor] = None + encoder_last_hidden_state_text: Optional[torch.FloatTensor] = None + encoder_vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_text_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + encoder_logits: Optional[torch.FloatTensor] = None + encoder_pred_boxes: Optional[torch.FloatTensor] = None + input_ids: Optional[torch.LongTensor] = None + + +def build_label_maps(logits: torch.FloatTensor, input_ids: torch.LongTensor) -> tuple[torch.FloatTensor]: + """ + Computes a mapping between tokens and their corresponding labels, where `num_labels` is determined by the number of classes in the input prompt. + The function identifies segments of tokens between specific delimiter tokens and generates label maps for those segments. + Args: + logits (`torch.Tensor` of shape `(batch_size, seq_length, hidden_size)`): + The output logits from the model, where `hidden_size` corresponds to the dimension of the model's output features. + + input_ids (`torch.Tensor` of shape `(batch_size, seq_length)`): + The input token IDs corresponding to the input prompt. For example, given the prompt "fish. shark.", + `input_ids` might look like `[101, 3869, 1012, 11420, 1012, 102]` where each number corresponds to a token including special tokens. + Returns: + tuple: A tuple containing label maps for each instance in the batch. + - label_maps (tuple of `torch.Tensor`): + A tuple of tensors, where each tensor in the tuple corresponds to an instance in the batch. Each tensor + has shape `(num_labels, hidden_size)` and contains binary values (0 or 1), where `1` indicates the tokens + that are associated with a specific label (class) between delimiter tokens, and `0` elsewhere. + Example: + Given an input prompt "fish. shark." and corresponding `input_ids` as `[101, 3869, 1012, 11420, 1012, 102]`: + - The function identifies the tokens for "fish" (IDs `[3869]`) and "shark" (IDs `[11420]`). + - The function then constructs label maps for these tokens, where each label map indicates which tokens + correspond to which label between the delimiter tokens (e.g., between the period `.`). + - The output is a tuple of label maps, one for each instance in the batch. + Note: + - `SPECIAL_TOKENS` should be a predefined list of tokens that are considered special (e.g., `[CLS]`, `[SEP]`, etc.). + """ + max_seq_len = logits.shape[-1] + # Add [PAD] token to the list of special tokens + delimiter_tokens = torch.tensor(SPECIAL_TOKENS + [0], device=input_ids.device) + + delimiter_token_masks = torch.isin(input_ids, delimiter_tokens) + label_groups = torch.cumsum(delimiter_token_masks, dim=1) * (~delimiter_token_masks).to(torch.int32) + + label_maps = () + + # Iterate over batch dimension as we can have different number of labels + for label_group in label_groups: + # `label_group` is a tensor of shape `(seq_len,)` with zeros for non-label tokens and integers for label tokens + # label tokens with same integer value are part of the same label group + + # Get unique labels and exclude 0 (i.e. non-label tokens) + unique_labels = torch.unique(label_group)[1:, None] + num_labels = unique_labels.shape[0] + + # Create one-hot encoding for each label group + label_map = label_group.unsqueeze(0).repeat(num_labels, 1) + label_map = torch.where(label_map == unique_labels, 1, 0) + + # Pad label_map to match `max_seq_len` + label_map = F.pad(label_map, (0, max_seq_len - label_map.shape[1]), value=0) + + label_maps += (label_map,) + + return label_maps + + +def build_text_mask(logits, attention_mask): + """ + Create text_mask based on the matching indices + """ + seq_len = attention_mask.shape[1] + text_mask = torch.zeros_like(logits, device=logits.device, dtype=attention_mask.dtype) + text_mask[:, :, :seq_len] = attention_mask[:, None, :] + + return text_mask.bool() + + +@auto_docstring( + custom_intro=""" + Grounding DINO Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, + for tasks such as COCO detection. + """ +) +class MMGroundingDinoForObjectDetection(MMGroundingDinoPreTrainedModel): + _tied_weights_keys = [ + r"bbox_embed\.[1-9]\d*", + r"model\.decoder\.bbox_embed\.[0-9]\d*", + r"class_embed\.[1-9]\d*", + r"model\.decoder\.class_embed\.[0-9]\d*", + ] + + def __init__(self, config: MMGroundingDinoConfig): + super().__init__(config) + + self.model = MMGroundingDinoModel(config) + + self.class_embed = nn.ModuleList( + [MMGroundingDinoContrastiveEmbedding(config) for _ in range(config.decoder_layers)] + ) + + self.bbox_embed = nn.ModuleList( + [ + MMGroundingDinoMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + for _ in range(config.decoder_layers) + ] + ) + + # hack for box-refinement + self.model.decoder.bbox_embed = self.bbox_embed + # hack implementation for two-stage + self.model.decoder.class_embed = self.class_embed + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + pixel_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Union[MMGroundingDinoEncoderOutput, tuple]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[list[dict[str, Union[torch.LongTensor, torch.FloatTensor]]]] = None, + ): + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`BertTokenizer.__call__`] for details. + token_type_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: 0 corresponds to a `sentence A` token, 1 corresponds to a `sentence B` token + + [What are token type IDs?](../glossary#token-type-ids) + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> import requests + + >>> import torch + >>> from PIL import Image + >>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection + + >>> model_id = "IDEA-Research/grounding-dino-tiny" + >>> device = "cuda" + + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) + + >>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(image_url, stream=True).raw) + >>> # Check for cats and remote controls + >>> text_labels = [["a cat", "a remote control"]] + + >>> inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device) + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> results = processor.post_process_grounded_object_detection( + ... outputs, + ... threshold=0.4, + ... text_threshold=0.3, + ... target_sizes=[(image.height, image.width)] + ... ) + >>> # Retrieve the first image result + >>> result = results[0] + >>> for box, score, text_label in zip(result["boxes"], result["scores"], result["text_labels"]): + ... box = [round(x, 2) for x in box.tolist()] + ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") + Detected a cat with confidence 0.479 at location [344.7, 23.11, 637.18, 374.28] + Detected a cat with confidence 0.438 at location [12.27, 51.91, 316.86, 472.44] + Detected a remote control with confidence 0.478 at location [38.57, 70.0, 176.78, 118.18] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # First, sent images through Grounding DINO base model to obtain encoder + decoder outputs + outputs = self.model( + pixel_values=pixel_values, + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0) + enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx] + hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2] + init_reference_points = outputs.init_reference_points if return_dict else outputs[1] + inter_references_points = outputs.intermediate_reference_points if return_dict else outputs[3] + + # class logits + predicted bounding boxes + outputs_classes = [] + outputs_coords = [] + + # hidden_states are of shape (batch_size, num_stages, height, width) + # predict class and bounding box deltas for each stage + num_levels = hidden_states.shape[1] + for level in range(num_levels): + if level == 0: + reference = init_reference_points + else: + reference = inter_references_points[:, level - 1] + reference = torch.special.logit(reference, eps=1e-5) + outputs_class = self.class_embed[level]( + vision_hidden_state=hidden_states[:, level], + text_hidden_state=enc_text_hidden_state, + text_token_mask=attention_mask.bool(), + ) + delta_bbox = self.bbox_embed[level](hidden_states[:, level]) + + reference_coordinates = reference.shape[-1] + if reference_coordinates == 4: + outputs_coord_logits = delta_bbox + reference + elif reference_coordinates == 2: + delta_bbox[..., :2] += reference + outputs_coord_logits = delta_bbox + else: + raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}") + outputs_coord = outputs_coord_logits.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + logits = outputs_class[-1] + pred_boxes = outputs_coord[-1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + label_maps = build_label_maps(logits, input_ids) + text_mask = build_text_mask(logits, attention_mask) + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + label_maps, + text_mask, + outputs_class=outputs_class, + outputs_coord=outputs_coord, + encoder_logits=outputs[-2], + encoder_pred_boxes=outputs[-1], + ) + + if not return_dict: + auxiliary_outputs = auxiliary_outputs if auxiliary_outputs is not None else [] + output = [loss, loss_dict, logits, pred_boxes, *auxiliary_outputs, *outputs, input_ids] + output = tuple(out for out in output if out is not None) + return output + + dict_outputs = MMGroundingDinoObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + last_hidden_state=outputs.last_hidden_state, + auxiliary_outputs=auxiliary_outputs, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + encoder_last_hidden_state_vision=outputs.encoder_last_hidden_state_vision, + encoder_last_hidden_state_text=outputs.encoder_last_hidden_state_text, + encoder_vision_hidden_states=outputs.encoder_vision_hidden_states, + encoder_text_hidden_states=outputs.encoder_text_hidden_states, + encoder_attentions=outputs.encoder_attentions, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_reference_points=outputs.intermediate_reference_points, + init_reference_points=outputs.init_reference_points, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + encoder_logits=outputs.encoder_logits, + encoder_pred_boxes=outputs.encoder_pred_boxes, + input_ids=input_ids, + ) + + return dict_outputs + + +__all__ = ["MMGroundingDinoForObjectDetection", "MMGroundingDinoModel", "MMGroundingDinoPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..a05045a68cb57f7a5fab69e7ccd31bfae518d2bc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py @@ -0,0 +1,434 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +from torch import nn + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING +from ..auto.modeling_auto import AutoModel +from ..grounding_dino.configuration_grounding_dino import GroundingDinoConfig +from ..grounding_dino.modeling_grounding_dino import ( + GroundingDinoContrastiveEmbedding, + GroundingDinoConvEncoder, + GroundingDinoConvModel, + GroundingDinoDecoder, + GroundingDinoEncoder, + GroundingDinoForObjectDetection, + GroundingDinoMLPPredictionHead, + GroundingDinoModel, + GroundingDinoPreTrainedModel, + build_position_encoding, +) + + +logger = logging.get_logger(__name__) + + +class MMGroundingDinoConfig(GroundingDinoConfig, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MMGroundingDinoModel`]. It is used to instantiate a + MM Grounding DINO model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MM Grounding DINO tiny architecture + [openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_v3det](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_v3det). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `BertConfig`): + The config object or dictionary of the text backbone. + num_queries (`int`, *optional*, defaults to 900): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`MMGroundingDinoModel`] can detect in a single image. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (`str`, *optional*, defaults to `"sine"`): + Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. + num_feature_levels (`int`, *optional*, defaults to 4): + The number of input feature levels. + encoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the encoder. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + two_stage (`bool`, *optional*, defaults to `True`): + Whether to apply a two-stage deformable DETR, where the region proposals are also generated by a variant of + Grounding DINO, which are further fed into the decoder for iterative bounding box refinement. + class_cost (`float`, *optional*, defaults to 1.0): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + bbox_loss_coefficient (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + focal_alpha (`float`, *optional*, defaults to 0.25): + Alpha parameter in the focal loss. + disable_custom_kernels (`bool`, *optional*, defaults to `False`): + Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom + kernels are not supported by PyTorch ONNX export. + max_text_len (`int`, *optional*, defaults to 256): + The maximum length of the text input. + text_enhancer_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the text enhancer. + fusion_droppath (`float`, *optional*, defaults to 0.1): + The droppath ratio for the fusion module. + fusion_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the fusion module. + embedding_init_target (`bool`, *optional*, defaults to `True`): + Whether to initialize the target with Embedding weights. + query_dim (`int`, *optional*, defaults to 4): + The dimension of the query vector. + positional_embedding_temperature (`float`, *optional*, defaults to 20): + The temperature for Sine Positional Embedding that is used together with vision backbone. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + + Examples: + + ```python + >>> from transformers import MMGroundingDinoConfig, MMGroundingDinoModel + + >>> # Initializing a MM Grounding DINO configuration + >>> configuration = MMGroundingDinoConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = MMGroundingDinoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mm-grounding-dino" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + text_config=None, + num_queries=900, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + auxiliary_loss=False, + position_embedding_type="sine", + num_feature_levels=4, + encoder_n_points=4, + decoder_n_points=4, + two_stage=True, + class_cost=1.0, + bbox_cost=5.0, + giou_cost=2.0, + bbox_loss_coefficient=5.0, + giou_loss_coefficient=2.0, + focal_alpha=0.25, + disable_custom_kernels=False, + # other parameters + max_text_len=256, + text_enhancer_dropout=0.0, + fusion_droppath=0.1, + fusion_dropout=0.0, + embedding_init_target=True, + query_dim=4, + positional_embedding_temperature=20, + init_std=0.02, + layer_norm_eps=1e-5, + **kwargs, + ): + PretrainedConfig.__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.") + backbone_config = CONFIG_MAPPING["swin"]( + window_size=7, + image_size=224, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the text config with default values (`BertConfig`).") + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.num_queries = num_queries + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + # deformable attributes + self.num_feature_levels = num_feature_levels + self.encoder_n_points = encoder_n_points + self.decoder_n_points = decoder_n_points + self.two_stage = two_stage + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.focal_alpha = focal_alpha + self.disable_custom_kernels = disable_custom_kernels + # Text backbone + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "bert") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["bert"]() + + self.text_config = text_config + self.max_text_len = max_text_len + + # Text Enhancer + self.text_enhancer_dropout = text_enhancer_dropout + # Fusion + self.fusion_droppath = fusion_droppath + self.fusion_dropout = fusion_dropout + # Others + self.embedding_init_target = embedding_init_target + self.query_dim = query_dim + self.positional_embedding_temperature = positional_embedding_temperature + self.init_std = init_std + self.layer_norm_eps = layer_norm_eps + + +class MMGroundingDinoContrastiveEmbedding(GroundingDinoContrastiveEmbedding): + def __init__(self, config): + super().__init__(config) + self.bias = nn.Parameter(torch.tensor(0.0)) + + def forward( + self, + vision_hidden_state: torch.FloatTensor, + text_hidden_state: torch.FloatTensor, + text_token_mask: torch.BoolTensor, + ) -> torch.FloatTensor: + res = vision_hidden_state @ text_hidden_state.transpose(-1, -2) + res = res / math.sqrt(vision_hidden_state.shape[-1]) + res = res + self.bias + res.masked_fill_(~text_token_mask[:, None, :], float("-inf")) + + # padding to max_text_len + new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device) + new_res[..., : res.shape[-1]] = res + + return new_res + + +class MMGroundingDinoPreTrainedModel(GroundingDinoPreTrainedModel): + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, MMGroundingDinoContrastiveEmbedding): + nn.init.constant_(module.bias, -math.log((1 - 0.01) / 0.01)) + + +class MMGroundingDinoConvEncoder(GroundingDinoConvEncoder): + pass + + +class MMGroundingDinoConvModel(GroundingDinoConvModel): + pass + + +class MMGroundingDinoEncoder(GroundingDinoEncoder): + pass + + +class MMGroundingDinoDecoder(GroundingDinoDecoder): + pass + + +class MMGroundingDinoModel(GroundingDinoModel, MMGroundingDinoPreTrainedModel): + def __init__(self, config: MMGroundingDinoConfig): + MMGroundingDinoPreTrainedModel.__init__(self, config) + + # Create backbone + positional encoding + backbone = MMGroundingDinoConvEncoder(config) + position_embeddings = build_position_encoding(config) + self.backbone = MMGroundingDinoConvModel(backbone, position_embeddings) + + # Create input projection layers + num_backbone_outs = len(backbone.intermediate_channel_sizes) + input_proj_list = [] + for i in range(num_backbone_outs): + in_channels = backbone.intermediate_channel_sizes[i] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1), + nn.GroupNorm(32, config.d_model), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, config.d_model), + ) + ) + in_channels = config.d_model + self.input_proj_vision = nn.ModuleList(input_proj_list) + + # Create text backbone + self.text_backbone = AutoModel.from_config(config.text_config, add_pooling_layer=False) + self.text_projection = nn.Linear(config.text_config.hidden_size, config.d_model) + + if config.embedding_init_target or not config.two_stage: + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = MMGroundingDinoEncoder(config) + self.decoder = MMGroundingDinoDecoder(config) + + self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model)) + + self.enc_output = nn.Linear(config.d_model, config.d_model) + self.enc_output_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps) + self.encoder_output_bbox_embed = MMGroundingDinoMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + self.encoder_output_class_embed = MMGroundingDinoContrastiveEmbedding(config) + + self.post_init() + + +class MMGroundingDinoMLPPredictionHead(GroundingDinoMLPPredictionHead): + pass + + +class MMGroundingDinoForObjectDetection(GroundingDinoForObjectDetection, MMGroundingDinoPreTrainedModel): + _tied_weights_keys = [ + r"bbox_embed\.[1-9]\d*", + r"model\.decoder\.bbox_embed\.[0-9]\d*", + r"class_embed\.[1-9]\d*", + r"model\.decoder\.class_embed\.[0-9]\d*", + ] + + def __init__(self, config: MMGroundingDinoConfig): + MMGroundingDinoPreTrainedModel.__init__(self, config) + + self.model = MMGroundingDinoModel(config) + + self.class_embed = nn.ModuleList( + [MMGroundingDinoContrastiveEmbedding(config) for _ in range(config.decoder_layers)] + ) + + self.bbox_embed = nn.ModuleList( + [ + MMGroundingDinoMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + for _ in range(config.decoder_layers) + ] + ) + + # hack for box-refinement + self.model.decoder.bbox_embed = self.bbox_embed + # hack implementation for two-stage + self.model.decoder.class_embed = self.class_embed + + # Initialize weights and apply final processing + self.post_init() + + +__all__ = [ + "MMGroundingDinoConfig", + "MMGroundingDinoForObjectDetection", + "MMGroundingDinoModel", + "MMGroundingDinoPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bce83216c35546fe01cf96c738cbe9c2f3582486 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mobilenet_v1 import * + from .feature_extraction_mobilenet_v1 import * + from .image_processing_mobilenet_v1 import * + from .image_processing_mobilenet_v1_fast import * + from .modeling_mobilenet_v1 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..c18d72d55003dc88e8fbcfe2ee80a18975d23dbc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MobileNetV1 model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MobileNetV1Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MobileNetV1Model`]. It is used to instantiate a + MobileNetV1 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MobileNetV1 + [google/mobilenet_v1_1.0_224](https://huggingface.co/google/mobilenet_v1_1.0_224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + depth_multiplier (`float`, *optional*, defaults to 1.0): + Shrinks or expands the number of channels in each layer. Default is 1.0, which starts the network with 32 + channels. This is sometimes also called "alpha" or "width multiplier". + min_depth (`int`, *optional*, defaults to 8): + All layers will have at least this many channels. + hidden_act (`str` or `function`, *optional*, defaults to `"relu6"`): + The non-linear activation function (function or string) in the Transformer encoder and convolution layers. + tf_padding (`bool`, *optional*, defaults to `True`): + Whether to use TensorFlow padding rules on the convolution layers. + classifier_dropout_prob (`float`, *optional*, defaults to 0.999): + The dropout ratio for attached classifiers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 0.001): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import MobileNetV1Config, MobileNetV1Model + + >>> # Initializing a "mobilenet_v1_1.0_224" style configuration + >>> configuration = MobileNetV1Config() + + >>> # Initializing a model from the "mobilenet_v1_1.0_224" style configuration + >>> model = MobileNetV1Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mobilenet_v1" + + def __init__( + self, + num_channels=3, + image_size=224, + depth_multiplier=1.0, + min_depth=8, + hidden_act="relu6", + tf_padding=True, + classifier_dropout_prob=0.999, + initializer_range=0.02, + layer_norm_eps=0.001, + **kwargs, + ): + super().__init__(**kwargs) + + if depth_multiplier <= 0: + raise ValueError("depth_multiplier must be greater than zero.") + + self.num_channels = num_channels + self.image_size = image_size + self.depth_multiplier = depth_multiplier + self.min_depth = min_depth + self.hidden_act = hidden_act + self.tf_padding = tf_padding + self.classifier_dropout_prob = classifier_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + +class MobileNetV1OnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("pixel_values", {0: "batch"})]) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "image-classification": + return OrderedDict([("logits", {0: "batch"})]) + else: + return OrderedDict([("last_hidden_state", {0: "batch"}), ("pooler_output", {0: "batch"})]) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["MobileNetV1Config", "MobileNetV1OnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/feature_extraction_mobilenet_v1.py b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/feature_extraction_mobilenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..02a5401bc145996d1126641ee656180a48c92e20 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/feature_extraction_mobilenet_v1.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for MobileNetV1.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_mobilenet_v1 import MobileNetV1ImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class MobileNetV1FeatureExtractor(MobileNetV1ImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class MobileNetV1FeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use MobileNetV1ImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["MobileNetV1FeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa3f443c53b1a315a917a8167b100128b45046f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py @@ -0,0 +1,307 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for MobileNetV1.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class MobileNetV1ImageProcessor(BaseImageProcessor): + r""" + Constructs a MobileNetV1 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image + is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the + `preprocess` method. + crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`. + Can be overridden by the `crop_size` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize: + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 256} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has + an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["MobileNetV1ImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..e716553a6d102fcd5d32af50f9e28c5397812287 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for MobileNetV1.""" + +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + Unpack, +) +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling +from ...utils import auto_docstring + + +@auto_docstring +class MobileNetV1ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"shortest_edge": 256} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + + def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> None: + super().__init__(**kwargs) + + +__all__ = ["MobileNetV1ImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..25997a46790c009bc62e5b4323d378386e3c1718 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py @@ -0,0 +1,414 @@ +# coding=utf-8 +# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MobileNetV1 model.""" + +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_mobilenet_v1 import MobileNetV1Config + + +logger = logging.get_logger(__name__) + + +def _build_tf_to_pytorch_map(model, config, tf_weights=None): + """ + A map of modules from TF to PyTorch. + """ + + tf_to_pt_map = {} + + if isinstance(model, MobileNetV1ForImageClassification): + backbone = model.mobilenet_v1 + else: + backbone = model + + prefix = "MobilenetV1/Conv2d_0/" + tf_to_pt_map[prefix + "weights"] = backbone.conv_stem.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = backbone.conv_stem.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = backbone.conv_stem.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.normalization.running_var + + for i in range(13): + tf_index = i + 1 + pt_index = i * 2 + + pointer = backbone.layer[pt_index] + prefix = f"MobilenetV1/Conv2d_{tf_index}_depthwise/" + tf_to_pt_map[prefix + "depthwise_weights"] = pointer.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var + + pointer = backbone.layer[pt_index + 1] + prefix = f"MobilenetV1/Conv2d_{tf_index}_pointwise/" + tf_to_pt_map[prefix + "weights"] = pointer.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var + + if isinstance(model, MobileNetV1ForImageClassification): + prefix = "MobilenetV1/Logits/Conv2d_1c_1x1/" + tf_to_pt_map[prefix + "weights"] = model.classifier.weight + tf_to_pt_map[prefix + "biases"] = model.classifier.bias + + return tf_to_pt_map + + +def load_tf_weights_in_mobilenet_v1(model, config, tf_checkpoint_path): + """Load TensorFlow checkpoints in a PyTorch model.""" + try: + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + # Load weights from TF model + init_vars = tf.train.list_variables(tf_checkpoint_path) + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_checkpoint_path, name) + tf_weights[name] = array + + # Build TF to PyTorch weights loading map + tf_to_pt_map = _build_tf_to_pytorch_map(model, config, tf_weights) + + for name, pointer in tf_to_pt_map.items(): + logger.info(f"Importing {name}") + if name not in tf_weights: + logger.info(f"{name} not in tf pre-trained weights, skipping") + continue + + array = tf_weights[name] + + if "depthwise_weights" in name: + logger.info("Transposing depthwise") + array = np.transpose(array, (2, 3, 0, 1)) + elif "weights" in name: + logger.info("Transposing") + if len(pointer.shape) == 2: # copying into linear layer + array = array.squeeze().transpose() + else: + array = np.transpose(array, (3, 2, 0, 1)) + + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name} {array.shape}") + pointer.data = torch.from_numpy(array) + + tf_weights.pop(name, None) + tf_weights.pop(name + "/RMSProp", None) + tf_weights.pop(name + "/RMSProp_1", None) + tf_weights.pop(name + "/ExponentialMovingAverage", None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") + return model + + +def apply_tf_padding(features: torch.Tensor, conv_layer: nn.Conv2d) -> torch.Tensor: + """ + Apply TensorFlow-style "SAME" padding to a convolution layer. See the notes at: + https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 + """ + in_height, in_width = features.shape[-2:] + stride_height, stride_width = conv_layer.stride + kernel_height, kernel_width = conv_layer.kernel_size + + if in_height % stride_height == 0: + pad_along_height = max(kernel_height - stride_height, 0) + else: + pad_along_height = max(kernel_height - (in_height % stride_height), 0) + + if in_width % stride_width == 0: + pad_along_width = max(kernel_width - stride_width, 0) + else: + pad_along_width = max(kernel_width - (in_width % stride_width), 0) + + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + + padding = (pad_left, pad_right, pad_top, pad_bottom) + return nn.functional.pad(features, padding, "constant", 0.0) + + +class MobileNetV1ConvLayer(nn.Module): + def __init__( + self, + config: MobileNetV1Config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: Optional[int] = 1, + groups: Optional[int] = 1, + bias: bool = False, + use_normalization: Optional[bool] = True, + use_activation: Optional[Union[bool, str]] = True, + ) -> None: + super().__init__() + self.config = config + + if in_channels % groups != 0: + raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.") + if out_channels % groups != 0: + raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.") + + padding = 0 if config.tf_padding else int((kernel_size - 1) / 2) + + self.convolution = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=bias, + padding_mode="zeros", + ) + + if use_normalization: + self.normalization = nn.BatchNorm2d( + num_features=out_channels, + eps=config.layer_norm_eps, + momentum=0.9997, + affine=True, + track_running_stats=True, + ) + else: + self.normalization = None + + if use_activation: + if isinstance(use_activation, str): + self.activation = ACT2FN[use_activation] + elif isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + else: + self.activation = None + + def forward(self, features: torch.Tensor) -> torch.Tensor: + if self.config.tf_padding: + features = apply_tf_padding(features, self.convolution) + features = self.convolution(features) + if self.normalization is not None: + features = self.normalization(features) + if self.activation is not None: + features = self.activation(features) + return features + + +@auto_docstring +class MobileNetV1PreTrainedModel(PreTrainedModel): + config: MobileNetV1Config + load_tf_weights = load_tf_weights_in_mobilenet_v1 + base_model_prefix = "mobilenet_v1" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _no_split_modules = [] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.BatchNorm2d): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class MobileNetV1Model(MobileNetV1PreTrainedModel): + def __init__(self, config: MobileNetV1Config, add_pooling_layer: bool = True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + depth = 32 + out_channels = max(int(depth * config.depth_multiplier), config.min_depth) + + self.conv_stem = MobileNetV1ConvLayer( + config, + in_channels=config.num_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + ) + + strides = [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1] + + self.layer = nn.ModuleList() + for i in range(13): + in_channels = out_channels + + if strides[i] == 2 or i == 0: + depth *= 2 + out_channels = max(int(depth * config.depth_multiplier), config.min_depth) + + self.layer.append( + MobileNetV1ConvLayer( + config, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=strides[i], + groups=in_channels, + ) + ) + + self.layer.append( + MobileNetV1ConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + ) + + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.conv_stem(pixel_values) + + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + last_hidden_state = hidden_states + + if self.pooler is not None: + pooled_output = torch.flatten(self.pooler(last_hidden_state), start_dim=1) + else: + pooled_output = None + + if not return_dict: + return tuple(v for v in [last_hidden_state, pooled_output, all_hidden_states] if v is not None) + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=all_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + MobileNetV1 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class MobileNetV1ForImageClassification(MobileNetV1PreTrainedModel): + def __init__(self, config: MobileNetV1Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.mobilenet_v1 = MobileNetV1Model(config) + + last_hidden_size = self.mobilenet_v1.layer[-1].convolution.out_channels + + # Classifier head + self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True) + self.classifier = nn.Linear(last_hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mobilenet_v1(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(self.dropout(pooled_output)) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +__all__ = [ + "MobileNetV1ForImageClassification", + "MobileNetV1Model", + "MobileNetV1PreTrainedModel", + "load_tf_weights_in_mobilenet_v1", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6750449a3eae900890eeaffca5c766ba9cba3339 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mobilevit import * + from .feature_extraction_mobilevit import * + from .image_processing_mobilevit import * + from .image_processing_mobilevit_fast import * + from .modeling_mobilevit import * + from .modeling_tf_mobilevit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevit/configuration_mobilevit.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/configuration_mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..1bb22e1590b2d0de857c1e5c801311900fa58b7f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/configuration_mobilevit.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MobileViT model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MobileViTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MobileViTModel`]. It is used to instantiate a + MobileViT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MobileViT + [apple/mobilevit-small](https://huggingface.co/apple/mobilevit-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 256): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 2): + The size (resolution) of each patch. + hidden_sizes (`list[int]`, *optional*, defaults to `[144, 192, 240]`): + Dimensionality (hidden size) of the Transformer encoders at each stage. + neck_hidden_sizes (`list[int]`, *optional*, defaults to `[16, 32, 64, 96, 128, 160, 640]`): + The number of channels for the feature maps of the backbone. + num_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`float`, *optional*, defaults to 2.0): + The ratio of the number of channels in the output of the MLP to the number of channels in the input. + expand_ratio (`float`, *optional*, defaults to 4.0): + Expansion factor for the MobileNetv2 layers. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the Transformer encoder and convolution layers. + conv_kernel_size (`int`, *optional*, defaults to 3): + The size of the convolutional kernel in the MobileViT layer. + output_stride (`int`, *optional*, defaults to 32): + The ratio of the spatial resolution of the output to the resolution of the input image. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the Transformer encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + classifier_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for attached classifiers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + aspp_out_channels (`int`, *optional*, defaults to 256): + Number of output channels used in the ASPP layer for semantic segmentation. + atrous_rates (`list[int]`, *optional*, defaults to `[6, 12, 18]`): + Dilation (atrous) factors used in the ASPP layer for semantic segmentation. + aspp_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the ASPP layer for semantic segmentation. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + + Example: + + ```python + >>> from transformers import MobileViTConfig, MobileViTModel + + >>> # Initializing a mobilevit-small style configuration + >>> configuration = MobileViTConfig() + + >>> # Initializing a model from the mobilevit-small style configuration + >>> model = MobileViTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mobilevit" + + def __init__( + self, + num_channels=3, + image_size=256, + patch_size=2, + hidden_sizes=[144, 192, 240], + neck_hidden_sizes=[16, 32, 64, 96, 128, 160, 640], + num_attention_heads=4, + mlp_ratio=2.0, + expand_ratio=4.0, + hidden_act="silu", + conv_kernel_size=3, + output_stride=32, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.0, + classifier_dropout_prob=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + qkv_bias=True, + aspp_out_channels=256, + atrous_rates=[6, 12, 18], + aspp_dropout_prob=0.1, + semantic_loss_ignore_index=255, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_sizes = hidden_sizes + self.neck_hidden_sizes = neck_hidden_sizes + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.expand_ratio = expand_ratio + self.hidden_act = hidden_act + self.conv_kernel_size = conv_kernel_size + self.output_stride = output_stride + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.classifier_dropout_prob = classifier_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + + # decode head attributes for semantic segmentation + self.aspp_out_channels = aspp_out_channels + self.atrous_rates = atrous_rates + self.aspp_dropout_prob = aspp_dropout_prob + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + +class MobileViTOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"})]) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "image-classification": + return OrderedDict([("logits", {0: "batch"})]) + else: + return OrderedDict([("last_hidden_state", {0: "batch"}), ("pooler_output", {0: "batch"})]) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["MobileViTConfig", "MobileViTOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevit/feature_extraction_mobilevit.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/feature_extraction_mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..6c220df918647341a289e39ef0f885b80dcd9df3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/feature_extraction_mobilevit.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for MobileViT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_mobilevit import MobileViTImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class MobileViTFeatureExtractor(MobileViTImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class MobileViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use MobileViTImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["MobileViTFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevit/image_processing_mobilevit.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/image_processing_mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..5411023c31047251b94a97b429a4529ae4241672 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/image_processing_mobilevit.py @@ -0,0 +1,518 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for MobileViT.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_torch_available, + is_torch_tensor, + is_vision_available, + logging, +) +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class MobileViTImageProcessor(BaseImageProcessor): + r""" + Constructs a MobileViT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Controls the size of the output image after resizing. Can be overridden by the `size` parameter in the + `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter + in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the + image is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in + the `preprocess` method. + crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`): + Desired output size `(size["height"], size["width"])` when applying center-cropping. Can be overridden by + the `crop_size` parameter in the `preprocess` method. + do_flip_channel_order (`bool`, *optional*, defaults to `True`): + Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` + parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_flip_channel_order: bool = True, + do_reduce_labels: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_flip_channel_order = do_flip_channel_order + self.do_reduce_labels = do_reduce_labels + + # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def flip_channel_order( + self, + image: np.ndarray, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Flip the color channels from RGB to BGR or vice versa. + + Args: + image (`np.ndarray`): + The image, represented as a numpy array. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format) + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + + def __call__(self, images, segmentation_maps=None, **kwargs): + """ + Preprocesses a batch of images and optionally segmentation maps. + + Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be + passed in as positional arguments. + """ + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_reduce_labels: bool, + do_resize: bool, + do_rescale: bool, + do_center_crop: bool, + do_flip_channel_order: bool, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + rescale_factor: Optional[float] = None, + crop_size: Optional[dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_reduce_labels: + image = self.reduce_label(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_flip_channel_order: + image = self.flip_channel_order(image, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_flip_channel_order: Optional[bool] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + image = self._preprocess( + image=image, + do_reduce_labels=False, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_flip_channel_order=do_flip_channel_order, + input_data_format=input_data_format, + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_reduce_labels: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + segmentation_map = to_numpy_array(segmentation_map) + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + + segmentation_map = self._preprocess( + image=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + resample=PILImageResampling.NEAREST, + do_rescale=False, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_flip_channel_order=False, + input_data_format=input_data_format, + ) + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_flip_channel_order: Optional[bool] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image by rescale factor. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop if `do_center_crop` is set to `True`. + do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): + Whether to flip the channel order of the image. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_flip_channel_order = ( + do_flip_channel_order if do_flip_channel_order is not None else self.do_flip_channel_order + ) + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + + images = make_flat_list_of_images(images) + + if segmentation_maps is not None: + segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_flip_channel_order=do_flip_channel_order, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + + data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_mask( + segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + do_center_crop=do_center_crop, + crop_size=crop_size, + input_data_format=input_data_format, + ) + for segmentation_map in segmentation_maps + ] + + data["labels"] = segmentation_maps + + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + """ + Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`MobileViTForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`list[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["MobileViTImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevit/image_processing_mobilevit_fast.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/image_processing_mobilevit_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..fab16ecfdc878f3b6b9f3ecdae97eb474d8792d6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/image_processing_mobilevit_fast.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for MobileViT.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + is_torch_tensor, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) + + +class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): + Whether to flip the color channels from RGB to BGR or vice versa. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + """ + + do_flip_channel_order: Optional[bool] + do_reduce_labels: Optional[bool] + + +@auto_docstring +class MobileViTImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 256, "width": 256} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = None + do_convert_rgb = None + do_flip_channel_order = True + do_reduce_labels = False + valid_kwargs = MobileVitFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]): + super().__init__(**kwargs) + + # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + return label + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[MobileVitFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[MobileVitFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + images_kwargs = kwargs.copy() + images_kwargs["do_reduce_labels"] = False + batch_feature = self._preprocess(images, **images_kwargs) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_rescale": False, + "do_flip_channel_order": False, + # Nearest interpolation is used for segmentation maps instead of BILINEAR. + "interpolation": F.InterpolationMode.NEAREST_EXACT, + } + ) + + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ).pixel_values + batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return batch_feature + + def _preprocess( + self, + images: list["torch.Tensor"], + do_reduce_labels: bool, + do_resize: bool, + size: Optional[SizeDict], + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: Optional[float], + do_center_crop: bool, + crop_size: Optional[SizeDict], + do_flip_channel_order: bool, + disable_grouping: bool, + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + processed_images = [] + + if do_reduce_labels: + images = self.reduce_label(images) + + # Group images by shape for more efficient batch processing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + + # Process each group of images with the same shape + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + + # Reorder images to original sequence + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group again after resizing (in case resize produced different sizes) + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(image=stacked_images, size=crop_size) + if do_rescale: + stacked_images = self.rescale(image=stacked_images, scale=rescale_factor) + if do_flip_channel_order: + # For batched images, we need to handle them all at once + if stacked_images.ndim > 3 and stacked_images.shape[1] >= 3: + # Flip RGB → BGR for batched images + flipped = stacked_images.clone() + flipped[:, 0:3] = stacked_images[:, [2, 1, 0], ...] + stacked_images = flipped + + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + + # Stack all processed images if return_tensors is specified + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + """ + Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`MobileNetV2ForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`list[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["MobileViTImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevit/modeling_mobilevit.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/modeling_mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..415c33a7cb8527a9df0439a62fc5026cd5d92e00 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/modeling_mobilevit.py @@ -0,0 +1,998 @@ +# coding=utf-8 +# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE +"""PyTorch MobileViT model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, + SemanticSegmenterOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging, torch_int +from .configuration_mobilevit import MobileViTConfig + + +logger = logging.get_logger(__name__) + + +def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int: + """ + Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the + original TensorFlow repo. It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_value < 0.9 * value: + new_value += divisor + return int(new_value) + + +class MobileViTConvLayer(nn.Module): + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias: bool = False, + dilation: int = 1, + use_normalization: bool = True, + use_activation: Union[bool, str] = True, + ) -> None: + super().__init__() + padding = int((kernel_size - 1) / 2) * dilation + + if in_channels % groups != 0: + raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.") + if out_channels % groups != 0: + raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.") + + self.convolution = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode="zeros", + ) + + if use_normalization: + self.normalization = nn.BatchNorm2d( + num_features=out_channels, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ) + else: + self.normalization = None + + if use_activation: + if isinstance(use_activation, str): + self.activation = ACT2FN[use_activation] + elif isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + else: + self.activation = None + + def forward(self, features: torch.Tensor) -> torch.Tensor: + features = self.convolution(features) + if self.normalization is not None: + features = self.normalization(features) + if self.activation is not None: + features = self.activation(features) + return features + + +class MobileViTInvertedResidual(nn.Module): + """ + Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381 + """ + + def __init__( + self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1 + ) -> None: + super().__init__() + expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8) + + if stride not in [1, 2]: + raise ValueError(f"Invalid stride {stride}.") + + self.use_residual = (stride == 1) and (in_channels == out_channels) + + self.expand_1x1 = MobileViTConvLayer( + config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1 + ) + + self.conv_3x3 = MobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=expanded_channels, + kernel_size=3, + stride=stride, + groups=expanded_channels, + dilation=dilation, + ) + + self.reduce_1x1 = MobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=out_channels, + kernel_size=1, + use_activation=False, + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + + features = self.expand_1x1(features) + features = self.conv_3x3(features) + features = self.reduce_1x1(features) + + return residual + features if self.use_residual else features + + +class MobileViTMobileNetLayer(nn.Module): + def __init__( + self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1 + ) -> None: + super().__init__() + + self.layer = nn.ModuleList() + for i in range(num_stages): + layer = MobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if i == 0 else 1, + ) + self.layer.append(layer) + in_channels = out_channels + + def forward(self, features: torch.Tensor) -> torch.Tensor: + for layer_module in self.layer: + features = layer_module(features) + return features + + +class MobileViTSelfAttention(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: + super().__init__() + + if hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class MobileViTSelfOutput(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class MobileViTAttention(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: + super().__init__() + self.attention = MobileViTSelfAttention(config, hidden_size) + self.output = MobileViTSelfOutput(config, hidden_size) + self.pruned_heads = set() + + def prune_heads(self, heads: set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + self_outputs = self.attention(hidden_states) + attention_output = self.output(self_outputs) + return attention_output + + +class MobileViTIntermediate(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class MobileViTOutput(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class MobileViTTransformerLayer(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + self.attention = MobileViTAttention(config, hidden_size) + self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size) + self.output = MobileViTOutput(config, hidden_size, intermediate_size) + self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states)) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, hidden_states) + return layer_output + + +class MobileViTTransformer(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None: + super().__init__() + + self.layer = nn.ModuleList() + for _ in range(num_stages): + transformer_layer = MobileViTTransformerLayer( + config, + hidden_size=hidden_size, + intermediate_size=int(hidden_size * config.mlp_ratio), + ) + self.layer.append(transformer_layer) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for layer_module in self.layer: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class MobileViTLayer(GradientCheckpointingLayer): + """ + MobileViT block: https://huggingface.co/papers/2110.02178 + """ + + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + stride: int, + hidden_size: int, + num_stages: int, + dilation: int = 1, + ) -> None: + super().__init__() + self.patch_width = config.patch_size + self.patch_height = config.patch_size + + if stride == 2: + self.downsampling_layer = MobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if dilation == 1 else 1, + dilation=dilation // 2 if dilation > 1 else 1, + ) + in_channels = out_channels + else: + self.downsampling_layer = None + + self.conv_kxk = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=config.conv_kernel_size, + ) + + self.conv_1x1 = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=hidden_size, + kernel_size=1, + use_normalization=False, + use_activation=False, + ) + + self.transformer = MobileViTTransformer( + config, + hidden_size=hidden_size, + num_stages=num_stages, + ) + + self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + self.conv_projection = MobileViTConvLayer( + config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1 + ) + + self.fusion = MobileViTConvLayer( + config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size + ) + + def unfolding(self, features: torch.Tensor) -> tuple[torch.Tensor, dict]: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = int(patch_width * patch_height) + + batch_size, channels, orig_height, orig_width = features.shape + + new_height = ( + torch_int(torch.ceil(orig_height / patch_height) * patch_height) + if torch.jit.is_tracing() + else int(math.ceil(orig_height / patch_height) * patch_height) + ) + new_width = ( + torch_int(torch.ceil(orig_width / patch_width) * patch_width) + if torch.jit.is_tracing() + else int(math.ceil(orig_width / patch_width) * patch_width) + ) + + interpolate = False + if new_width != orig_width or new_height != orig_height: + # Note: Padding can be done, but then it needs to be handled in attention function. + features = nn.functional.interpolate( + features, size=(new_height, new_width), mode="bilinear", align_corners=False + ) + interpolate = True + + # number of patches along width and height + num_patch_width = new_width // patch_width + num_patch_height = new_height // patch_height + num_patches = num_patch_height * num_patch_width + + # convert from shape (batch_size, channels, orig_height, orig_width) + # to the shape (batch_size * patch_area, num_patches, channels) + patches = features.reshape( + batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width + ) + patches = patches.transpose(1, 2) + patches = patches.reshape(batch_size, channels, num_patches, patch_area) + patches = patches.transpose(1, 3) + patches = patches.reshape(batch_size * patch_area, num_patches, -1) + + info_dict = { + "orig_size": (orig_height, orig_width), + "batch_size": batch_size, + "channels": channels, + "interpolate": interpolate, + "num_patches": num_patches, + "num_patches_width": num_patch_width, + "num_patches_height": num_patch_height, + } + return patches, info_dict + + def folding(self, patches: torch.Tensor, info_dict: dict) -> torch.Tensor: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = int(patch_width * patch_height) + + batch_size = info_dict["batch_size"] + channels = info_dict["channels"] + num_patches = info_dict["num_patches"] + num_patch_height = info_dict["num_patches_height"] + num_patch_width = info_dict["num_patches_width"] + + # convert from shape (batch_size * patch_area, num_patches, channels) + # back to shape (batch_size, channels, orig_height, orig_width) + features = patches.contiguous().view(batch_size, patch_area, num_patches, -1) + features = features.transpose(1, 3) + features = features.reshape( + batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width + ) + features = features.transpose(1, 2) + features = features.reshape( + batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width + ) + + if info_dict["interpolate"]: + features = nn.functional.interpolate( + features, size=info_dict["orig_size"], mode="bilinear", align_corners=False + ) + + return features + + def forward(self, features: torch.Tensor) -> torch.Tensor: + # reduce spatial dimensions if needed + if self.downsampling_layer: + features = self.downsampling_layer(features) + + residual = features + + # local representation + features = self.conv_kxk(features) + features = self.conv_1x1(features) + + # convert feature map to patches + patches, info_dict = self.unfolding(features) + + # learn global representations + patches = self.transformer(patches) + patches = self.layernorm(patches) + + # convert patches back to feature maps + features = self.folding(patches, info_dict) + + features = self.conv_projection(features) + features = self.fusion(torch.cat((residual, features), dim=1)) + return features + + +class MobileViTEncoder(nn.Module): + def __init__(self, config: MobileViTConfig) -> None: + super().__init__() + self.config = config + + self.layer = nn.ModuleList() + self.gradient_checkpointing = False + + # segmentation architectures like DeepLab and PSPNet modify the strides + # of the classification backbones + dilate_layer_4 = dilate_layer_5 = False + if config.output_stride == 8: + dilate_layer_4 = True + dilate_layer_5 = True + elif config.output_stride == 16: + dilate_layer_5 = True + + dilation = 1 + + layer_1 = MobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[0], + out_channels=config.neck_hidden_sizes[1], + stride=1, + num_stages=1, + ) + self.layer.append(layer_1) + + layer_2 = MobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[1], + out_channels=config.neck_hidden_sizes[2], + stride=2, + num_stages=3, + ) + self.layer.append(layer_2) + + layer_3 = MobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[2], + out_channels=config.neck_hidden_sizes[3], + stride=2, + hidden_size=config.hidden_sizes[0], + num_stages=2, + ) + self.layer.append(layer_3) + + if dilate_layer_4: + dilation *= 2 + + layer_4 = MobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[3], + out_channels=config.neck_hidden_sizes[4], + stride=2, + hidden_size=config.hidden_sizes[1], + num_stages=4, + dilation=dilation, + ) + self.layer.append(layer_4) + + if dilate_layer_5: + dilation *= 2 + + layer_5 = MobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[4], + out_channels=config.neck_hidden_sizes[5], + stride=2, + hidden_size=config.hidden_sizes[2], + num_stages=3, + dilation=dilation, + ) + self.layer.append(layer_5) + + def forward( + self, + hidden_states: torch.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutputWithNoAttention]: + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + +@auto_docstring +class MobileViTPreTrainedModel(PreTrainedModel): + config: MobileViTConfig + base_model_prefix = "mobilevit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["MobileViTLayer"] + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class MobileViTModel(MobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig, expand_output: bool = True): + r""" + expand_output (`bool`, *optional*, defaults to `True`): + Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional + 1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`. + """ + super().__init__(config) + self.config = config + self.expand_output = expand_output + + self.conv_stem = MobileViTConvLayer( + config, + in_channels=config.num_channels, + out_channels=config.neck_hidden_sizes[0], + kernel_size=3, + stride=2, + ) + + self.encoder = MobileViTEncoder(config) + + if self.expand_output: + self.conv_1x1_exp = MobileViTConvLayer( + config, + in_channels=config.neck_hidden_sizes[5], + out_channels=config.neck_hidden_sizes[6], + kernel_size=1, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel + """ + for layer_index, heads in heads_to_prune.items(): + mobilevit_layer = self.encoder.layer[layer_index] + if isinstance(mobilevit_layer, MobileViTLayer): + for transformer_layer in mobilevit_layer.transformer.layer: + transformer_layer.attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.conv_stem(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.expand_output: + last_hidden_state = self.conv_1x1_exp(encoder_outputs[0]) + + # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels) + pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False) + else: + last_hidden_state = encoder_outputs[0] + pooled_output = None + + if not return_dict: + output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,) + return output + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class MobileViTForImageClassification(MobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.mobilevit = MobileViTModel(config) + + # Classifier head + self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True) + self.classifier = ( + nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(self.dropout(pooled_output)) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +class MobileViTASPPPooling(nn.Module): + def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None: + super().__init__() + + self.global_pool = nn.AdaptiveAvgPool2d(output_size=1) + + self.conv_1x1 = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_normalization=True, + use_activation="relu", + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + spatial_size = features.shape[-2:] + features = self.global_pool(features) + features = self.conv_1x1(features) + features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False) + return features + + +class MobileViTASPP(nn.Module): + """ + ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTConfig) -> None: + super().__init__() + + in_channels = config.neck_hidden_sizes[-2] + out_channels = config.aspp_out_channels + + if len(config.atrous_rates) != 3: + raise ValueError("Expected 3 values for atrous_rates") + + self.convs = nn.ModuleList() + + in_projection = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + use_activation="relu", + ) + self.convs.append(in_projection) + + self.convs.extend( + [ + MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=rate, + use_activation="relu", + ) + for rate in config.atrous_rates + ] + ) + + pool_layer = MobileViTASPPPooling(config, in_channels, out_channels) + self.convs.append(pool_layer) + + self.project = MobileViTConvLayer( + config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu" + ) + + self.dropout = nn.Dropout(p=config.aspp_dropout_prob) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + pyramid = [] + for conv in self.convs: + pyramid.append(conv(features)) + pyramid = torch.cat(pyramid, dim=1) + + pooled_features = self.project(pyramid) + pooled_features = self.dropout(pooled_features) + return pooled_features + + +class MobileViTDeepLabV3(nn.Module): + """ + DeepLabv3 architecture: https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTConfig) -> None: + super().__init__() + self.aspp = MobileViTASPP(config) + + self.dropout = nn.Dropout2d(config.classifier_dropout_prob) + + self.classifier = MobileViTConvLayer( + config, + in_channels=config.aspp_out_channels, + out_channels=config.num_labels, + kernel_size=1, + use_normalization=False, + use_activation=False, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + features = self.aspp(hidden_states[-1]) + features = self.dropout(features) + features = self.classifier(features) + return features + + +@auto_docstring( + custom_intro=""" + MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC. + """ +) +class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.mobilevit = MobileViTModel(config, expand_output=False) + self.segmentation_head = MobileViTDeepLabV3(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small") + >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.mobilevit( + pixel_values, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + logits = self.segmentation_head(encoder_hidden_states) + + loss = None + if labels is not None: + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + loss = loss_fct(upsampled_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = [ + "MobileViTForImageClassification", + "MobileViTForSemanticSegmentation", + "MobileViTModel", + "MobileViTPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevit/modeling_tf_mobilevit.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/modeling_tf_mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..dcad0f302a8ebfdc0679f140eb6dd7e139f215c2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -0,0 +1,1376 @@ +# coding=utf-8 +# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE +"""TensorFlow 2.0 MobileViT model.""" + +from __future__ import annotations + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFImageClassifierOutputWithNoAttention, + TFSemanticSegmenterOutputWithNoAttention, +) +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import logging +from .configuration_mobilevit import MobileViTConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "MobileViTConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "apple/mobilevit-small" +_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int: + """ + Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the + original TensorFlow repo. It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_value < 0.9 * value: + new_value += divisor + return int(new_value) + + +class TFMobileViTConvLayer(keras.layers.Layer): + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias: bool = False, + dilation: int = 1, + use_normalization: bool = True, + use_activation: bool | str = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + logger.warning( + f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " + "to train/fine-tune this model, you need a GPU or a TPU" + ) + + padding = int((kernel_size - 1) / 2) * dilation + self.padding = keras.layers.ZeroPadding2D(padding) + + if out_channels % groups != 0: + raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.") + + self.convolution = keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + padding="VALID", + dilation_rate=dilation, + groups=groups, + use_bias=bias, + name="convolution", + ) + + if use_normalization: + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + else: + self.normalization = None + + if use_activation: + if isinstance(use_activation, str): + self.activation = get_tf_activation(use_activation) + elif isinstance(config.hidden_act, str): + self.activation = get_tf_activation(config.hidden_act) + else: + self.activation = config.hidden_act + else: + self.activation = None + self.in_channels = in_channels + self.out_channels = out_channels + + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: + padded_features = self.padding(features) + features = self.convolution(padded_features) + if self.normalization is not None: + features = self.normalization(features, training=training) + if self.activation is not None: + features = self.activation(features) + return features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution", None) is not None: + with tf.name_scope(self.convolution.name): + self.convolution.build([None, None, None, self.in_channels]) + if getattr(self, "normalization", None) is not None: + if hasattr(self.normalization, "name"): + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.out_channels]) + + +class TFMobileViTInvertedResidual(keras.layers.Layer): + """ + Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381 + """ + + def __init__( + self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1, **kwargs + ) -> None: + super().__init__(**kwargs) + expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8) + + if stride not in [1, 2]: + raise ValueError(f"Invalid stride {stride}.") + + self.use_residual = (stride == 1) and (in_channels == out_channels) + + self.expand_1x1 = TFMobileViTConvLayer( + config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1" + ) + + self.conv_3x3 = TFMobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=expanded_channels, + kernel_size=3, + stride=stride, + groups=expanded_channels, + dilation=dilation, + name="conv_3x3", + ) + + self.reduce_1x1 = TFMobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=out_channels, + kernel_size=1, + use_activation=False, + name="reduce_1x1", + ) + + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: + residual = features + + features = self.expand_1x1(features, training=training) + features = self.conv_3x3(features, training=training) + features = self.reduce_1x1(features, training=training) + + return residual + features if self.use_residual else features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "expand_1x1", None) is not None: + with tf.name_scope(self.expand_1x1.name): + self.expand_1x1.build(None) + if getattr(self, "conv_3x3", None) is not None: + with tf.name_scope(self.conv_3x3.name): + self.conv_3x3.build(None) + if getattr(self, "reduce_1x1", None) is not None: + with tf.name_scope(self.reduce_1x1.name): + self.reduce_1x1.build(None) + + +class TFMobileViTMobileNetLayer(keras.layers.Layer): + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + stride: int = 1, + num_stages: int = 1, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.layers = [] + for i in range(num_stages): + layer = TFMobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if i == 0 else 1, + name=f"layer.{i}", + ) + self.layers.append(layer) + in_channels = out_channels + + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: + for layer_module in self.layers: + features = layer_module(features, training=training) + return features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer_module in self.layers: + with tf.name_scope(layer_module.name): + layer_module.build(None) + + +class TFMobileViTSelfAttention(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: + super().__init__(**kwargs) + + if hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + scale = tf.cast(self.attention_head_size, dtype=tf.float32) + self.scale = tf.math.sqrt(scale) + + self.query = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="query") + self.key = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="key") + self.value = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value") + + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.hidden_size = hidden_size + + def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + batch_size = tf.shape(x)[0] + x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + return tf.transpose(x, perm=[0, 2, 1, 3]) + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + batch_size = tf.shape(hidden_states)[0] + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + attention_scores = attention_scores / self.scale + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + context_layer = tf.matmul(attention_probs, value_layer) + + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size)) + return context_layer + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.hidden_size]) + + +class TFMobileViTSelfOutput(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(hidden_size, name="dense") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.hidden_size = hidden_size + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.hidden_size]) + + +class TFMobileViTAttention(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention") + self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + self_outputs = self.attention(hidden_states, training=training) + attention_output = self.dense_output(self_outputs, training=training) + return attention_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +class TFMobileViTIntermediate(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(intermediate_size, name="dense") + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.hidden_size = hidden_size + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.hidden_size]) + + +class TFMobileViTOutput(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(hidden_size, name="dense") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.intermediate_size = intermediate_size + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = hidden_states + input_tensor + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.intermediate_size]) + + +class TFMobileViTTransformerLayer(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.attention = TFMobileViTAttention(config, hidden_size, name="attention") + self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate") + self.mobilevit_output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output") + self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") + self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + self.hidden_size = hidden_size + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states), training=training) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.mobilevit_output(layer_output, hidden_states, training=training) + return layer_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "mobilevit_output", None) is not None: + with tf.name_scope(self.mobilevit_output.name): + self.mobilevit_output.build(None) + if getattr(self, "layernorm_before", None) is not None: + with tf.name_scope(self.layernorm_before.name): + self.layernorm_before.build([None, None, self.hidden_size]) + if getattr(self, "layernorm_after", None) is not None: + with tf.name_scope(self.layernorm_after.name): + self.layernorm_after.build([None, None, self.hidden_size]) + + +class TFMobileViTTransformer(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None: + super().__init__(**kwargs) + + self.layers = [] + for i in range(num_stages): + transformer_layer = TFMobileViTTransformerLayer( + config, + hidden_size=hidden_size, + intermediate_size=int(hidden_size * config.mlp_ratio), + name=f"layer.{i}", + ) + self.layers.append(transformer_layer) + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + for layer_module in self.layers: + hidden_states = layer_module(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer_module in self.layers: + with tf.name_scope(layer_module.name): + layer_module.build(None) + + +class TFMobileViTLayer(keras.layers.Layer): + """ + MobileViT block: https://huggingface.co/papers/2110.02178 + """ + + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + stride: int, + hidden_size: int, + num_stages: int, + dilation: int = 1, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.patch_width = config.patch_size + self.patch_height = config.patch_size + + if stride == 2: + self.downsampling_layer = TFMobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if dilation == 1 else 1, + dilation=dilation // 2 if dilation > 1 else 1, + name="downsampling_layer", + ) + in_channels = out_channels + else: + self.downsampling_layer = None + + self.conv_kxk = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=config.conv_kernel_size, + name="conv_kxk", + ) + + self.conv_1x1 = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=hidden_size, + kernel_size=1, + use_normalization=False, + use_activation=False, + name="conv_1x1", + ) + + self.transformer = TFMobileViTTransformer( + config, hidden_size=hidden_size, num_stages=num_stages, name="transformer" + ) + + self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + + self.conv_projection = TFMobileViTConvLayer( + config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection" + ) + + self.fusion = TFMobileViTConvLayer( + config, + in_channels=2 * in_channels, + out_channels=in_channels, + kernel_size=config.conv_kernel_size, + name="fusion", + ) + self.hidden_size = hidden_size + + def unfolding(self, features: tf.Tensor) -> tuple[tf.Tensor, dict]: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = tf.cast(patch_width * patch_height, "int32") + + batch_size = tf.shape(features)[0] + orig_height = tf.shape(features)[1] + orig_width = tf.shape(features)[2] + channels = tf.shape(features)[3] + + new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32") + new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32") + + interpolate = new_width != orig_width or new_height != orig_height + if interpolate: + # Note: Padding can be done, but then it needs to be handled in attention function. + features = tf.image.resize(features, size=(new_height, new_width), method="bilinear") + + # number of patches along width and height + num_patch_width = new_width // patch_width + num_patch_height = new_height // patch_height + num_patches = num_patch_height * num_patch_width + + # convert from shape (batch_size, orig_height, orig_width, channels) + # to the shape (batch_size * patch_area, num_patches, channels) + features = tf.transpose(features, [0, 3, 1, 2]) + patches = tf.reshape( + features, (batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width) + ) + patches = tf.transpose(patches, [0, 2, 1, 3]) + patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area)) + patches = tf.transpose(patches, [0, 3, 2, 1]) + patches = tf.reshape(patches, (batch_size * patch_area, num_patches, channels)) + + info_dict = { + "orig_size": (orig_height, orig_width), + "batch_size": batch_size, + "channels": channels, + "interpolate": interpolate, + "num_patches": num_patches, + "num_patches_width": num_patch_width, + "num_patches_height": num_patch_height, + } + return patches, info_dict + + def folding(self, patches: tf.Tensor, info_dict: dict) -> tf.Tensor: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = int(patch_width * patch_height) + + batch_size = info_dict["batch_size"] + channels = info_dict["channels"] + num_patches = info_dict["num_patches"] + num_patch_height = info_dict["num_patches_height"] + num_patch_width = info_dict["num_patches_width"] + + # convert from shape (batch_size * patch_area, num_patches, channels) + # back to shape (batch_size, channels, orig_height, orig_width) + features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1)) + features = tf.transpose(features, perm=(0, 3, 2, 1)) + features = tf.reshape( + features, (batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width) + ) + features = tf.transpose(features, perm=(0, 2, 1, 3)) + features = tf.reshape( + features, (batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width) + ) + features = tf.transpose(features, perm=(0, 2, 3, 1)) + + if info_dict["interpolate"]: + features = tf.image.resize(features, size=info_dict["orig_size"], method="bilinear") + + return features + + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: + # reduce spatial dimensions if needed + if self.downsampling_layer: + features = self.downsampling_layer(features, training=training) + + residual = features + + # local representation + features = self.conv_kxk(features, training=training) + features = self.conv_1x1(features, training=training) + + # convert feature map to patches + patches, info_dict = self.unfolding(features) + + # learn global representations + patches = self.transformer(patches, training=training) + patches = self.layernorm(patches) + + # convert patches back to feature maps + features = self.folding(patches, info_dict) + + features = self.conv_projection(features, training=training) + features = self.fusion(tf.concat([residual, features], axis=-1), training=training) + return features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv_kxk", None) is not None: + with tf.name_scope(self.conv_kxk.name): + self.conv_kxk.build(None) + if getattr(self, "conv_1x1", None) is not None: + with tf.name_scope(self.conv_1x1.name): + self.conv_1x1.build(None) + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, self.hidden_size]) + if getattr(self, "conv_projection", None) is not None: + with tf.name_scope(self.conv_projection.name): + self.conv_projection.build(None) + if getattr(self, "fusion", None) is not None: + with tf.name_scope(self.fusion.name): + self.fusion.build(None) + if getattr(self, "downsampling_layer", None) is not None: + with tf.name_scope(self.downsampling_layer.name): + self.downsampling_layer.build(None) + + +class TFMobileViTEncoder(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + + self.layers = [] + + # segmentation architectures like DeepLab and PSPNet modify the strides + # of the classification backbones + dilate_layer_4 = dilate_layer_5 = False + if config.output_stride == 8: + dilate_layer_4 = True + dilate_layer_5 = True + elif config.output_stride == 16: + dilate_layer_5 = True + + dilation = 1 + + layer_1 = TFMobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[0], + out_channels=config.neck_hidden_sizes[1], + stride=1, + num_stages=1, + name="layer.0", + ) + self.layers.append(layer_1) + + layer_2 = TFMobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[1], + out_channels=config.neck_hidden_sizes[2], + stride=2, + num_stages=3, + name="layer.1", + ) + self.layers.append(layer_2) + + layer_3 = TFMobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[2], + out_channels=config.neck_hidden_sizes[3], + stride=2, + hidden_size=config.hidden_sizes[0], + num_stages=2, + name="layer.2", + ) + self.layers.append(layer_3) + + if dilate_layer_4: + dilation *= 2 + + layer_4 = TFMobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[3], + out_channels=config.neck_hidden_sizes[4], + stride=2, + hidden_size=config.hidden_sizes[1], + num_stages=4, + dilation=dilation, + name="layer.3", + ) + self.layers.append(layer_4) + + if dilate_layer_5: + dilation *= 2 + + layer_5 = TFMobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[4], + out_channels=config.neck_hidden_sizes[5], + stride=2, + hidden_size=config.hidden_sizes[2], + num_stages=3, + dilation=dilation, + name="layer.4", + ) + self.layers.append(layer_5) + + def call( + self, + hidden_states: tf.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> tuple | TFBaseModelOutput: + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.layers): + hidden_states = layer_module(hidden_states, training=training) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer_module in self.layers: + with tf.name_scope(layer_module.name): + layer_module.build(None) + + +@keras_serializable +class TFMobileViTMainLayer(keras.layers.Layer): + config_class = MobileViTConfig + + def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs): + super().__init__(**kwargs) + self.config = config + self.expand_output = expand_output + + self.conv_stem = TFMobileViTConvLayer( + config, + in_channels=config.num_channels, + out_channels=config.neck_hidden_sizes[0], + kernel_size=3, + stride=2, + name="conv_stem", + ) + + self.encoder = TFMobileViTEncoder(config, name="encoder") + + if self.expand_output: + self.conv_1x1_exp = TFMobileViTConvLayer( + config, + in_channels=config.neck_hidden_sizes[5], + out_channels=config.neck_hidden_sizes[6], + kernel_size=1, + name="conv_1x1_exp", + ) + + self.pooler = keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler") + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPooling: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + embedding_output = self.conv_stem(pixel_values, training=training) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + if self.expand_output: + last_hidden_state = self.conv_1x1_exp(encoder_outputs[0]) + + # Change to NCHW output format to have uniformity in the modules + last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2]) + + # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels) + pooled_output = self.pooler(last_hidden_state) + else: + last_hidden_state = encoder_outputs[0] + # Change to NCHW output format to have uniformity in the modules + last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2]) + pooled_output = None + + if not return_dict: + output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,) + + # Change to NCHW output format to have uniformity in the modules + if not self.expand_output: + remaining_encoder_outputs = encoder_outputs[1:] + remaining_encoder_outputs = tuple( + tf.transpose(h, perm=(0, 3, 1, 2)) for h in remaining_encoder_outputs[0] + ) + remaining_encoder_outputs = (remaining_encoder_outputs,) + return output + remaining_encoder_outputs + else: + return output + encoder_outputs[1:] + + # Change the other hidden state outputs to NCHW as well + if output_hidden_states: + hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]) + + return TFBaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv_stem", None) is not None: + with tf.name_scope(self.conv_stem.name): + self.conv_stem.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build([None, None, None, None]) + if getattr(self, "conv_1x1_exp", None) is not None: + with tf.name_scope(self.conv_1x1_exp.name): + self.conv_1x1_exp.build(None) + + +class TFMobileViTPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MobileViTConfig + base_model_prefix = "mobilevit" + main_input_name = "pixel_values" + + +MOBILEVIT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MOBILEVIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`MobileViTImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. +""" + + +@add_start_docstrings( + "The bare MobileViT model outputting raw hidden-states without any specific head on top.", + MOBILEVIT_START_DOCSTRING, +) +class TFMobileViTModel(TFMobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig, expand_output: bool = True, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.config = config + self.expand_output = expand_output + + self.mobilevit = TFMobileViTMainLayer(config, expand_output=expand_output, name="mobilevit") + + @unpack_inputs + @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: tf.Tensor | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPooling: + output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training) + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mobilevit", None) is not None: + with tf.name_scope(self.mobilevit.name): + self.mobilevit.build(None) + + +@add_start_docstrings( + """ + MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + MOBILEVIT_START_DOCSTRING, +) +class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: MobileViTConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.mobilevit = TFMobileViTMainLayer(config, name="mobilevit") + + # Classifier head + self.dropout = keras.layers.Dropout(config.classifier_dropout_prob) + self.classifier = ( + keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: tf.Tensor | None = None, + output_hidden_states: bool | None = None, + labels: tf.Tensor | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFImageClassifierOutputWithNoAttention: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mobilevit( + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(self.dropout(pooled_output, training=training)) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mobilevit", None) is not None: + with tf.name_scope(self.mobilevit.name): + self.mobilevit.build(None) + if getattr(self, "classifier", None) is not None: + if hasattr(self.classifier, "name"): + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.neck_hidden_sizes[-1]]) + + +class TFMobileViTASPPPooling(keras.layers.Layer): + def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None: + super().__init__(**kwargs) + + self.global_pool = keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool") + + self.conv_1x1 = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_normalization=True, + use_activation="relu", + name="conv_1x1", + ) + + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: + spatial_size = shape_list(features)[1:-1] + features = self.global_pool(features) + features = self.conv_1x1(features, training=training) + features = tf.image.resize(features, size=spatial_size, method="bilinear") + return features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "global_pool", None) is not None: + with tf.name_scope(self.global_pool.name): + self.global_pool.build([None, None, None, None]) + if getattr(self, "conv_1x1", None) is not None: + with tf.name_scope(self.conv_1x1.name): + self.conv_1x1.build(None) + + +class TFMobileViTASPP(keras.layers.Layer): + """ + ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(**kwargs) + + in_channels = config.neck_hidden_sizes[-2] + out_channels = config.aspp_out_channels + + if len(config.atrous_rates) != 3: + raise ValueError("Expected 3 values for atrous_rates") + + self.convs = [] + + in_projection = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + use_activation="relu", + name="convs.0", + ) + self.convs.append(in_projection) + + self.convs.extend( + [ + TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=rate, + use_activation="relu", + name=f"convs.{i + 1}", + ) + for i, rate in enumerate(config.atrous_rates) + ] + ) + + pool_layer = TFMobileViTASPPPooling( + config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}" + ) + self.convs.append(pool_layer) + + self.project = TFMobileViTConvLayer( + config, + in_channels=5 * out_channels, + out_channels=out_channels, + kernel_size=1, + use_activation="relu", + name="project", + ) + + self.dropout = keras.layers.Dropout(config.aspp_dropout_prob) + + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: + # since the hidden states were transposed to have `(batch_size, channels, height, width)` + # layout we transpose them back to have `(batch_size, height, width, channels)` layout. + features = tf.transpose(features, perm=[0, 2, 3, 1]) + pyramid = [] + for conv in self.convs: + pyramid.append(conv(features, training=training)) + pyramid = tf.concat(pyramid, axis=-1) + + pooled_features = self.project(pyramid, training=training) + pooled_features = self.dropout(pooled_features, training=training) + return pooled_features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "project", None) is not None: + with tf.name_scope(self.project.name): + self.project.build(None) + if getattr(self, "convs", None) is not None: + for conv in self.convs: + with tf.name_scope(conv.name): + conv.build(None) + + +class TFMobileViTDeepLabV3(keras.layers.Layer): + """ + DeepLabv3 architecture: https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.aspp = TFMobileViTASPP(config, name="aspp") + + self.dropout = keras.layers.Dropout(config.classifier_dropout_prob) + + self.classifier = TFMobileViTConvLayer( + config, + in_channels=config.aspp_out_channels, + out_channels=config.num_labels, + kernel_size=1, + use_normalization=False, + use_activation=False, + bias=True, + name="classifier", + ) + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + features = self.aspp(hidden_states[-1], training=training) + features = self.dropout(features, training=training) + features = self.classifier(features, training=training) + return features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "aspp", None) is not None: + with tf.name_scope(self.aspp.name): + self.aspp.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC. + """, + MOBILEVIT_START_DOCSTRING, +) +class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.num_labels = config.num_labels + self.mobilevit = TFMobileViTMainLayer(config, expand_output=False, name="mobilevit") + self.segmentation_head = TFMobileViTDeepLabV3(config, name="segmentation_head") + + def hf_compute_loss(self, logits, labels): + # upsample logits to the images' original size + # `labels` is of shape (batch_size, height, width) + label_interp_shape = shape_list(labels)[1:] + + upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") + # compute weighted loss + loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") + + def masked_loss(real, pred): + unmasked_loss = loss_fct(real, pred) + mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype) + masked_loss = unmasked_loss * mask + # Reduction strategy in the similar spirit with + # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210 + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask) + return tf.reshape(reduced_masked_loss, (1,)) + + return masked_loss(labels, upsampled_logits) + + @unpack_inputs + @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSemanticSegmenterOutputWithNoAttention, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: tf.Tensor | None = None, + labels: tf.Tensor | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> tuple | TFSemanticSegmenterOutputWithNoAttention: + r""" + labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFMobileViTForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small") + >>> model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") + + >>> inputs = image_processor(images=image, return_tensors="tf") + + >>> outputs = model(**inputs) + + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and not self.config.num_labels > 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.mobilevit( + pixel_values, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + training=training, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + logits = self.segmentation_head(encoder_hidden_states, training=training) + + loss = None + if labels is not None: + loss = self.hf_compute_loss(logits=logits, labels=labels) + + # make logits of shape (batch_size, num_labels, height, width) to + # keep them consistent across APIs + logits = tf.transpose(logits, perm=[0, 3, 1, 2]) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSemanticSegmenterOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mobilevit", None) is not None: + with tf.name_scope(self.mobilevit.name): + self.mobilevit.build(None) + if getattr(self, "segmentation_head", None) is not None: + with tf.name_scope(self.segmentation_head.name): + self.segmentation_head.build(None) + + +__all__ = [ + "TFMobileViTForImageClassification", + "TFMobileViTForSemanticSegmentation", + "TFMobileViTModel", + "TFMobileViTPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15e10a1d9815935629fadf9405effdae5a16fbb2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mobilevitv2 import * + from .modeling_mobilevitv2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/configuration_mobilevitv2.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/configuration_mobilevitv2.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9db247602c77daa35c3c3ad97dbda10e881cc8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/configuration_mobilevitv2.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MobileViTV2 model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MobileViTV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MobileViTV2Model`]. It is used to instantiate a + MobileViTV2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MobileViTV2 + [apple/mobilevitv2-1.0](https://huggingface.co/apple/mobilevitv2-1.0) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 256): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 2): + The size (resolution) of each patch. + expand_ratio (`float`, *optional*, defaults to 2.0): + Expansion factor for the MobileNetv2 layers. + hidden_act (`str` or `function`, *optional*, defaults to `"swish"`): + The non-linear activation function (function or string) in the Transformer encoder and convolution layers. + conv_kernel_size (`int`, *optional*, defaults to 3): + The size of the convolutional kernel in the MobileViTV2 layer. + output_stride (`int`, *optional*, defaults to 32): + The ratio of the spatial resolution of the output to the resolution of the input image. + classifier_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for attached classifiers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + aspp_out_channels (`int`, *optional*, defaults to 512): + Number of output channels used in the ASPP layer for semantic segmentation. + atrous_rates (`list[int]`, *optional*, defaults to `[6, 12, 18]`): + Dilation (atrous) factors used in the ASPP layer for semantic segmentation. + aspp_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the ASPP layer for semantic segmentation. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + n_attn_blocks (`list[int]`, *optional*, defaults to `[2, 4, 3]`): + The number of attention blocks in each MobileViTV2Layer + base_attn_unit_dims (`list[int]`, *optional*, defaults to `[128, 192, 256]`): + The base multiplier for dimensions of attention blocks in each MobileViTV2Layer + width_multiplier (`float`, *optional*, defaults to 1.0): + The width multiplier for MobileViTV2. + ffn_multiplier (`int`, *optional*, defaults to 2): + The FFN multiplier for MobileViTV2. + attn_dropout (`float`, *optional*, defaults to 0.0): + The dropout in the attention layer. + ffn_dropout (`float`, *optional*, defaults to 0.0): + The dropout between FFN layers. + + Example: + + ```python + >>> from transformers import MobileViTV2Config, MobileViTV2Model + + >>> # Initializing a mobilevitv2-small style configuration + >>> configuration = MobileViTV2Config() + + >>> # Initializing a model from the mobilevitv2-small style configuration + >>> model = MobileViTV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mobilevitv2" + + def __init__( + self, + num_channels=3, + image_size=256, + patch_size=2, + expand_ratio=2.0, + hidden_act="swish", + conv_kernel_size=3, + output_stride=32, + classifier_dropout_prob=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + aspp_out_channels=512, + atrous_rates=[6, 12, 18], + aspp_dropout_prob=0.1, + semantic_loss_ignore_index=255, + n_attn_blocks=[2, 4, 3], + base_attn_unit_dims=[128, 192, 256], + width_multiplier=1.0, + ffn_multiplier=2, + attn_dropout=0.0, + ffn_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.expand_ratio = expand_ratio + self.hidden_act = hidden_act + self.conv_kernel_size = conv_kernel_size + self.output_stride = output_stride + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.n_attn_blocks = n_attn_blocks + self.base_attn_unit_dims = base_attn_unit_dims + self.width_multiplier = width_multiplier + self.ffn_multiplier = ffn_multiplier + self.ffn_dropout = ffn_dropout + self.attn_dropout = attn_dropout + self.classifier_dropout_prob = classifier_dropout_prob + + # decode head attributes for semantic segmentation + self.aspp_out_channels = aspp_out_channels + self.atrous_rates = atrous_rates + self.aspp_dropout_prob = aspp_dropout_prob + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + +class MobileViTV2OnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"})]) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "image-classification": + return OrderedDict([("logits", {0: "batch"})]) + else: + return OrderedDict([("last_hidden_state", {0: "batch"}), ("pooler_output", {0: "batch"})]) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["MobileViTV2Config", "MobileViTV2OnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/modeling_mobilevitv2.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0e972a648ad028f37d4a5928e3c4f178fa9b27 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -0,0 +1,947 @@ +# coding=utf-8 +# Copyright 2023 Apple Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE +"""PyTorch MobileViTV2 model.""" + +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, + SemanticSegmenterOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_mobilevitv2 import MobileViTV2Config + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.mobilevit.modeling_mobilevit.make_divisible +def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int: + """ + Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the + original TensorFlow repo. It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_value < 0.9 * value: + new_value += divisor + return int(new_value) + + +def clip(value: float, min_val: float = float("-inf"), max_val: float = float("inf")) -> float: + return max(min_val, min(max_val, value)) + + +# Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTConvLayer with MobileViT->MobileViTV2 +class MobileViTV2ConvLayer(nn.Module): + def __init__( + self, + config: MobileViTV2Config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias: bool = False, + dilation: int = 1, + use_normalization: bool = True, + use_activation: Union[bool, str] = True, + ) -> None: + super().__init__() + padding = int((kernel_size - 1) / 2) * dilation + + if in_channels % groups != 0: + raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.") + if out_channels % groups != 0: + raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.") + + self.convolution = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode="zeros", + ) + + if use_normalization: + self.normalization = nn.BatchNorm2d( + num_features=out_channels, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ) + else: + self.normalization = None + + if use_activation: + if isinstance(use_activation, str): + self.activation = ACT2FN[use_activation] + elif isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + else: + self.activation = None + + def forward(self, features: torch.Tensor) -> torch.Tensor: + features = self.convolution(features) + if self.normalization is not None: + features = self.normalization(features) + if self.activation is not None: + features = self.activation(features) + return features + + +# Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTInvertedResidual with MobileViT->MobileViTV2 +class MobileViTV2InvertedResidual(nn.Module): + """ + Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381 + """ + + def __init__( + self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int, dilation: int = 1 + ) -> None: + super().__init__() + expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8) + + if stride not in [1, 2]: + raise ValueError(f"Invalid stride {stride}.") + + self.use_residual = (stride == 1) and (in_channels == out_channels) + + self.expand_1x1 = MobileViTV2ConvLayer( + config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1 + ) + + self.conv_3x3 = MobileViTV2ConvLayer( + config, + in_channels=expanded_channels, + out_channels=expanded_channels, + kernel_size=3, + stride=stride, + groups=expanded_channels, + dilation=dilation, + ) + + self.reduce_1x1 = MobileViTV2ConvLayer( + config, + in_channels=expanded_channels, + out_channels=out_channels, + kernel_size=1, + use_activation=False, + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + + features = self.expand_1x1(features) + features = self.conv_3x3(features) + features = self.reduce_1x1(features) + + return residual + features if self.use_residual else features + + +# Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTMobileNetLayer with MobileViT->MobileViTV2 +class MobileViTV2MobileNetLayer(nn.Module): + def __init__( + self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1 + ) -> None: + super().__init__() + + self.layer = nn.ModuleList() + for i in range(num_stages): + layer = MobileViTV2InvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if i == 0 else 1, + ) + self.layer.append(layer) + in_channels = out_channels + + def forward(self, features: torch.Tensor) -> torch.Tensor: + for layer_module in self.layer: + features = layer_module(features) + return features + + +class MobileViTV2LinearSelfAttention(nn.Module): + """ + This layer applies a self-attention with linear complexity, as described in MobileViTV2 paper: + https://huggingface.co/papers/2206.02680 + + Args: + config (`MobileVitv2Config`): + Model configuration object + embed_dim (`int`): + `input_channels` from an expected input of size :math:`(batch_size, input_channels, height, width)` + """ + + def __init__(self, config: MobileViTV2Config, embed_dim: int) -> None: + super().__init__() + + self.qkv_proj = MobileViTV2ConvLayer( + config=config, + in_channels=embed_dim, + out_channels=1 + (2 * embed_dim), + bias=True, + kernel_size=1, + use_normalization=False, + use_activation=False, + ) + + self.attn_dropout = nn.Dropout(p=config.attn_dropout) + self.out_proj = MobileViTV2ConvLayer( + config=config, + in_channels=embed_dim, + out_channels=embed_dim, + bias=True, + kernel_size=1, + use_normalization=False, + use_activation=False, + ) + self.embed_dim = embed_dim + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # (batch_size, embed_dim, num_pixels_in_patch, num_patches) --> (batch_size, 1+2*embed_dim, num_pixels_in_patch, num_patches) + qkv = self.qkv_proj(hidden_states) + + # Project hidden_states into query, key and value + # Query --> [batch_size, 1, num_pixels_in_patch, num_patches] + # value, key --> [batch_size, embed_dim, num_pixels_in_patch, num_patches] + query, key, value = torch.split(qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1) + + # apply softmax along num_patches dimension + context_scores = torch.nn.functional.softmax(query, dim=-1) + context_scores = self.attn_dropout(context_scores) + + # Compute context vector + # [batch_size, embed_dim, num_pixels_in_patch, num_patches] x [batch_size, 1, num_pixels_in_patch, num_patches] -> [batch_size, embed_dim, num_pixels_in_patch, num_patches] + context_vector = key * context_scores + # [batch_size, embed_dim, num_pixels_in_patch, num_patches] --> [batch_size, embed_dim, num_pixels_in_patch, 1] + context_vector = torch.sum(context_vector, dim=-1, keepdim=True) + + # combine context vector with values + # [batch_size, embed_dim, num_pixels_in_patch, num_patches] * [batch_size, embed_dim, num_pixels_in_patch, 1] --> [batch_size, embed_dim, num_pixels_in_patch, num_patches] + out = torch.nn.functional.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + return out + + +class MobileViTV2FFN(nn.Module): + def __init__( + self, + config: MobileViTV2Config, + embed_dim: int, + ffn_latent_dim: int, + ffn_dropout: float = 0.0, + ) -> None: + super().__init__() + self.conv1 = MobileViTV2ConvLayer( + config=config, + in_channels=embed_dim, + out_channels=ffn_latent_dim, + kernel_size=1, + stride=1, + bias=True, + use_normalization=False, + use_activation=True, + ) + self.dropout1 = nn.Dropout(ffn_dropout) + + self.conv2 = MobileViTV2ConvLayer( + config=config, + in_channels=ffn_latent_dim, + out_channels=embed_dim, + kernel_size=1, + stride=1, + bias=True, + use_normalization=False, + use_activation=False, + ) + self.dropout2 = nn.Dropout(ffn_dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv1(hidden_states) + hidden_states = self.dropout1(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.dropout2(hidden_states) + return hidden_states + + +class MobileViTV2TransformerLayer(nn.Module): + def __init__( + self, + config: MobileViTV2Config, + embed_dim: int, + ffn_latent_dim: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.layernorm_before = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps) + self.attention = MobileViTV2LinearSelfAttention(config, embed_dim) + self.dropout1 = nn.Dropout(p=dropout) + self.layernorm_after = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps) + self.ffn = MobileViTV2FFN(config, embed_dim, ffn_latent_dim, config.ffn_dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + layernorm_1_out = self.layernorm_before(hidden_states) + attention_output = self.attention(layernorm_1_out) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.ffn(layer_output) + + layer_output = layer_output + hidden_states + return layer_output + + +class MobileViTV2Transformer(nn.Module): + def __init__(self, config: MobileViTV2Config, n_layers: int, d_model: int) -> None: + super().__init__() + + ffn_multiplier = config.ffn_multiplier + + ffn_dims = [ffn_multiplier * d_model] * n_layers + + # ensure that dims are multiple of 16 + ffn_dims = [int((d // 16) * 16) for d in ffn_dims] + + self.layer = nn.ModuleList() + for block_idx in range(n_layers): + transformer_layer = MobileViTV2TransformerLayer( + config, embed_dim=d_model, ffn_latent_dim=ffn_dims[block_idx] + ) + self.layer.append(transformer_layer) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for layer_module in self.layer: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class MobileViTV2Layer(GradientCheckpointingLayer): + """ + MobileViTV2 layer: https://huggingface.co/papers/2206.02680 + """ + + def __init__( + self, + config: MobileViTV2Config, + in_channels: int, + out_channels: int, + attn_unit_dim: int, + n_attn_blocks: int = 2, + dilation: int = 1, + stride: int = 2, + ) -> None: + super().__init__() + self.patch_width = config.patch_size + self.patch_height = config.patch_size + + cnn_out_dim = attn_unit_dim + + if stride == 2: + self.downsampling_layer = MobileViTV2InvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if dilation == 1 else 1, + dilation=dilation // 2 if dilation > 1 else 1, + ) + in_channels = out_channels + else: + self.downsampling_layer = None + + # Local representations + self.conv_kxk = MobileViTV2ConvLayer( + config, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=config.conv_kernel_size, + groups=in_channels, + ) + self.conv_1x1 = MobileViTV2ConvLayer( + config, + in_channels=in_channels, + out_channels=cnn_out_dim, + kernel_size=1, + use_normalization=False, + use_activation=False, + ) + + # Global representations + self.transformer = MobileViTV2Transformer(config, d_model=attn_unit_dim, n_layers=n_attn_blocks) + + # self.layernorm = MobileViTV2LayerNorm2D(attn_unit_dim, eps=config.layer_norm_eps) + self.layernorm = nn.GroupNorm(num_groups=1, num_channels=attn_unit_dim, eps=config.layer_norm_eps) + + # Fusion + self.conv_projection = MobileViTV2ConvLayer( + config, + in_channels=cnn_out_dim, + out_channels=in_channels, + kernel_size=1, + use_normalization=True, + use_activation=False, + ) + + def unfolding(self, feature_map: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]: + batch_size, in_channels, img_height, img_width = feature_map.shape + patches = nn.functional.unfold( + feature_map, + kernel_size=(self.patch_height, self.patch_width), + stride=(self.patch_height, self.patch_width), + ) + patches = patches.reshape(batch_size, in_channels, self.patch_height * self.patch_width, -1) + + return patches, (img_height, img_width) + + def folding(self, patches: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor: + batch_size, in_dim, patch_size, n_patches = patches.shape + patches = patches.reshape(batch_size, in_dim * patch_size, n_patches) + + feature_map = nn.functional.fold( + patches, + output_size=output_size, + kernel_size=(self.patch_height, self.patch_width), + stride=(self.patch_height, self.patch_width), + ) + + return feature_map + + def forward(self, features: torch.Tensor) -> torch.Tensor: + # reduce spatial dimensions if needed + if self.downsampling_layer: + features = self.downsampling_layer(features) + + # local representation + features = self.conv_kxk(features) + features = self.conv_1x1(features) + + # convert feature map to patches + patches, output_size = self.unfolding(features) + + # learn global representations + patches = self.transformer(patches) + patches = self.layernorm(patches) + + # convert patches back to feature maps + # [batch_size, patch_height, patch_width, input_dim] --> [batch_size, input_dim, patch_height, patch_width] + features = self.folding(patches, output_size) + + features = self.conv_projection(features) + return features + + +class MobileViTV2Encoder(nn.Module): + def __init__(self, config: MobileViTV2Config) -> None: + super().__init__() + self.config = config + + self.layer = nn.ModuleList() + self.gradient_checkpointing = False + + # segmentation architectures like DeepLab and PSPNet modify the strides + # of the classification backbones + dilate_layer_4 = dilate_layer_5 = False + if config.output_stride == 8: + dilate_layer_4 = True + dilate_layer_5 = True + elif config.output_stride == 16: + dilate_layer_5 = True + + dilation = 1 + + layer_0_dim = make_divisible( + clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16 + ) + + layer_1_dim = make_divisible(64 * config.width_multiplier, divisor=16) + layer_2_dim = make_divisible(128 * config.width_multiplier, divisor=8) + layer_3_dim = make_divisible(256 * config.width_multiplier, divisor=8) + layer_4_dim = make_divisible(384 * config.width_multiplier, divisor=8) + layer_5_dim = make_divisible(512 * config.width_multiplier, divisor=8) + + layer_1 = MobileViTV2MobileNetLayer( + config, + in_channels=layer_0_dim, + out_channels=layer_1_dim, + stride=1, + num_stages=1, + ) + self.layer.append(layer_1) + + layer_2 = MobileViTV2MobileNetLayer( + config, + in_channels=layer_1_dim, + out_channels=layer_2_dim, + stride=2, + num_stages=2, + ) + self.layer.append(layer_2) + + layer_3 = MobileViTV2Layer( + config, + in_channels=layer_2_dim, + out_channels=layer_3_dim, + attn_unit_dim=make_divisible(config.base_attn_unit_dims[0] * config.width_multiplier, divisor=8), + n_attn_blocks=config.n_attn_blocks[0], + ) + self.layer.append(layer_3) + + if dilate_layer_4: + dilation *= 2 + + layer_4 = MobileViTV2Layer( + config, + in_channels=layer_3_dim, + out_channels=layer_4_dim, + attn_unit_dim=make_divisible(config.base_attn_unit_dims[1] * config.width_multiplier, divisor=8), + n_attn_blocks=config.n_attn_blocks[1], + dilation=dilation, + ) + self.layer.append(layer_4) + + if dilate_layer_5: + dilation *= 2 + + layer_5 = MobileViTV2Layer( + config, + in_channels=layer_4_dim, + out_channels=layer_5_dim, + attn_unit_dim=make_divisible(config.base_attn_unit_dims[2] * config.width_multiplier, divisor=8), + n_attn_blocks=config.n_attn_blocks[2], + dilation=dilation, + ) + self.layer.append(layer_5) + + def forward( + self, + hidden_states: torch.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutputWithNoAttention]: + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + +@auto_docstring +class MobileViTV2PreTrainedModel(PreTrainedModel): + config: MobileViTV2Config + base_model_prefix = "mobilevitv2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["MobileViTV2Layer"] + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.GroupNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class MobileViTV2Model(MobileViTV2PreTrainedModel): + def __init__(self, config: MobileViTV2Config, expand_output: bool = True): + r""" + expand_output (`bool`, *optional*, defaults to `True`): + Whether to expand the output of the model. If `True`, the model will output pooled features in addition to + hidden states. If `False`, only the hidden states will be returned. + """ + super().__init__(config) + self.config = config + self.expand_output = expand_output + + layer_0_dim = make_divisible( + clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16 + ) + + self.conv_stem = MobileViTV2ConvLayer( + config, + in_channels=config.num_channels, + out_channels=layer_0_dim, + kernel_size=3, + stride=2, + use_normalization=True, + use_activation=True, + ) + self.encoder = MobileViTV2Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel + """ + for layer_index, heads in heads_to_prune.items(): + mobilevitv2_layer = self.encoder.layer[layer_index] + if isinstance(mobilevitv2_layer, MobileViTV2Layer): + for transformer_layer in mobilevitv2_layer.transformer.layer: + transformer_layer.attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.conv_stem(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.expand_output: + last_hidden_state = encoder_outputs[0] + + # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels) + pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False) + else: + last_hidden_state = encoder_outputs[0] + pooled_output = None + + if not return_dict: + output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,) + return output + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + MobileViTV2 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class MobileViTV2ForImageClassification(MobileViTV2PreTrainedModel): + def __init__(self, config: MobileViTV2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.mobilevitv2 = MobileViTV2Model(config) + + out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension + # Classifier head + self.classifier = ( + nn.Linear(in_features=out_channels, out_features=config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mobilevitv2(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +# Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTASPPPooling with MobileViT->MobileViTV2 +class MobileViTV2ASPPPooling(nn.Module): + def __init__(self, config: MobileViTV2Config, in_channels: int, out_channels: int) -> None: + super().__init__() + + self.global_pool = nn.AdaptiveAvgPool2d(output_size=1) + + self.conv_1x1 = MobileViTV2ConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_normalization=True, + use_activation="relu", + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + spatial_size = features.shape[-2:] + features = self.global_pool(features) + features = self.conv_1x1(features) + features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False) + return features + + +class MobileViTV2ASPP(nn.Module): + """ + ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTV2Config) -> None: + super().__init__() + + encoder_out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension + in_channels = encoder_out_channels + out_channels = config.aspp_out_channels + + if len(config.atrous_rates) != 3: + raise ValueError("Expected 3 values for atrous_rates") + + self.convs = nn.ModuleList() + + in_projection = MobileViTV2ConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + use_activation="relu", + ) + self.convs.append(in_projection) + + self.convs.extend( + [ + MobileViTV2ConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=rate, + use_activation="relu", + ) + for rate in config.atrous_rates + ] + ) + + pool_layer = MobileViTV2ASPPPooling(config, in_channels, out_channels) + self.convs.append(pool_layer) + + self.project = MobileViTV2ConvLayer( + config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu" + ) + + self.dropout = nn.Dropout(p=config.aspp_dropout_prob) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + pyramid = [] + for conv in self.convs: + pyramid.append(conv(features)) + pyramid = torch.cat(pyramid, dim=1) + + pooled_features = self.project(pyramid) + pooled_features = self.dropout(pooled_features) + return pooled_features + + +# Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTDeepLabV3 with MobileViT->MobileViTV2 +class MobileViTV2DeepLabV3(nn.Module): + """ + DeepLabv3 architecture: https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTV2Config) -> None: + super().__init__() + self.aspp = MobileViTV2ASPP(config) + + self.dropout = nn.Dropout2d(config.classifier_dropout_prob) + + self.classifier = MobileViTV2ConvLayer( + config, + in_channels=config.aspp_out_channels, + out_channels=config.num_labels, + kernel_size=1, + use_normalization=False, + use_activation=False, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + features = self.aspp(hidden_states[-1]) + features = self.dropout(features) + features = self.classifier(features) + return features + + +@auto_docstring( + custom_intro=""" + MobileViTV2 model with a semantic segmentation head on top, e.g. for Pascal VOC. + """ +) +class MobileViTV2ForSemanticSegmentation(MobileViTV2PreTrainedModel): + def __init__(self, config: MobileViTV2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.mobilevitv2 = MobileViTV2Model(config, expand_output=False) + self.segmentation_head = MobileViTV2DeepLabV3(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from transformers import AutoImageProcessor, MobileViTV2ForSemanticSegmentation + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256") + >>> model = MobileViTV2ForSemanticSegmentation.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.mobilevitv2( + pixel_values, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + logits = self.segmentation_head(encoder_hidden_states) + + loss = None + if labels is not None: + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + loss = loss_fct(upsampled_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = [ + "MobileViTV2ForImageClassification", + "MobileViTV2ForSemanticSegmentation", + "MobileViTV2Model", + "MobileViTV2PreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/modernbert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/modernbert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18317742981909460973138da806584ddfc4a390 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/modernbert/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_modernbert import * + from .modeling_modernbert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/modernbert/configuration_modernbert.py b/venv/lib/python3.13/site-packages/transformers/models/modernbert/configuration_modernbert.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0da20ad203cb02375f40e14ea818b642c26c91 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/modernbert/configuration_modernbert.py @@ -0,0 +1,224 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +from ...configuration_utils import PretrainedConfig + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50368): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 22): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 50282): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 50281): + Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the + CLS token doesn't attend to all tokens on long sequences. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. + repad_logits_with_grad (`bool`, *optional*, defaults to `False`): + When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only + applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient. + + Examples: + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + + >>> # Initializing a ModernBert style configuration + >>> configuration = ModernBertConfig() + + >>> # Initializing a model from the modernbert-base style configuration + >>> model = ModernBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + attribute_map = {"rope_theta": "global_rope_theta"} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + initializer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50282, + bos_token_id=50281, + cls_token_id=50281, + sep_token_id=50282, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + decoder_bias=True, + classifier_pooling: Literal["cls", "mean"] = "cls", + classifier_dropout=0.0, + classifier_bias=False, + classifier_activation="gelu", + deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, + reference_compile=None, + repad_logits_with_grad=False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_cutoff_factor = initializer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.decoder_bias = decoder_bias + self.classifier_pooling = classifier_pooling + self.classifier_dropout = classifier_dropout + self.classifier_bias = classifier_bias + self.classifier_activation = classifier_activation + self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.reference_compile = reference_compile + self.repad_logits_with_grad = repad_logits_with_grad + + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + + def to_dict(self): + output = super().to_dict() + output.pop("reference_compile", None) + return output + + +__all__ = ["ModernBertConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/modernbert/modeling_modernbert.py b/venv/lib/python3.13/site-packages/transformers/models/modernbert/modeling_modernbert.py new file mode 100644 index 0000000000000000000000000000000000000000..00fbe19c3a63c7932d8710b38466d6720a4d46a3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/modernbert/modeling_modernbert.py @@ -0,0 +1,1572 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from contextlib import nullcontext +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_flash_attn_2_available, logging +from ...utils.import_utils import is_triton_available +from .configuration_modernbert import ModernBertConfig + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = object + + +logger = logging.get_logger(__name__) + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache will be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) + + @torch.compile(dynamic=True) + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + def forward( + self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds)) + else: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: ModernBertConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: tuple[int, int], + bs: int, + dim: int, + output_attentions: Optional[bool] = False, + **_kwargs, +) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = module.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, -1, dim) + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) + + +def flash_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> tuple[torch.Tensor]: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + return (attn.view(bs, dim),) + + +def sdpa_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: tuple[int, int], + bs: int, + dim: int, + **_kwargs, +) -> tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) + attn_output = attn_output.view(bs, -1, dim) + return (attn_output,) + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta + max_position_embeddings = config.local_attention + else: + self.local_attention = (-1, -1) + max_position_embeddings = config.max_position_embeddings + rope_theta = config.global_rope_theta + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + config_copy = copy.deepcopy(config) + config_copy.rope_theta = rope_theta + self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + qkv = self.Wqkv(hidden_states) + + bs = hidden_states.shape[0] + if self.config._attn_implementation == "flash_attention_2": + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + else: + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) + + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +class ModernBertEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs = self.attn( + self.attn_norm(hidden_states), + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attn_outputs[0] + mlp_output = ( + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +@auto_docstring +class ModernBertPreTrainedModel(PreTrainedModel): + config: ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = False + + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance( + module, + ( + ModernBertForSequenceClassification, + ModernBertForMultipleChoice, + ModernBertForTokenClassification, + ModernBertForQuestionAnswering, + ), + ): + init_weight(module.classifier, stds["final_out"]) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + + def _check_and_adjust_attn_implementation( + self, attn_implementation: Optional[str], is_init_check: bool = False + ) -> str: + """ + Checks and dispatches to hhe requested attention implementation. + """ + # If the user didn't specify anything, try to use flash_attention_2 if available. + # Otherwise we fall back to the default SDPA -> Eager from the super() method. + # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't + # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. + + try: + attn_implementation = ( + "flash_attention_2" + if attn_implementation is None and self._flash_attn_2_can_dispatch() + else attn_implementation + ) + except (ValueError, ImportError): + pass + return super()._check_and_adjust_attn_implementation( + attn_implementation=attn_implementation, is_init_check=is_init_check + ) + + def _maybe_set_compile(self): + if self.config.reference_compile is False: + return + + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: + if self.config.reference_compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "mps": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "cpu": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + return model_embeds + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + + +@auto_docstring +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + self._maybe_set_compile() + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + repad = False + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask + ) + else: + if position_ids is None: + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) + + attention_mask, sliding_window_mask = self._update_attention_mask( + attention_mask, output_attentions=output_attentions + ) + + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len + ) + if all_hidden_states is not None: + all_hidden_states = tuple( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + for hs in all_hidden_states + ) + # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch + # dimension missing + elif ( + self.config._attn_implementation == "flash_attention_2" + and all_hidden_states is not None + and all_hidden_states[-1].dim() == 2 + ): + hidden_states = hidden_states.unsqueeze(0) + all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) + + global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) + + # Create position indices + rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = ( + (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + ) + # Combine with existing mask + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) + + return global_attention_mask, sliding_window_mask + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a decoder head on top that is used for masked language modeling. + """ +) +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) + + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + last_hidden_state = last_hidden_state[mask_tokens] + labels = labels[mask_tokens] + + logits = ( + self.compiled_head(last_hidden_state) + if self.config.reference_compile + else self.decoder(self.head(last_hidden_state)) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs) + + if self.config._attn_implementation == "flash_attention_2": + # Logits padding + with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): + logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) + # Hidden states padding + if getattr(outputs, "hidden_states", None) is not None: + padded_hidden_states = [] + for hs in outputs.hidden_states: + if hs.dim() == 3 and hs.shape[0] == 1: + hs = hs.squeeze(0) + padded_hidden_states.append( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + ) + outputs.hidden_states = tuple(padded_hidden_states) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a sequence classification head on top that performs pooling. + """ +) +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks. + """ +) +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.head(last_hidden_state) + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class ModernBertForQuestionAnswering(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.head(last_hidden_state) + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. + """ +) +class ModernBertForMultipleChoice(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + self._maybe_set_compile() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size) + + # If classifier_pooling is "cls", isolate the token + if self.config.classifier_pooling == "cls": + indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device) + # for left or right padding, is the first non-pad token + if attention_mask is not None: + cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device) + # if no pad, is the first token + else: + cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device) + # extract the token for the logits + last_hidden_state = last_hidden_state[indices_0, cls_mask] + + # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length + elif self.config.classifier_pooling == "mean": + num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True) + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ModernBertModel", + "ModernBertPreTrainedModel", + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + "ModernBertForQuestionAnswering", + "ModernBertForMultipleChoice", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/modernbert/modular_modernbert.py b/venv/lib/python3.13/site-packages/transformers/models/modernbert/modular_modernbert.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac298f0959669c9f4201216d796df5291464e29 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/modernbert/modular_modernbert.py @@ -0,0 +1,1698 @@ +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from contextlib import nullcontext +from typing import Literal, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_flash_attn_2_available, logging +from ...utils.import_utils import is_triton_available +from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = object + + +logger = logging.get_logger(__name__) + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50368): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 22): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 50282): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 50281): + Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the + CLS token doesn't attend to all tokens on long sequences. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. + repad_logits_with_grad (`bool`, *optional*, defaults to `False`): + When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only + applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient. + + Examples: + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + + >>> # Initializing a ModernBert style configuration + >>> configuration = ModernBertConfig() + + >>> # Initializing a model from the modernbert-base style configuration + >>> model = ModernBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + attribute_map = {"rope_theta": "global_rope_theta"} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + initializer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50282, + bos_token_id=50281, + cls_token_id=50281, + sep_token_id=50282, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + decoder_bias=True, + classifier_pooling: Literal["cls", "mean"] = "cls", + classifier_dropout=0.0, + classifier_bias=False, + classifier_activation="gelu", + deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, + reference_compile=None, + repad_logits_with_grad=False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_cutoff_factor = initializer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.decoder_bias = decoder_bias + self.classifier_pooling = classifier_pooling + self.classifier_dropout = classifier_dropout + self.classifier_bias = classifier_bias + self.classifier_activation = classifier_activation + self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.reference_compile = reference_compile + self.repad_logits_with_grad = repad_logits_with_grad + + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + + def to_dict(self): + output = super().to_dict() + output.pop("reference_compile", None) + return output + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache will be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) + + @torch.compile(dynamic=True) + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + def forward( + self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds)) + else: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): + pass + + +def eager_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: tuple[int, int], + bs: int, + dim: int, + output_attentions: Optional[bool] = False, + **_kwargs, +) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = module.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, -1, dim) + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) + + +def flash_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> tuple[torch.Tensor]: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + return (attn.view(bs, dim),) + + +def sdpa_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: tuple[int, int], + bs: int, + dim: int, + **_kwargs, +) -> tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) + attn_output = attn_output.view(bs, -1, dim) + return (attn_output,) + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta + max_position_embeddings = config.local_attention + else: + self.local_attention = (-1, -1) + max_position_embeddings = config.max_position_embeddings + rope_theta = config.global_rope_theta + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + config_copy = copy.deepcopy(config) + config_copy.rope_theta = rope_theta + self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + qkv = self.Wqkv(hidden_states) + + bs = hidden_states.shape[0] + if self.config._attn_implementation == "flash_attention_2": + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + else: + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) + + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +class ModernBertEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs = self.attn( + self.attn_norm(hidden_states), + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attn_outputs[0] + mlp_output = ( + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +@auto_docstring +class ModernBertPreTrainedModel(PreTrainedModel): + config: ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = False + + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance( + module, + ( + ModernBertForSequenceClassification, + ModernBertForMultipleChoice, + ModernBertForTokenClassification, + ModernBertForQuestionAnswering, + ), + ): + init_weight(module.classifier, stds["final_out"]) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + + def _check_and_adjust_attn_implementation( + self, attn_implementation: Optional[str], is_init_check: bool = False + ) -> str: + """ + Checks and dispatches to hhe requested attention implementation. + """ + # If the user didn't specify anything, try to use flash_attention_2 if available. + # Otherwise we fall back to the default SDPA -> Eager from the super() method. + # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't + # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. + + try: + attn_implementation = ( + "flash_attention_2" + if attn_implementation is None and self._flash_attn_2_can_dispatch() + else attn_implementation + ) + except (ValueError, ImportError): + pass + return super()._check_and_adjust_attn_implementation( + attn_implementation=attn_implementation, is_init_check=is_init_check + ) + + def _maybe_set_compile(self): + if self.config.reference_compile is False: + return + + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: + if self.config.reference_compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "mps": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "cpu": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + return model_embeds + + +@auto_docstring +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + self._maybe_set_compile() + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + repad = False + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask + ) + else: + if position_ids is None: + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) + + attention_mask, sliding_window_mask = self._update_attention_mask( + attention_mask, output_attentions=output_attentions + ) + + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len + ) + if all_hidden_states is not None: + all_hidden_states = tuple( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + for hs in all_hidden_states + ) + # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch + # dimension missing + elif ( + self.config._attn_implementation == "flash_attention_2" + and all_hidden_states is not None + and all_hidden_states[-1].dim() == 2 + ): + hidden_states = hidden_states.unsqueeze(0) + all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) + + global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) + + # Create position indices + rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = ( + (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + ) + # Combine with existing mask + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) + + return global_attention_mask, sliding_window_mask + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a decoder head on top that is used for masked language modeling. + """ +) +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) + + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + last_hidden_state = last_hidden_state[mask_tokens] + labels = labels[mask_tokens] + + logits = ( + self.compiled_head(last_hidden_state) + if self.config.reference_compile + else self.decoder(self.head(last_hidden_state)) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs) + + if self.config._attn_implementation == "flash_attention_2": + # Logits padding + with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): + logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) + # Hidden states padding + if getattr(outputs, "hidden_states", None) is not None: + padded_hidden_states = [] + for hs in outputs.hidden_states: + if hs.dim() == 3 and hs.shape[0] == 1: + hs = hs.squeeze(0) + padded_hidden_states.append( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + ) + outputs.hidden_states = tuple(padded_hidden_states) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a sequence classification head on top that performs pooling. + """ +) +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks. + """ +) +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.head(last_hidden_state) + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class ModernBertForQuestionAnswering(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.head(last_hidden_state) + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. + """ +) +class ModernBertForMultipleChoice(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + self._maybe_set_compile() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size) + + # If classifier_pooling is "cls", isolate the token + if self.config.classifier_pooling == "cls": + indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device) + # for left or right padding, is the first non-pad token + if attention_mask is not None: + cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device) + # if no pad, is the first token + else: + cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device) + # extract the token for the logits + last_hidden_state = last_hidden_state[indices_0, cls_mask] + + # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length + elif self.config.classifier_pooling == "mean": + num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True) + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ModernBertConfig", + "ModernBertModel", + "ModernBertPreTrainedModel", + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + "ModernBertForQuestionAnswering", + "ModernBertForMultipleChoice", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/moonshine/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/moonshine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3c1870c2a1d6f5d089df5a396d4b049e4a9d4e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/moonshine/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_moonshine import * + from .modeling_moonshine import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/moonshine/configuration_moonshine.py b/venv/lib/python3.13/site-packages/transformers/models/moonshine/configuration_moonshine.py new file mode 100644 index 0000000000000000000000000000000000000000..270a2e3e484023a5b7422e222905e9f08d01d94f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/moonshine/configuration_moonshine.py @@ -0,0 +1,230 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_moonshine.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class MoonshineConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Moonshine + [UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MoonshineModel`]. + hidden_size (`int`, *optional*, defaults to 288): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + encoder_num_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + decoder_num_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer decoder. + encoder_num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + encoder_num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if + `encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + decoder_num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if + `decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `decoder_num_attention_heads`. + pad_head_dim_to_multiple_of (`int`, *optional*): + Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain + optimized attention implementations. + encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. + decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + decoder_start_token_id (`int`, *optional*, defaults to 1): + Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids` + are provided to the `generate` function. It is used to guide the model`s generation process depending on + the task. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.9): + Percentage of the query and keys which will have rotary embedding. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + bos_token_id (`int`, *optional*, defaults to 1): + Denotes beginning of sequences token id. + eos_token_id (`int`, *optional*, defaults to 2): + Denotes end of sequences token id. + + Example: + + ```python + >>> from transformers import MoonshineModel, MoonshineConfig + + >>> # Initializing a Moonshine style configuration + >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny") + + >>> # Initializing a model from the configuration + >>> model = MoonshineModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "moonshine" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_key_value_heads": "encoder_num_key_value_heads", + "num_attention_heads": "encoder_num_attention_heads", + "num_hidden_layers": "encoder_num_hidden_layers", + } + + def __init__( + self, + vocab_size=32768, + hidden_size=288, + intermediate_size=1152, + encoder_num_hidden_layers=6, + decoder_num_hidden_layers=6, + encoder_num_attention_heads=8, + decoder_num_attention_heads=8, + encoder_num_key_value_heads=None, + decoder_num_key_value_heads=None, + pad_head_dim_to_multiple_of=None, + encoder_hidden_act="gelu", + decoder_hidden_act="silu", + max_position_embeddings=512, + initializer_range=0.02, + decoder_start_token_id=1, + use_cache=True, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.9, + is_encoder_decoder=True, + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.encoder_num_hidden_layers = encoder_num_hidden_layers + self.decoder_num_hidden_layers = decoder_num_hidden_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.decoder_num_attention_heads = decoder_num_attention_heads + + if encoder_num_key_value_heads is None: + encoder_num_key_value_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + + if decoder_num_key_value_heads is None: + decoder_num_key_value_heads = decoder_num_attention_heads + self.decoder_num_key_value_heads = decoder_num_key_value_heads + + self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of + + self.encoder_hidden_act = encoder_hidden_act + self.decoder_hidden_act = decoder_hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.is_encoder_decoder = is_encoder_decoder + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + # Validate the correctness of rotary position embeddings parameters + rope_config_validation(self) + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +__all__ = ["MoonshineConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/moonshine/modeling_moonshine.py b/venv/lib/python3.13/site-packages/transformers/models/moonshine/modeling_moonshine.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0f879aa102f59d5541afe01c897cb6ed8b2de1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/moonshine/modeling_moonshine.py @@ -0,0 +1,1097 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_moonshine.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn + +from transformers.utils.generic import OutputRecorder, check_model_inputs + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from .configuration_moonshine import MoonshineConfig + + +class MoonshineEncoderMLP(nn.Module): + def __init__(self, config, hidden_act): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MoonshineDecoderMLP(nn.Module): + def __init__(self, config, hidden_act): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation_fn(gate) * hidden_states + hidden_states = self.fc2(hidden_states) + return hidden_states + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class MoonshineAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: MoonshineConfig, + layer_idx: int, + is_causal: bool, + num_attention_heads: int, + num_key_value_heads: int, + ): + super().__init__() + config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads}) + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = is_causal + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + # Pad head dimension to the next specified multiple. + if self.config.pad_head_dim_to_multiple_of is not None: + target_multiple = self.config.pad_head_dim_to_multiple_of + target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple) + self.head_dim_padding = target_head_dim - self.head_dim + else: + self.head_dim_padding = 0 + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + key_value_states: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len = hidden_states.shape[:-1] + + query_states = ( + self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) + ) + + is_cross_attention = key_value_states is not None + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_values.is_updated[self.layer_idx] = True + past_key_values = past_key_values.cross_attention_cache + else: + past_key_values = past_key_values.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_values and is_updated: + key_states = past_key_values.layers[self.layer_idx].keys + value_states = past_key_values.layers[self.layer_idx].values + else: + key_states = ( + self.k_proj(current_states) + .view(bsz, -1, self.config.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(current_states) + .view(bsz, -1, self.config.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + if is_cross_attention and past_key_values is not None: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + if not is_cross_attention: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + is_causal = self.is_causal and attention_mask is None and q_len > 1 + + if self.head_dim_padding > 0: + query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding)) + key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding)) + value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding)) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + is_causal=is_causal, + **kwargs, + ) + + if self.head_dim_padding > 0: + attn_output = attn_output[..., : -self.head_dim_padding] + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MoonshineRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: MoonshineConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class MoonshineEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MoonshineConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MoonshineAttention( + config=config, + layer_idx=layer_idx, + is_causal=False, + num_attention_heads=config.encoder_num_attention_heads, + num_key_value_heads=config.encoder_num_key_value_heads, + ) + + self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act) + self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class MoonshineDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MoonshineAttention( + config=config, + layer_idx=layer_idx, + is_causal=True, + num_attention_heads=config.decoder_num_attention_heads, + num_key_value_heads=config.decoder_num_key_value_heads, + ) + self.encoder_attn = MoonshineAttention( + config=config, + layer_idx=layer_idx, + is_causal=False, + num_attention_heads=config.decoder_num_attention_heads, + num_key_value_heads=config.decoder_num_key_value_heads, + ) + + self.mlp = MoonshineDecoderMLP(config, config.decoder_hidden_act) + self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + encoder_position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class MoonshinePreTrainedModel(PreTrainedModel): + config: MoonshineConfig + base_model_prefix = "model" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + # TODO arthur, how do we separate when it cross / self coming from different layer? + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + output_conv1_length = int((input_lengths - 127) / 64 + 1) + output_conv2_length = int((output_conv1_length - 7) / 3 + 1) + output_conv3_length = int((output_conv2_length - 3) / 2 + 1) + + return output_conv3_length + + +class MoonshineEncoder(MoonshinePreTrainedModel): + """ + Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`] + + Args: + config: MoonshineConfig + """ + + main_input_name = "input_values" + _can_record_outputs = { + "attentions": MoonshineAttention, + "hidden_states": MoonshineEncoderLayer, + } + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False) + self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3) + self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2) + self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5) + self.rotary_emb = MoonshineRotaryEmbedding(config=config) + + self.layers = nn.ModuleList( + [MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)] + ) + self.layer_norm = nn.LayerNorm(embed_dim, bias=False) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + @check_model_inputs() + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): + Float values of the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a + `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or + the soundfile library (`pip install soundfile`). To prepare the array into + `input_values`, the [`AutoFeatureExtractor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + """ + input_values = input_values.unsqueeze(1) + hidden_states = nn.functional.tanh(self.conv1(input_values)) + hidden_states = self.groupnorm(hidden_states) + hidden_states = nn.functional.gelu(self.conv2(hidden_states)) + hidden_states = nn.functional.gelu(self.conv3(hidden_states)) + hidden_states = hidden_states.permute(0, 2, 1) + + # attention mask downsampling + if attention_mask is not None: + mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1]) + downsample_stride = 64 * 3 * 2 # conv strides + attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask == 0.0).any() else None + elif self.config._attn_implementation == "sdpa": + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype) + else: + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.layer_norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + ) + + +@auto_docstring +class MoonshineDecoder(MoonshinePreTrainedModel): + main_input_name = "input_ids" + _can_record_outputs = { + "attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="self_attn"), + "hidden_states": MoonshineDecoderLayer, + "cross_attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="encoder_attn"), + } + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MoonshineDecoderLayer(config, idx) for idx in range(config.decoder_num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, bias=False) + self.rotary_emb = MoonshineRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + if encoder_attention_mask is not None: + mask_len = encoder_hidden_states.shape[-2] + downsample_stride = 64 * 3 * 2 # conv strides + encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None + elif self.config._attn_implementation == "sdpa": + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + else: + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +def _compute_mask_indices( + shape: tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +@auto_docstring +class MoonshineModel(MoonshinePreTrainedModel): + def __init__(self, config: MoonshineConfig): + super().__init__(config) + + self.encoder = MoonshineEncoder(config) + self.decoder = MoonshineDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_encoder(self): + return self.encoder + + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Moonshine encoder so that its parameters will + not be updated during training. + """ + self.encoder._freeze_parameters() + + def _mask_input_features( + self, + input_features: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://huggingface.co/papers/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return input_features + + # generate indices & apply SpecAugment along time axis + batch_size, hidden_size, sequence_length = input_features.size() + + if self.config.mask_time_prob > 0 and self.training: + # generate indices & apply SpecAugment along time axis + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool) + mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1) + input_features[mask_time_indices] = 0 + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool) + input_features[mask_feature_indices] = 0 + + return input_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): + Float values of the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a + `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or + the soundfile library (`pip install soundfile`). To prepare the array into + `input_values`, the [`AutoFeatureExtractor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings` + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, MoonshineModel + >>> from datasets import load_dataset + + >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_values = inputs.input_values + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 288] + ``` + """ + if encoder_outputs is None: + encoder_outputs: BaseModelOutput = self.encoder(input_values, attention_mask=attention_mask, **kwargs) + + decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_attention_mask=attention_mask, + encoder_hidden_states=encoder_outputs.last_hidden_state, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +@auto_docstring( + custom_intro=""" + The Moonshine Model with a language modeling head. Can be used for automatic speech recognition. + """ +) +class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["proj_out.weight"] + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.model = MoonshineModel(config) + self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_output_embeddings(self): + return self.proj_out + + def set_output_embeddings(self, new_embeddings): + self.proj_out = new_embeddings + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqLMOutput: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): + Float values of the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a + `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or + the soundfile library (`pip install soundfile`). To prepare the array into + `input_values`, the [`AutoFeatureExtractor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings` + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny") + >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_values = inputs.input_values + + >>> generated_ids = model.generate(input_values, max_new_tokens=100) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs: Seq2SeqModelOutput = self.model( + input_values, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + logits = self.proj_out(outputs.last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +__all__ = ["MoonshineModel", "MoonshinePreTrainedModel", "MoonshineForConditionalGeneration"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/moonshine/modular_moonshine.py b/venv/lib/python3.13/site-packages/transformers/models/moonshine/modular_moonshine.py new file mode 100644 index 0000000000000000000000000000000000000000..15b3de1b7d5642b2b2025973e5d8c83ca22bf3d1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/moonshine/modular_moonshine.py @@ -0,0 +1,921 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from transformers.utils.generic import OutputRecorder, check_model_inputs + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ..glm.modeling_glm import GlmAttention, GlmRotaryEmbedding, apply_rotary_pos_emb +from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel, eager_attention_forward +from ..whisper.modeling_whisper import WhisperModel, shift_tokens_right + + +logger = logging.get_logger(__name__) + + +class MoonshineConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Moonshine + [UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MoonshineModel`]. + hidden_size (`int`, *optional*, defaults to 288): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + encoder_num_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + decoder_num_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer decoder. + encoder_num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + encoder_num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if + `encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + decoder_num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if + `decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `decoder_num_attention_heads`. + pad_head_dim_to_multiple_of (`int`, *optional*): + Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain + optimized attention implementations. + encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. + decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + decoder_start_token_id (`int`, *optional*, defaults to 1): + Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids` + are provided to the `generate` function. It is used to guide the model`s generation process depending on + the task. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.9): + Percentage of the query and keys which will have rotary embedding. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + bos_token_id (`int`, *optional*, defaults to 1): + Denotes beginning of sequences token id. + eos_token_id (`int`, *optional*, defaults to 2): + Denotes end of sequences token id. + + Example: + + ```python + >>> from transformers import MoonshineModel, MoonshineConfig + + >>> # Initializing a Moonshine style configuration + >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny") + + >>> # Initializing a model from the configuration + >>> model = MoonshineModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "moonshine" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_key_value_heads": "encoder_num_key_value_heads", + "num_attention_heads": "encoder_num_attention_heads", + "num_hidden_layers": "encoder_num_hidden_layers", + } + + def __init__( + self, + vocab_size=32768, + hidden_size=288, + intermediate_size=1152, + encoder_num_hidden_layers=6, + decoder_num_hidden_layers=6, + encoder_num_attention_heads=8, + decoder_num_attention_heads=8, + encoder_num_key_value_heads=None, + decoder_num_key_value_heads=None, + pad_head_dim_to_multiple_of=None, + encoder_hidden_act="gelu", + decoder_hidden_act="silu", + max_position_embeddings=512, + initializer_range=0.02, + decoder_start_token_id=1, + use_cache=True, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.9, + is_encoder_decoder=True, + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.encoder_num_hidden_layers = encoder_num_hidden_layers + self.decoder_num_hidden_layers = decoder_num_hidden_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.decoder_num_attention_heads = decoder_num_attention_heads + + if encoder_num_key_value_heads is None: + encoder_num_key_value_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + + if decoder_num_key_value_heads is None: + decoder_num_key_value_heads = decoder_num_attention_heads + self.decoder_num_key_value_heads = decoder_num_key_value_heads + + self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of + + self.encoder_hidden_act = encoder_hidden_act + self.decoder_hidden_act = decoder_hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.is_encoder_decoder = is_encoder_decoder + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + # Validate the correctness of rotary position embeddings parameters + rope_config_validation(self) + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +class MoonshineEncoderMLP(nn.Module): + def __init__(self, config, hidden_act): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MoonshineDecoderMLP(nn.Module): + def __init__(self, config, hidden_act): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation_fn(gate) * hidden_states + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MoonshineAttention(GlmAttention): + def __init__( + self, + config: MoonshineConfig, + layer_idx: int, + is_causal: bool, + num_attention_heads: int, + num_key_value_heads: int, + ): + config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads}) + super().__init__(config, layer_idx) + self.is_causal = is_causal + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + # Pad head dimension to the next specified multiple. + if self.config.pad_head_dim_to_multiple_of is not None: + target_multiple = self.config.pad_head_dim_to_multiple_of + target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple) + self.head_dim_padding = target_head_dim - self.head_dim + else: + self.head_dim_padding = 0 + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + key_value_states: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len = hidden_states.shape[:-1] + + query_states = ( + self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) + ) + + is_cross_attention = key_value_states is not None + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_values.is_updated[self.layer_idx] = True + past_key_values = past_key_values.cross_attention_cache + else: + past_key_values = past_key_values.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_values and is_updated: + key_states = past_key_values.layers[self.layer_idx].keys + value_states = past_key_values.layers[self.layer_idx].values + else: + key_states = ( + self.k_proj(current_states) + .view(bsz, -1, self.config.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(current_states) + .view(bsz, -1, self.config.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + if is_cross_attention and past_key_values is not None: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + if not is_cross_attention: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + is_causal = self.is_causal and attention_mask is None and q_len > 1 + + if self.head_dim_padding > 0: + query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding)) + key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding)) + value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding)) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + is_causal=is_causal, + **kwargs, + ) + + if self.head_dim_padding > 0: + attn_output = attn_output[..., : -self.head_dim_padding] + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MoonshineRotaryEmbedding(GlmRotaryEmbedding): + pass + + +class MoonshineEncoderLayer(LlamaDecoderLayer): + def __init__(self, config: MoonshineConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self.self_attn = MoonshineAttention( + config=config, + layer_idx=layer_idx, + is_causal=False, + num_attention_heads=config.encoder_num_attention_heads, + num_key_value_heads=config.encoder_num_key_value_heads, + ) + + self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act) + self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + + +class MoonshineDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MoonshineAttention( + config=config, + layer_idx=layer_idx, + is_causal=True, + num_attention_heads=config.decoder_num_attention_heads, + num_key_value_heads=config.decoder_num_key_value_heads, + ) + self.encoder_attn = MoonshineAttention( + config=config, + layer_idx=layer_idx, + is_causal=False, + num_attention_heads=config.decoder_num_attention_heads, + num_key_value_heads=config.decoder_num_key_value_heads, + ) + + self.mlp = MoonshineDecoderMLP(config, config.decoder_hidden_act) + self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + encoder_position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class MoonshinePreTrainedModel(PreTrainedModel): + config: MoonshineConfig + base_model_prefix = "model" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + # TODO arthur, how do we separate when it cross / self coming from different layer? + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + output_conv1_length = int((input_lengths - 127) / 64 + 1) + output_conv2_length = int((output_conv1_length - 7) / 3 + 1) + output_conv3_length = int((output_conv2_length - 3) / 2 + 1) + + return output_conv3_length + + +class MoonshineEncoder(MoonshinePreTrainedModel): + """ + Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`] + + Args: + config: MoonshineConfig + """ + + main_input_name = "input_values" + _can_record_outputs = { + "attentions": MoonshineAttention, + "hidden_states": MoonshineEncoderLayer, + } + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False) + self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3) + self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2) + self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5) + self.rotary_emb = MoonshineRotaryEmbedding(config=config) + + self.layers = nn.ModuleList( + [MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)] + ) + self.layer_norm = nn.LayerNorm(embed_dim, bias=False) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + @check_model_inputs() + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): + Float values of the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a + `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or + the soundfile library (`pip install soundfile`). To prepare the array into + `input_values`, the [`AutoFeatureExtractor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + """ + input_values = input_values.unsqueeze(1) + hidden_states = nn.functional.tanh(self.conv1(input_values)) + hidden_states = self.groupnorm(hidden_states) + hidden_states = nn.functional.gelu(self.conv2(hidden_states)) + hidden_states = nn.functional.gelu(self.conv3(hidden_states)) + hidden_states = hidden_states.permute(0, 2, 1) + + # attention mask downsampling + if attention_mask is not None: + mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1]) + downsample_stride = 64 * 3 * 2 # conv strides + attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask == 0.0).any() else None + elif self.config._attn_implementation == "sdpa": + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype) + else: + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.layer_norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + ) + + +class MoonshineDecoder(LlamaModel): + main_input_name = "input_ids" + _can_record_outputs = { + "attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="self_attn"), + "hidden_states": MoonshineDecoderLayer, + "cross_attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="encoder_attn"), + } + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.norm = nn.LayerNorm(config.hidden_size, bias=False) + self.layers = nn.ModuleList( + [MoonshineDecoderLayer(config, idx) for idx in range(config.decoder_num_hidden_layers)] + ) + + @check_model_inputs() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + if encoder_attention_mask is not None: + mask_len = encoder_hidden_states.shape[-2] + downsample_stride = 64 * 3 * 2 # conv strides + encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None + elif self.config._attn_implementation == "sdpa": + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + else: + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class MoonshineModel(WhisperModel): + @can_return_tuple + @auto_docstring + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): + Float values of the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a + `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or + the soundfile library (`pip install soundfile`). To prepare the array into + `input_values`, the [`AutoFeatureExtractor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings` + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, MoonshineModel + >>> from datasets import load_dataset + + >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_values = inputs.input_values + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 288] + ``` + """ + if encoder_outputs is None: + encoder_outputs: BaseModelOutput = self.encoder(input_values, attention_mask=attention_mask, **kwargs) + + decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_attention_mask=attention_mask, + encoder_hidden_states=encoder_outputs.last_hidden_state, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Moonshine Model with a language modeling head. Can be used for automatic speech recognition. + """ +) +class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["proj_out.weight"] + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.model = MoonshineModel(config) + self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_output_embeddings(self): + return self.proj_out + + def set_output_embeddings(self, new_embeddings): + self.proj_out = new_embeddings + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqLMOutput: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`): + Float values of the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a + `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or + the soundfile library (`pip install soundfile`). To prepare the array into + `input_values`, the [`AutoFeatureExtractor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings` + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny") + >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_values = inputs.input_values + + >>> generated_ids = model.generate(input_values, max_new_tokens=100) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs: Seq2SeqModelOutput = self.model( + input_values, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + logits = self.proj_out(outputs.last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +__all__ = [ + "MoonshineConfig", + "MoonshineModel", + "MoonshinePreTrainedModel", + "MoonshineForConditionalGeneration", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpnet/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mpnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b7abc8357cc18e9d0dcc75fe79faa98e294dd77 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpnet/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mpnet import * + from .modeling_mpnet import * + from .modeling_tf_mpnet import * + from .tokenization_mpnet import * + from .tokenization_mpnet_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpnet/configuration_mpnet.py b/venv/lib/python3.13/site-packages/transformers/models/mpnet/configuration_mpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e80d6a0c30301f72bde911026489f9c89c0759cd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpnet/configuration_mpnet.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MPNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MPNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MPNetModel`] or a [`TFMPNetModel`]. It is used to + instantiate a MPNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MPNet + [microsoft/mpnet-base](https://huggingface.co/microsoft/mpnet-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30527): + Vocabulary size of the MPNet model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MPNetModel`] or [`TFMPNetModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + + Examples: + + ```python + >>> from transformers import MPNetModel, MPNetConfig + + >>> # Initializing a MPNet mpnet-base style configuration + >>> configuration = MPNetConfig() + + >>> # Initializing a model from the mpnet-base style configuration + >>> model = MPNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mpnet" + + def __init__( + self, + vocab_size=30527, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + initializer_range=0.02, + layer_norm_eps=1e-12, + relative_attention_num_buckets=32, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.relative_attention_num_buckets = relative_attention_num_buckets + + +__all__ = ["MPNetConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpnet/modeling_mpnet.py b/venv/lib/python3.13/site-packages/transformers/models/mpnet/modeling_mpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b25e5491738b7b5e9270fdfa96f27f7e0f3ce712 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpnet/modeling_mpnet.py @@ -0,0 +1,967 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MPNet model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from .configuration_mpnet import MPNetConfig + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class MPNetPreTrainedModel(PreTrainedModel): + config: MPNetConfig + base_model_prefix = "mpnet" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, MPNetLMHead): + module.bias.data.zero_() + + +class MPNetEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.padding_idx = 1 + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, **kwargs): + if position_ids is None: + if input_ids is not None: + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = inputs_embeds + position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +class MPNetSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.q = nn.Linear(config.hidden_size, self.all_head_size) + self.k = nn.Linear(config.hidden_size, self.all_head_size) + self.v = nn.Linear(config.hidden_size, self.all_head_size) + self.o = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + position_bias=None, + output_attentions=False, + **kwargs, + ): + batch_size, seq_length, _ = hidden_states.shape + q = ( + self.q(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + k = ( + self.k(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + v = ( + self.v(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(q, k.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Apply relative position embedding (precomputed in MPNetEncoder) if provided. + if position_bias is not None: + attention_scores += position_bias + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + attention_probs = self.dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + c = torch.matmul(attention_probs, v) + + c = c.permute(0, 2, 1, 3).contiguous() + new_c_shape = c.size()[:-2] + (self.all_head_size,) + c = c.view(*new_c_shape) + + o = self.o(c) + + outputs = (o, attention_probs) if output_attentions else (o,) + return outputs + + +class MPNetAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.attn = MPNetSelfAttention(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attn.num_attention_heads, self.attn.attention_head_size, self.pruned_heads + ) + + self.attn.q = prune_linear_layer(self.attn.q, index) + self.attn.k = prune_linear_layer(self.attn.k, index) + self.attn.v = prune_linear_layer(self.attn.v, index) + self.attn.o = prune_linear_layer(self.attn.o, index, dim=1) + + self.attn.num_attention_heads = self.attn.num_attention_heads - len(heads) + self.attn.all_head_size = self.attn.attention_head_size * self.attn.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + position_bias=None, + output_attentions=False, + **kwargs, + ): + self_outputs = self.attn( + hidden_states, + attention_mask, + head_mask, + position_bias, + output_attentions=output_attentions, + ) + attention_output = self.LayerNorm(self.dropout(self_outputs[0]) + hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class MPNetIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class MPNetOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class MPNetLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = MPNetAttention(config) + self.intermediate = MPNetIntermediate(config) + self.output = MPNetOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + position_bias=None, + output_attentions=False, + **kwargs, + ): + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs + return outputs + + +class MPNetEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_heads = config.num_attention_heads + self.layer = nn.ModuleList([MPNetLayer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention_bias = nn.Embedding(config.relative_attention_num_buckets, self.n_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + **kwargs, + ): + position_bias = self.compute_position_bias(hidden_states) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + head_mask[i], + position_bias, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + def compute_position_bias(self, x, position_ids=None, num_buckets=32): + bsz, qlen, klen = x.size(0), x.size(1), x.size(1) + if position_ids is not None: + context_position = position_ids[:, :, None] + memory_position = position_ids[:, None, :] + else: + context_position = torch.arange(qlen, dtype=torch.long)[:, None] + memory_position = torch.arange(klen, dtype=torch.long)[None, :] + + relative_position = memory_position - context_position + + rp_bucket = self.relative_position_bucket(relative_position, num_buckets=num_buckets) + rp_bucket = rp_bucket.to(x.device) + values = self.relative_attention_bias(rp_bucket) + values = values.permute([2, 0, 1]).unsqueeze(0) + values = values.expand((bsz, -1, qlen, klen)).contiguous() + return values + + @staticmethod + def relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).to(torch.long) * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + ret += torch.where(is_small, n, val_if_large) + return ret + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class MPNetPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class MPNetModel(MPNetPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = MPNetEmbeddings(config) + self.encoder = MPNetEncoder(config) + self.pooler = MPNetPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class MPNetForMaskedLM(MPNetPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder"] + + def __init__(self, config): + super().__init__(config) + + self.mpnet = MPNetModel(config, add_pooling_layer=False) + self.lm_head = MPNetLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mpnet( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class MPNetLMHead(nn.Module): + """MPNet Head for masked and permuted language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + +@auto_docstring( + custom_intro=""" + MPNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """ +) +class MPNetForSequenceClassification(MPNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.mpnet = MPNetModel(config, add_pooling_layer=False) + self.classifier = MPNetClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mpnet( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MPNetForMultipleChoice(MPNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.mpnet = MPNetModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.mpnet( + flat_input_ids, + position_ids=flat_position_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MPNetForTokenClassification(MPNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.mpnet = MPNetModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mpnet( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class MPNetClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to BERT's [CLS] token) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@auto_docstring +class MPNetForQuestionAnswering(MPNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.mpnet = MPNetModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mpnet( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. :param torch.Tensor x: :return torch.Tensor: + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "MPNetForMaskedLM", + "MPNetForMultipleChoice", + "MPNetForQuestionAnswering", + "MPNetForSequenceClassification", + "MPNetForTokenClassification", + "MPNetLayer", + "MPNetModel", + "MPNetPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpnet/modeling_tf_mpnet.py b/venv/lib/python3.13/site-packages/transformers/models/mpnet/modeling_tf_mpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1afea867df35413e13616b25309fa62b336b2712 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpnet/modeling_tf_mpnet.py @@ -0,0 +1,1353 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 MPNet model.""" + +from __future__ import annotations + +import math +import warnings + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_mpnet import MPNetConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/mpnet-base" +_CONFIG_FOR_DOC = "MPNetConfig" + + +class TFMPNetPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MPNetConfig + base_model_prefix = "mpnet" + + +class TFMPNetEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position embeddings.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.padding_idx = 1 + self.config = config + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(initializer_range=self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(initializer_range=self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def create_position_ids_from_input_ids(self, input_ids): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: tf.Tensor + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) + incremental_indices = tf.math.cumsum(mask, axis=1) * mask + + return incremental_indices + self.padding_idx + + def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) + else: + position_ids = tf.expand_dims( + tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + final_embeddings = inputs_embeds + position_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->MPNet +class TFMPNetPooler(keras.layers.Layer): + def __init__(self, config: MPNetConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFMPNetSelfAttention(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads}" + ) + + self.num_attention_heads = config.num_attention_heads + assert config.hidden_size % config.num_attention_heads == 0 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.q = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="q" + ) + self.k = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="k" + ) + self.v = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="v" + ) + self.o = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="o" + ) + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.config = config + + def transpose_for_scores(self, x, batch_size): + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + return tf.transpose(x, perm=[0, 2, 1, 3]) + + def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False): + batch_size = shape_list(hidden_states)[0] + + q = self.q(hidden_states) + k = self.k(hidden_states) + v = self.v(hidden_states) + + q = self.transpose_for_scores(q, batch_size) + k = self.transpose_for_scores(k, batch_size) + v = self.transpose_for_scores(v, batch_size) + + attention_scores = tf.matmul(q, k, transpose_b=True) + dk = tf.cast(shape_list(k)[-1], attention_scores.dtype) + attention_scores = attention_scores / tf.math.sqrt(dk) + + # Apply relative position embedding (precomputed in MPNetEncoder) if provided. + if position_bias is not None: + attention_scores += position_bias + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = stable_softmax(attention_scores, axis=-1) + + attention_probs = self.dropout(attention_probs, training=training) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + c = tf.matmul(attention_probs, v) + c = tf.transpose(c, perm=[0, 2, 1, 3]) + c = tf.reshape(c, (batch_size, -1, self.all_head_size)) + o = self.o(c) + + outputs = (o, attention_probs) if output_attentions else (o,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "q", None) is not None: + with tf.name_scope(self.q.name): + self.q.build([None, None, self.config.hidden_size]) + if getattr(self, "k", None) is not None: + with tf.name_scope(self.k.name): + self.k.build([None, None, self.config.hidden_size]) + if getattr(self, "v", None) is not None: + with tf.name_scope(self.v.name): + self.v.build([None, None, self.config.hidden_size]) + if getattr(self, "o", None) is not None: + with tf.name_scope(self.o.name): + self.o.build([None, None, self.config.hidden_size]) + + +class TFMPNetAttention(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.attn = TFMPNetSelfAttention(config, name="attn") + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.config = config + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, input_tensor, attention_mask, head_mask, output_attentions, position_bias=None, training=False): + self_outputs = self.attn( + input_tensor, attention_mask, head_mask, output_attentions, position_bias=position_bias, training=training + ) + attention_output = self.LayerNorm(self.dropout(self_outputs[0]) + input_tensor) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attn", None) is not None: + with tf.name_scope(self.attn.name): + self.attn.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->MPNet +class TFMPNetIntermediate(keras.layers.Layer): + def __init__(self, config: MPNetConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->MPNet +class TFMPNetOutput(keras.layers.Layer): + def __init__(self, config: MPNetConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFMPNetLayer(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.attention = TFMPNetAttention(config, name="attention") + self.intermediate = TFMPNetIntermediate(config, name="intermediate") + self.out = TFMPNetOutput(config, name="output") + + def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False): + self_attention_outputs = self.attention( + hidden_states, attention_mask, head_mask, output_attentions, position_bias=position_bias, training=training + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + intermediate_output = self.intermediate(attention_output) + layer_output = self.out(intermediate_output, attention_output, training=training) + outputs = (layer_output,) + outputs # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "out", None) is not None: + with tf.name_scope(self.out.name): + self.out.build(None) + + +class TFMPNetEncoder(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.n_heads = config.num_attention_heads + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.initializer_range = config.initializer_range + + self.layer = [TFMPNetLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + self.relative_attention_num_buckets = config.relative_attention_num_buckets + + def build(self, input_shape=None): + if self.built: + return + self.built = True + with tf.name_scope("relative_attention_bias"): + self.relative_attention_bias = self.add_weight( + name="embeddings", + shape=[self.relative_attention_num_buckets, self.n_heads], + initializer=get_initializer(self.initializer_range), + ) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + def call( + self, + hidden_states, + attention_mask, + head_mask, + output_attentions, + output_hidden_states, + return_dict, + training=False, + ): + position_bias = self.compute_position_bias(hidden_states) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + head_mask[i], + output_attentions, + position_bias=position_bias, + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += tf.cast(tf.math.less(n, 0), dtype=relative_position.dtype) * num_buckets + n = tf.math.abs(n) + + # now n is in the range [0, inf) + max_exact = num_buckets // 2 + is_small = tf.math.less(n, max_exact) + + val_if_large = max_exact + tf.cast( + tf.math.log(n / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact), + dtype=relative_position.dtype, + ) + + val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) + ret += tf.where(is_small, n, val_if_large) + return ret + + def compute_position_bias(self, x, position_ids=None): + """Compute binned relative position bias""" + input_shape = shape_list(x) + qlen, klen = input_shape[1], input_shape[1] + + if position_ids is not None: + context_position = position_ids[:, :, None] + memory_position = position_ids[:, None, :] + else: + context_position = tf.range(qlen)[:, None] + memory_position = tf.range(klen)[None, :] + + relative_position = memory_position - context_position # shape (qlen, klen) + + rp_bucket = self._relative_position_bucket( + relative_position, + num_buckets=self.relative_attention_num_buckets, + ) + values = tf.gather(self.relative_attention_bias, rp_bucket) # shape (qlen, klen, num_heads) + values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) + return values + + +@keras_serializable +class TFMPNetMainLayer(keras.layers.Layer): + config_class = MPNetConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.encoder = TFMPNetEncoder(config, name="encoder") + self.pooler = TFMPNetPooler(config, name="pooler") + # The embeddings must be the last declaration in order to follow the weights order + self.embeddings = TFMPNetEmbeddings(config, name="embeddings") + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(input_shape, 1) + + embedding_output = self.embeddings( + input_ids, + position_ids, + inputs_embeds, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + + encoder_outputs = self.encoder( + embedding_output, + extended_attention_mask, + head_mask, + output_attentions, + output_hidden_states, + return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + + +MPNET_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`MPNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MPNET_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare MPNet Model transformer outputting raw hidden-states without any specific head on top.", + MPNET_START_DOCSTRING, +) +class TFMPNetModel(TFMPNetPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.mpnet = TFMPNetMainLayer(config, name="mpnet") + + @unpack_inputs + @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.array | tf.Tensor | None = None, + position_ids: np.array | tf.Tensor | None = None, + head_mask: np.array | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + outputs = self.mpnet( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mpnet", None) is not None: + with tf.name_scope(self.mpnet.name): + self.mpnet.build(None) + + +class TFMPNetLMHead(keras.layers.Layer): + """MPNet head for masked and permuted language modeling""" + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.act = get_tf_activation("gelu") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + # project back to size of vocabulary with bias + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings("""MPNet Model with a `language modeling` head on top.""", MPNET_START_DOCSTRING) +class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss): + _keys_to_ignore_on_load_missing = [r"pooler"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.mpnet = TFMPNetMainLayer(config, name="mpnet") + self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.mpnet( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mpnet", None) is not None: + with tf.name_scope(self.mpnet.name): + self.mpnet.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFMPNetClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.out_proj = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + self.config = config + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + MPNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + MPNET_START_DOCSTRING, +) +class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassificationLoss): + _keys_to_ignore_on_load_missing = [r"pooler"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.mpnet = TFMPNetMainLayer(config, name="mpnet") + self.classifier = TFMPNetClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.array | tf.Tensor | None = None, + position_ids: np.array | tf.Tensor | None = None, + head_mask: np.array | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.mpnet( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mpnet", None) is not None: + with tf.name_scope(self.mpnet.name): + self.mpnet.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + MPNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + MPNET_START_DOCSTRING, +) +class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.mpnet = TFMPNetMainLayer(config, name="mpnet") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool = False, + ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.mpnet( + flat_input_ids, + flat_attention_mask, + flat_position_ids, + head_mask, + flat_inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mpnet", None) is not None: + with tf.name_scope(self.mpnet.name): + self.mpnet.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + MPNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + MPNET_START_DOCSTRING, +) +class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificationLoss): + _keys_to_ignore_on_load_missing = [r"pooler"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.mpnet = TFMPNetMainLayer(config, name="mpnet") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.mpnet( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mpnet", None) is not None: + with tf.name_scope(self.mpnet.name): + self.mpnet.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + MPNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MPNET_START_DOCSTRING, +) +class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLoss): + _keys_to_ignore_on_load_missing = [r"pooler"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.mpnet = TFMPNetMainLayer(config, name="mpnet") + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.array | tf.Tensor | None = None, + position_ids: np.array | tf.Tensor | None = None, + head_mask: np.array | tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: tf.Tensor | None = None, + end_positions: tf.Tensor | None = None, + training: bool = False, + **kwargs, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.mpnet( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions, "end_position": end_positions} + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "mpnet", None) is not None: + with tf.name_scope(self.mpnet.name): + self.mpnet.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFMPNetEmbeddings", + "TFMPNetForMaskedLM", + "TFMPNetForMultipleChoice", + "TFMPNetForQuestionAnswering", + "TFMPNetForSequenceClassification", + "TFMPNetForTokenClassification", + "TFMPNetMainLayer", + "TFMPNetModel", + "TFMPNetPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpnet/tokenization_mpnet.py b/venv/lib/python3.13/site-packages/transformers/models/mpnet/tokenization_mpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..bf035cf8e4bd490023890a0d690bba32a123124b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpnet/tokenization_mpnet.py @@ -0,0 +1,537 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for MPNet.""" + +import collections +import os +import unicodedata +from typing import Optional + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class MPNetTokenizer(PreTrainedTokenizer): + """ + + This tokenizer inherits from [`BertTokenizer`] which contains most of the methods. Users should refer to the + superclass for more information regarding methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="[UNK]", + pad_token="", + mask_token="", + tokenize_chinese_chars=True, + strip_accents=None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, special=True) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + # "" is part of the vocab, but was wrongfully added at a wrong index in the fast saved version + vocab = self.added_tokens_encoder.copy() + vocab.update(self.vocab) + return vocab + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A MPNet sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` methods. + + Args: + token_ids_0 (`list[int]`): + List of ids. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Set to True if the token list is already formatted with special tokens for the model + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. MPNet does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of ids. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["MPNetTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpnet/tokenization_mpnet_fast.py b/venv/lib/python3.13/site-packages/transformers/models/mpnet/tokenization_mpnet_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..1a470565a8452e765a509d0ffffb41dc6d6cd5e9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpnet/tokenization_mpnet_fast.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for MPNet.""" + +import json +from typing import Optional + +from tokenizers import normalizers + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_mpnet import MPNetTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class MPNetTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" MPNet tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = MPNetTokenizer + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="[UNK]", + pad_token="", + mask_token="", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + pre_tok_state.get("lowercase", do_lower_case) != do_lower_case + or pre_tok_state.get("strip_accents", strip_accents) != strip_accents + ): + pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) + pre_tok_state["lowercase"] = do_lower_case + pre_tok_state["strip_accents"] = strip_accents + self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) + + self.do_lower_case = do_lower_case + + @property + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + MPNet tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on MPNet. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. MPNet does not + make use of token type ids, therefore a list of zeros is returned + + Args: + token_ids_0 (`list[int]`): + List of ids. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["MPNetTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20bf8c5ebaf9871204c64eedc5990498ecff1ddc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpt/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mpt import * + from .modeling_mpt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpt/configuration_mpt.py b/venv/lib/python3.13/site-packages/transformers/models/mpt/configuration_mpt.py new file mode 100644 index 0000000000000000000000000000000000000000..f3468ca8fac4e0eb69cb378f2e762bba9b9e12d8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpt/configuration_mpt.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team and MosaicML NLP team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mpt configuration""" + +from typing import Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MptAttentionConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`MptAttention`] class. It is used to instantiate + attention layers according to the specified arguments, defining the layers architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MPT + [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b) architecture. Most of the arguments are kept for backward + compatibility with previous MPT models that are hosted on the Hub (previously with `trust_remote_code=True`). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_type (`str`, *optional*, defaults to `"multihead_attention"`): + type of attention to use. Options: `"multihead_attention"`, `"multiquery_attention"`. + attn_pdrop (`float`, *optional*, defaults to `0.0`): + The dropout probability for the attention layers. + attn_impl (`str`, *optional*, defaults to `"torch"`): + The attention implementation to use. One of `"torch"`, `"flash"`, or `"triton"`. + clip_qkv (`float`, *optional*): + If not `None`, clip the queries, keys, and values in the attention layer to this value. + softmax_scale (`float`, *optional*): + If not `None`, scale the softmax in the attention layer by this value. If `None`, will default to + `1/sqrt(hidden_size)`. + prefix_lm (`bool`, *optional*, defaults to `False`): + Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument + which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another + bi-directionally. Tokens outside the prefix use causal attention. + qk_ln (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization to the queries and keys in the attention layer. + attn_uses_sequence_id (`bool`, *optional*, defaults to `False`): + Whether to restrict attention to tokens that have the same token_type_ids. When the model is in `train` + mode, this requires passing an extra *token_type_ids* argument which indicates which sub-sequence each + token belongs to. Defaults to `False` meaning any provided *token_type_ids* will be ignored. + alibi (`bool`, *optional*, defaults to `True`): + Whether or not to use the alibi bias instead of positional embedding. + alibi_bias_max (`int`, *optional*, defaults to 8): + The maximum value of the alibi bias. + """ + + base_config_key = "attn_config" + + def __init__( + self, + attn_type="multihead_attention", + attn_pdrop=0, + attn_impl="torch", + clip_qkv=None, + softmax_scale=None, + prefix_lm=False, + qk_ln=False, + attn_uses_sequence_id=False, + alibi=True, + alibi_bias_max=8, + **kwargs, + ): + super().__init__() + self.attn_type = attn_type + self.attn_pdrop = attn_pdrop + self.attn_impl = attn_impl + self.clip_qkv = clip_qkv + self.softmax_scale = softmax_scale + self.prefix_lm = prefix_lm + self.attn_uses_sequence_id = attn_uses_sequence_id + self.alibi = alibi + self.qk_ln = qk_ln + self.alibi_bias_max = alibi_bias_max + + if attn_type not in ["multihead_attention", "multiquery_attention"]: + raise ValueError( + f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}" + ) + + +class MptConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`MptModel`]. It is used to instantiate a Mpt model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to the Mpt-7b architecture + [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 2048): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + expansion_ratio (`int`, *optional*, defaults to 4): + The ratio of the up/down scale in the MLP. + max_seq_len (`int`, *optional*, defaults to 2048): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 50368): + Vocabulary size of the Mpt model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`MptModel`]. Check [this + discussion](https://huggingface.co/bigscience/mpt/discussions/120#633d28389addb8530b406c2a) on how the + `vocab_size` has been defined. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + learned_pos_emb (`bool`, *optional*, defaults to `True`): + Whether to use learned positional embeddings. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + init_device (`str`, *optional*, defaults to `"cpu"`): + The device to use for parameter initialization. Defined for backward compatibility + logit_scale (`float`, *optional*): + If not None, scale the logits by this value. + no_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in all linear layers. + verbose (`int`, *optional*, defaults to 0): + The verbosity level to use for logging. Used in the previous versions of MPT models for logging. This + argument is deprecated. + embedding_fraction (`float`, *optional*, defaults to 1.0): + The fraction to scale the gradients of the embedding layer by. + norm_type (`str`, *optional*, defaults to `"low_precision_layernorm"`): + Type of layer norm to use. All MPT models uses the same layer norm implementation. Defined for backward + compatibility. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import MptConfig, MptModel + + >>> # Initializing a Mpt configuration + >>> configuration = MptConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = MptModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "mpt" + sub_configs = {"attn_config": MptAttentionConfig} + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + expansion_ratio: int = 4, + max_seq_len: int = 2048, + vocab_size: int = 50368, + resid_pdrop: float = 0.0, + layer_norm_epsilon: float = 1e-5, + emb_pdrop: float = 0.0, + learned_pos_emb: bool = True, + attn_config: MptAttentionConfig = None, + init_device: str = "cpu", + logit_scale: Optional[Union[float, str]] = None, + no_bias: bool = True, + verbose: int = 0, + embedding_fraction: float = 1.0, + norm_type: str = "low_precision_layernorm", + use_cache: bool = False, + initializer_range=0.02, + **kwargs, + ): + if attn_config is None: + self.attn_config = MptAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = MptAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.verbose = verbose + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.layer_norm_epsilon = layer_norm_epsilon + self.use_cache = use_cache + self.initializer_range = initializer_range + super().__init__(**kwargs) + + +__all__ = ["MptConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mpt/modeling_mpt.py b/venv/lib/python3.13/site-packages/transformers/models/mpt/modeling_mpt.py new file mode 100644 index 0000000000000000000000000000000000000000..c7bf0a795d423bfc1b72acf2442212dff061cfa5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mpt/modeling_mpt.py @@ -0,0 +1,824 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team and MosaicML NLP team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MPT model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from torch.nn import functional as F + +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_mpt import MptConfig + + +logger = logging.get_logger(__name__) + + +def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None): + r""" + Link to paper: https://huggingface.co/papers/2108.12409 - Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation. This implementation has been copied from + the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi: + https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292 + """ + alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length) + num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads)) + + base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64, device=device).float() + base = base * (alibi_bias_max / num_heads_power_of_2) + + slopes = 1.0 / torch.pow(2, base) + slopes = slopes.view(1, num_heads_power_of_2, 1, 1) + + if num_heads_power_of_2 != num_heads: + slopes = torch.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], dim=1)[:, :num_heads, ...] + + alibi = alibi * slopes + return alibi.squeeze(0) + + +class MptAttention(nn.Module): + """Multi-head self attention. + Using torch or triton attention implementation enables user to also use additive bias. + """ + + def __init__(self, config: MptConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.n_heads = config.n_heads + self.max_seq_length = config.max_seq_len + self.head_dim = self.hidden_size // self.n_heads + self.softmax_scale = config.attn_config.softmax_scale + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads) + + self.attn_dropout_p = config.attn_config.attn_pdrop + self.clip_qkv = config.attn_config.clip_qkv + self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ): + batch_size, seq_length = hidden_states.shape[:2] + + mixed_qkv = self.Wqkv(hidden_states) + if self.clip_qkv: + mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) + query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale + query_length = seq_length if past_key_values is None else seq_length + past_key_values.get_seq_length() + + if position_bias is not None: + if len(position_bias.shape) != 3: + raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}") + key_length = key_states.shape[-2] + + position_bias_query_index = max(0, position_bias.size(1) - query_length) + position_bias_key_index = max(0, position_bias.size(2) - key_length) + + position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] + + attention_scores = attention_scores + position_bias + + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training) + + context_states = torch.matmul(attn_weights, value_states) + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output = self.out_proj(context_states) + + return attn_output, attn_weights + + +class MptMLP(nn.Module): + def __init__(self, config: MptConfig): + super().__init__() + hidden_size = config.hidden_size + + self.up_proj = nn.Linear(hidden_size, 4 * hidden_size, bias=False) + self.act = nn.GELU(approximate="none") + self.down_proj = nn.Linear(4 * hidden_size, hidden_size, bias=False) + self.hidden_dropout = config.attn_config.attn_pdrop + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + hidden_states = self.act(self.up_proj(hidden_states)) + + intermediate_output = self.down_proj(hidden_states) + + output = F.dropout(intermediate_output, p=self.hidden_dropout, training=self.training) + output = output + residual + + return output + + +class MptBlock(GradientCheckpointingLayer): + def __init__(self, config: MptConfig, layer_idx: Optional[int] = None): + super().__init__() + hidden_size = config.hidden_size + + self.norm_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # backward compatibility with weights on the Hub + self.norm_1.bias = None + + self.num_heads = config.n_heads + self.attn = MptAttention(config, layer_idx) + + self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # backward compatibility with weights on the Hub + self.norm_2.bias = None + + self.ffn = MptMLP(config) + + self.dropout_rate = config.attn_config.attn_pdrop + self.resid_attn_dropout = nn.Dropout(self.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Cache] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.norm_1(hidden_states) + + residual = hidden_states + + # Self attention. + attn_outputs, attn_weights = self.attn( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_values=layer_past, + cache_position=cache_position, + ) + + hidden_states = self.resid_attn_dropout(attn_outputs) + residual + + layernorm_output = self.norm_2(hidden_states) + + # Get residual + residual = hidden_states + + # MLP. + output = self.ffn(layernorm_output, residual) + return output, attn_weights + + +@auto_docstring +class MptPreTrainedModel(PreTrainedModel): + config: MptConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["MptBlock"] + _keys_to_ignore_on_load_missing = [r"lm_head.*."] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LayerNorm): + if module.bias is not None: + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + @staticmethod + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def _convert_to_mpt_cache( + past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]], + ) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_values[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].reshape(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].reshape(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_values + ) + + +@auto_docstring +class MptModel(MptPreTrainedModel): + def __init__(self, config: MptConfig): + super().__init__(config) + + self.hidden_size = config.hidden_size + self.num_heads = config.n_heads + + # Embedding + LN Embedding + self.wte = nn.Embedding(config.vocab_size, self.hidden_size) + + # Transformer blocks + self.blocks = nn.ModuleList([MptBlock(config, layer_idx=i) for i in range(config.n_layers)]) + + # Final Layer Norm + self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon) + # backward compatibility with weights on the Hub + self.norm_f.bias = None + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None): + return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device) + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.wte = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` + (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + hidden_states = inputs_embeds + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Compute alibi tensor: check build_alibi_tensor documentation + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device) + + causal_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + causal_mask = causal_mask.bool() + + for block in self.blocks: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + use_cache=use_cache, + output_attentions=output_attentions, + position_bias=alibi, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1],) + + # Add last hidden state + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring( + custom_intro=""" + The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """ +) +class MptForCausalLM(MptPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: MptConfig): + super().__init__(config) + self.transformer = MptModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def set_output_embeddings(self, new_embeddings: torch.Tensor): + self.lm_head = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` + (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Flatten the tokens + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The MPT Model transformer with a sequence classification head on top (linear layer). + + [`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class MptForSequenceClassification(MptPreTrainedModel): + def __init__(self, config: MptConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = MptModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` + (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class MptForTokenClassification(MptPreTrainedModel): + def __init__(self, config: MptConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = MptModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` + (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class MptForQuestionAnswering(MptPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.transformer = MptModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` + (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "MptForCausalLM", + "MptModel", + "MptPreTrainedModel", + "MptForSequenceClassification", + "MptForTokenClassification", + "MptForQuestionAnswering", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mra/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2963ad0c97ba8c0c5bdc32ff9adca79ffc56cde3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mra/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mra import * + from .modeling_mra import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mra/configuration_mra.py b/venv/lib/python3.13/site-packages/transformers/models/mra/configuration_mra.py new file mode 100644 index 0000000000000000000000000000000000000000..16b064c98f7e6a9ed4fe72d7149fb867380c904d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mra/configuration_mra.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MRA model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MraConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MraModel`]. It is used to instantiate an MRA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Mra + [uw-madison/mra-base-512-4](https://huggingface.co/uw-madison/mra-base-512-4) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the Mra model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MraModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 1): + The vocabulary size of the `token_type_ids` passed when calling [`MraModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. + block_per_row (`int`, *optional*, defaults to 4): + Used to set the budget for the high resolution scale. + approx_mode (`str`, *optional*, defaults to `"full"`): + Controls whether both low and high resolution approximations are used. Set to `"full"` for both low and + high resolution and `"sparse"` for only low resolution. + initial_prior_first_n_blocks (`int`, *optional*, defaults to 0): + The initial number of blocks for which high resolution is used. + initial_prior_diagonal_n_blocks (`int`, *optional*, defaults to 0): + The number of diagonal blocks for which high resolution is used. + + Example: + + ```python + >>> from transformers import MraConfig, MraModel + + >>> # Initializing a Mra uw-madison/mra-base-512-4 style configuration + >>> configuration = MraConfig() + + >>> # Initializing a model (with random weights) from the uw-madison/mra-base-512-4 style configuration + >>> model = MraModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mra" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=1, + initializer_range=0.02, + layer_norm_eps=1e-5, + position_embedding_type="absolute", + block_per_row=4, + approx_mode="full", + initial_prior_first_n_blocks=0, + initial_prior_diagonal_n_blocks=0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.block_per_row = block_per_row + self.approx_mode = approx_mode + self.initial_prior_first_n_blocks = initial_prior_first_n_blocks + self.initial_prior_diagonal_n_blocks = initial_prior_diagonal_n_blocks + + +__all__ = ["MraConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mra/modeling_mra.py b/venv/lib/python3.13/site-packages/transformers/models/mra/modeling_mra.py new file mode 100644 index 0000000000000000000000000000000000000000..86bee4d09b5af6663283ac2e404bc39ea9b15e3d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mra/modeling_mra.py @@ -0,0 +1,1391 @@ +# coding=utf-8 +# Copyright 2023 University of Wisconsin-Madison and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MRA model.""" + +import math +from pathlib import Path +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.utils.cpp_extension import load + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, is_cuda_platform, is_ninja_available, is_torch_cuda_available, logging +from .configuration_mra import MraConfig + + +logger = logging.get_logger(__name__) + +mra_cuda_kernel = None + + +def load_cuda_kernels(): + global mra_cuda_kernel + src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra" + + def append_root(files): + return [src_folder / file for file in files] + + src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"]) + + mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True) + + +def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block): + """ + Computes maximum values for softmax stability. + """ + if len(sparse_qk_prod.size()) != 4: + raise ValueError("sparse_qk_prod must be a 4-dimensional tensor.") + + if len(indices.size()) != 2: + raise ValueError("indices must be a 2-dimensional tensor.") + + if sparse_qk_prod.size(2) != 32: + raise ValueError("The size of the second dimension of sparse_qk_prod must be 32.") + + if sparse_qk_prod.size(3) != 32: + raise ValueError("The size of the third dimension of sparse_qk_prod must be 32.") + + index_vals = sparse_qk_prod.max(dim=-2).values.transpose(-1, -2) + index_vals = index_vals.contiguous() + + indices = indices.int() + indices = indices.contiguous() + + max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block) + max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :] + + return max_vals, max_vals_scatter + + +def sparse_mask(mask, indices, block_size=32): + """ + Converts attention mask to a sparse mask for high resolution logits. + """ + if len(mask.size()) != 2: + raise ValueError("mask must be a 2-dimensional tensor.") + + if len(indices.size()) != 2: + raise ValueError("indices must be a 2-dimensional tensor.") + + if mask.shape[0] != indices.shape[0]: + raise ValueError("mask and indices must have the same size in the zero-th dimension.") + + batch_size, seq_len = mask.shape + num_block = seq_len // block_size + + batch_idx = torch.arange(indices.size(0), dtype=torch.long, device=indices.device) + mask = mask.reshape(batch_size, num_block, block_size) + mask = mask[batch_idx[:, None], (indices % num_block).long(), :] + + return mask + + +def mm_to_sparse(dense_query, dense_key, indices, block_size=32): + """ + Performs Sampled Dense Matrix Multiplication. + """ + batch_size, query_size, dim = dense_query.size() + _, key_size, dim = dense_key.size() + + if query_size % block_size != 0: + raise ValueError("query_size (size of first dimension of dense_query) must be divisible by block_size.") + + if key_size % block_size != 0: + raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.") + + dense_query = dense_query.reshape(batch_size, query_size // block_size, block_size, dim).transpose(-1, -2) + dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2) + + if len(dense_query.size()) != 4: + raise ValueError("dense_query must be a 4-dimensional tensor.") + + if len(dense_key.size()) != 4: + raise ValueError("dense_key must be a 4-dimensional tensor.") + + if len(indices.size()) != 2: + raise ValueError("indices must be a 2-dimensional tensor.") + + if dense_query.size(3) != 32: + raise ValueError("The third dimension of dense_query must be 32.") + + if dense_key.size(3) != 32: + raise ValueError("The third dimension of dense_key must be 32.") + + dense_query = dense_query.contiguous() + dense_key = dense_key.contiguous() + + indices = indices.int() + indices = indices.contiguous() + + return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int()) + + +def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32): + """ + Performs matrix multiplication of a sparse matrix with a dense matrix. + """ + batch_size, key_size, dim = dense_key.size() + + if key_size % block_size != 0: + raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.") + + if sparse_query.size(2) != block_size: + raise ValueError("The size of the second dimension of sparse_query must be equal to the block_size.") + + if sparse_query.size(3) != block_size: + raise ValueError("The size of the third dimension of sparse_query must be equal to the block_size.") + + dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2) + + if len(sparse_query.size()) != 4: + raise ValueError("sparse_query must be a 4-dimensional tensor.") + + if len(dense_key.size()) != 4: + raise ValueError("dense_key must be a 4-dimensional tensor.") + + if len(indices.size()) != 2: + raise ValueError("indices must be a 2-dimensional tensor.") + + if dense_key.size(3) != 32: + raise ValueError("The size of the third dimension of dense_key must be 32.") + + sparse_query = sparse_query.contiguous() + + indices = indices.int() + indices = indices.contiguous() + dense_key = dense_key.contiguous() + + dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block) + dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim) + return dense_qk_prod + + +def transpose_indices(indices, dim_1_block, dim_2_block): + return ((indices % dim_2_block) * dim_1_block + torch.div(indices, dim_2_block, rounding_mode="floor")).long() + + +class MraSampledDenseMatMul(torch.autograd.Function): + @staticmethod + def forward(ctx, dense_query, dense_key, indices, block_size): + sparse_qk_prod = mm_to_sparse(dense_query, dense_key, indices, block_size) + ctx.save_for_backward(dense_query, dense_key, indices) + ctx.block_size = block_size + return sparse_qk_prod + + @staticmethod + def backward(ctx, grad): + dense_query, dense_key, indices = ctx.saved_tensors + block_size = ctx.block_size + query_num_block = dense_query.size(1) // block_size + key_num_block = dense_key.size(1) // block_size + indices_T = transpose_indices(indices, query_num_block, key_num_block) + grad_key = sparse_dense_mm(grad.transpose(-1, -2), indices_T, dense_query, key_num_block) + grad_query = sparse_dense_mm(grad, indices, dense_key, query_num_block) + return grad_query, grad_key, None, None + + @staticmethod + def operator_call(dense_query, dense_key, indices, block_size=32): + return MraSampledDenseMatMul.apply(dense_query, dense_key, indices, block_size) + + +class MraSparseDenseMatMul(torch.autograd.Function): + @staticmethod + def forward(ctx, sparse_query, indices, dense_key, query_num_block): + sparse_qk_prod = sparse_dense_mm(sparse_query, indices, dense_key, query_num_block) + ctx.save_for_backward(sparse_query, indices, dense_key) + ctx.query_num_block = query_num_block + return sparse_qk_prod + + @staticmethod + def backward(ctx, grad): + sparse_query, indices, dense_key = ctx.saved_tensors + query_num_block = ctx.query_num_block + key_num_block = dense_key.size(1) // sparse_query.size(-1) + indices_T = transpose_indices(indices, query_num_block, key_num_block) + grad_key = sparse_dense_mm(sparse_query.transpose(-1, -2), indices_T, grad, key_num_block) + grad_query = mm_to_sparse(grad, dense_key, indices) + return grad_query, None, grad_key, None + + @staticmethod + def operator_call(sparse_query, indices, dense_key, query_num_block): + return MraSparseDenseMatMul.apply(sparse_query, indices, dense_key, query_num_block) + + +class MraReduceSum: + @staticmethod + def operator_call(sparse_query, indices, query_num_block, key_num_block): + batch_size, num_block, block_size, _ = sparse_query.size() + + if len(sparse_query.size()) != 4: + raise ValueError("sparse_query must be a 4-dimensional tensor.") + + if len(indices.size()) != 2: + raise ValueError("indices must be a 2-dimensional tensor.") + + _, _, block_size, _ = sparse_query.size() + batch_size, num_block = indices.size() + + sparse_query = sparse_query.sum(dim=2).reshape(batch_size * num_block, block_size) + + batch_idx = torch.arange(indices.size(0), dtype=torch.long, device=indices.device) + global_idxes = ( + torch.div(indices, key_num_block, rounding_mode="floor").long() + batch_idx[:, None] * query_num_block + ).reshape(batch_size * num_block) + temp = torch.zeros( + (batch_size * query_num_block, block_size), dtype=sparse_query.dtype, device=sparse_query.device + ) + output = temp.index_add(0, global_idxes, sparse_query).reshape(batch_size, query_num_block, block_size) + + output = output.reshape(batch_size, query_num_block * block_size) + return output + + +def get_low_resolution_logit(query, key, block_size, mask=None, value=None): + """ + Compute low resolution approximation. + """ + batch_size, seq_len, head_dim = query.size() + + num_block_per_row = seq_len // block_size + + value_hat = None + if mask is not None: + token_count = mask.reshape(batch_size, num_block_per_row, block_size).sum(dim=-1) + query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / ( + token_count[:, :, None] + 1e-6 + ) + key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / ( + token_count[:, :, None] + 1e-6 + ) + if value is not None: + value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / ( + token_count[:, :, None] + 1e-6 + ) + else: + token_count = block_size * torch.ones(batch_size, num_block_per_row, dtype=torch.float, device=query.device) + query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2) + key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2) + if value is not None: + value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2) + + low_resolution_logit = torch.matmul(query_hat, key_hat.transpose(-1, -2)) / math.sqrt(head_dim) + + low_resolution_logit_row_max = low_resolution_logit.max(dim=-1, keepdims=True).values + + if mask is not None: + low_resolution_logit = ( + low_resolution_logit - 1e4 * ((token_count[:, None, :] * token_count[:, :, None]) < 0.5).float() + ) + + return low_resolution_logit, token_count, low_resolution_logit_row_max, value_hat + + +def get_block_idxes( + low_resolution_logit, num_blocks, approx_mode, initial_prior_first_n_blocks, initial_prior_diagonal_n_blocks +): + """ + Compute the indices of the subset of components to be used in the approximation. + """ + batch_size, total_blocks_per_row, _ = low_resolution_logit.shape + + if initial_prior_diagonal_n_blocks > 0: + offset = initial_prior_diagonal_n_blocks // 2 + temp_mask = torch.ones(total_blocks_per_row, total_blocks_per_row, device=low_resolution_logit.device) + diagonal_mask = torch.tril(torch.triu(temp_mask, diagonal=-offset), diagonal=offset) + low_resolution_logit = low_resolution_logit + diagonal_mask[None, :, :] * 5e3 + + if initial_prior_first_n_blocks > 0: + low_resolution_logit[:, :initial_prior_first_n_blocks, :] = ( + low_resolution_logit[:, :initial_prior_first_n_blocks, :] + 5e3 + ) + low_resolution_logit[:, :, :initial_prior_first_n_blocks] = ( + low_resolution_logit[:, :, :initial_prior_first_n_blocks] + 5e3 + ) + + top_k_vals = torch.topk( + low_resolution_logit.reshape(batch_size, -1), num_blocks, dim=-1, largest=True, sorted=False + ) + indices = top_k_vals.indices + + if approx_mode == "full": + threshold = top_k_vals.values.min(dim=-1).values + high_resolution_mask = (low_resolution_logit >= threshold[:, None, None]).float() + elif approx_mode == "sparse": + high_resolution_mask = None + else: + raise ValueError(f"{approx_mode} is not a valid approx_model value.") + + return indices, high_resolution_mask + + +def mra2_attention( + query, + key, + value, + mask, + num_blocks, + approx_mode, + block_size=32, + initial_prior_first_n_blocks=0, + initial_prior_diagonal_n_blocks=0, +): + """ + Use Mra to approximate self-attention. + """ + if mra_cuda_kernel is None: + return torch.zeros_like(query).requires_grad_() + + batch_size, num_head, seq_len, head_dim = query.size() + meta_batch = batch_size * num_head + + if seq_len % block_size != 0: + raise ValueError("sequence length must be divisible by the block_size.") + + num_block_per_row = seq_len // block_size + + query = query.reshape(meta_batch, seq_len, head_dim) + key = key.reshape(meta_batch, seq_len, head_dim) + value = value.reshape(meta_batch, seq_len, head_dim) + + if mask is not None: + query = query * mask[:, :, None] + key = key * mask[:, :, None] + value = value * mask[:, :, None] + + if approx_mode == "full": + low_resolution_logit, token_count, low_resolution_logit_row_max, value_hat = get_low_resolution_logit( + query, key, block_size, mask, value + ) + elif approx_mode == "sparse": + with torch.no_grad(): + low_resolution_logit, token_count, low_resolution_logit_row_max, _ = get_low_resolution_logit( + query, key, block_size, mask + ) + else: + raise Exception('approx_mode must be "full" or "sparse"') + + with torch.no_grad(): + low_resolution_logit_normalized = low_resolution_logit - low_resolution_logit_row_max + indices, high_resolution_mask = get_block_idxes( + low_resolution_logit_normalized, + num_blocks, + approx_mode, + initial_prior_first_n_blocks, + initial_prior_diagonal_n_blocks, + ) + + high_resolution_logit = MraSampledDenseMatMul.operator_call( + query, key, indices, block_size=block_size + ) / math.sqrt(head_dim) + max_vals, max_vals_scatter = sparse_max(high_resolution_logit, indices, num_block_per_row, num_block_per_row) + high_resolution_logit = high_resolution_logit - max_vals_scatter + if mask is not None: + high_resolution_logit = high_resolution_logit - 1e4 * (1 - sparse_mask(mask, indices)[:, :, :, None]) + high_resolution_attn = torch.exp(high_resolution_logit) + high_resolution_attn_out = MraSparseDenseMatMul.operator_call( + high_resolution_attn, indices, value, num_block_per_row + ) + high_resolution_normalizer = MraReduceSum.operator_call( + high_resolution_attn, indices, num_block_per_row, num_block_per_row + ) + + if approx_mode == "full": + low_resolution_attn = ( + torch.exp(low_resolution_logit - low_resolution_logit_row_max - 1e4 * high_resolution_mask) + * token_count[:, None, :] + ) + + low_resolution_attn_out = ( + torch.matmul(low_resolution_attn, value_hat)[:, :, None, :] + .repeat(1, 1, block_size, 1) + .reshape(meta_batch, seq_len, head_dim) + ) + low_resolution_normalizer = ( + low_resolution_attn.sum(dim=-1)[:, :, None].repeat(1, 1, block_size).reshape(meta_batch, seq_len) + ) + + log_correction = low_resolution_logit_row_max.repeat(1, 1, block_size).reshape(meta_batch, seq_len) - max_vals + if mask is not None: + log_correction = log_correction * mask + + low_resolution_corr = torch.exp(log_correction * (log_correction <= 0).float()) + low_resolution_attn_out = low_resolution_attn_out * low_resolution_corr[:, :, None] + low_resolution_normalizer = low_resolution_normalizer * low_resolution_corr + + high_resolution_corr = torch.exp(-log_correction * (log_correction > 0).float()) + high_resolution_attn_out = high_resolution_attn_out * high_resolution_corr[:, :, None] + high_resolution_normalizer = high_resolution_normalizer * high_resolution_corr + + context_layer = (high_resolution_attn_out + low_resolution_attn_out) / ( + high_resolution_normalizer[:, :, None] + low_resolution_normalizer[:, :, None] + 1e-6 + ) + + elif approx_mode == "sparse": + context_layer = high_resolution_attn_out / (high_resolution_normalizer[:, :, None] + 1e-6) + else: + raise Exception('config.approx_mode must be "full" or "sparse"') + + if mask is not None: + context_layer = context_layer * mask[:, :, None] + + context_layer = context_layer.reshape(batch_size, num_head, seq_len, head_dim) + + return context_layer + + +class MraEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class MraSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + kernel_loaded = mra_cuda_kernel is not None + if is_torch_cuda_available() and is_cuda_platform() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = ( + position_embedding_type if position_embedding_type is not None else config.position_embedding_type + ) + + self.num_block = (config.max_position_embeddings // 32) * config.block_per_row + self.num_block = min(self.num_block, int((config.max_position_embeddings // 32) ** 2)) + + self.approx_mode = config.approx_mode + self.initial_prior_first_n_blocks = config.initial_prior_first_n_blocks + self.initial_prior_diagonal_n_blocks = config.initial_prior_diagonal_n_blocks + + def forward(self, hidden_states, attention_mask=None): + batch_size, seq_len, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + # revert changes made by get_extended_attention_mask + attention_mask = 1.0 + attention_mask / 10000.0 + attention_mask = ( + attention_mask.squeeze() + .repeat(1, self.num_attention_heads, 1) + .reshape(batch_size * self.num_attention_heads, seq_len) + .int() + ) + + # The CUDA kernels are most efficient with inputs whose size is a multiple of a GPU's warp size (32). Inputs + # smaller than this are padded with zeros. + gpu_warp_size = 32 + + if self.attention_head_size < gpu_warp_size: + pad_size = batch_size, self.num_attention_heads, seq_len, gpu_warp_size - self.attention_head_size + + query_layer = torch.cat([query_layer, torch.zeros(pad_size, device=query_layer.device)], dim=-1) + key_layer = torch.cat([key_layer, torch.zeros(pad_size, device=key_layer.device)], dim=-1) + value_layer = torch.cat([value_layer, torch.zeros(pad_size, device=value_layer.device)], dim=-1) + + context_layer = mra2_attention( + query_layer.float(), + key_layer.float(), + value_layer.float(), + attention_mask.float(), + self.num_block, + approx_mode=self.approx_mode, + initial_prior_first_n_blocks=self.initial_prior_first_n_blocks, + initial_prior_diagonal_n_blocks=self.initial_prior_diagonal_n_blocks, + ) + + if self.attention_head_size < gpu_warp_size: + context_layer = context_layer[:, :, :, : self.attention_head_size] + + context_layer = context_layer.reshape(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer,) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class MraSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class MraAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = MraSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = MraSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, attention_mask=None): + self_outputs = self.self(hidden_states, attention_mask) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class MraIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class MraOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class MraLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = MraAttention(config) + self.add_cross_attention = config.add_cross_attention + self.intermediate = MraIntermediate(config) + self.output = MraOutput(config) + + def forward(self, hidden_states, attention_mask=None): + self_attention_outputs = self.attention(hidden_states, attention_mask) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class MraEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([MraLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, attention_mask) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform +class MraPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Mra +class MraLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = MraPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Mra +class MraOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = MraLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@auto_docstring +# Copied from transformers.models.yoso.modeling_yoso.YosoPreTrainedModel with Yoso->Mra,yoso->mra +class MraPreTrainedModel(PreTrainedModel): + config: MraConfig + base_model_prefix = "mra" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, MraLMPredictionHead): + module.bias.data.zero_() + + +@auto_docstring +class MraModel(MraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = MraEmbeddings(config) + self.encoder = MraEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithCrossAttentions]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring +class MraForMaskedLM(MraPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.mra = MraModel(config) + self.cls = MraOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.yoso.modeling_yoso.YosoClassificationHead with Yoso->Mra +class MraClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@auto_docstring( + custom_intro=""" + MRA Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. + """ +) +class MraForSequenceClassification(MraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.mra = MraModel(config) + self.classifier = MraClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MraForMultipleChoice(MraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.mra = MraModel(config) + self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.mra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_state = outputs[0] # (bs * num_choices, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs * num_choices, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MraForTokenClassification(MraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.mra = MraModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MraForQuestionAnswering(MraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.mra = MraModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "MraForMaskedLM", + "MraForMultipleChoice", + "MraForQuestionAnswering", + "MraForSequenceClassification", + "MraForTokenClassification", + "MraLayer", + "MraModel", + "MraPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mt5/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mt5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..444a8f8cc8e0203a4cd0151c6bff0644e4e3d6b6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mt5/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mt5 import * + from .modeling_flax_mt5 import * + from .modeling_mt5 import * + from .modeling_tf_mt5 import * + from .tokenization_mt5 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mt5/configuration_mt5.py b/venv/lib/python3.13/site-packages/transformers/models/mt5/configuration_mt5.py new file mode 100644 index 0000000000000000000000000000000000000000..bad66e63b1260587828ebca4822fc57aa635175a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mt5/configuration_mt5.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""mT5 model configuration""" + +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MT5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MT5Model`] or a [`TFMT5Model`]. It is used to + instantiate a mT5 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the mT5 + [google/mt5-small](https://huggingface.co/google/mt5-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 250112): + Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. In the conventional context, it is typically expected that `d_kv` has to be equal to `d_model // num_heads`. + But in the architecture of mt5-small, `d_kv` is not equal to `d_model //num_heads`. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 1024): + Size of the intermediate feed forward layer in each `T5Block`. + num_layers (`int`, *optional*, defaults to 8): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 6): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"gated-gelu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "mt5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } + + def __init__( + self, + vocab_size=250112, + d_model=512, + d_kv=64, + d_ff=1024, + num_layers=8, + num_decoder_layers=None, + num_heads=6, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="gated-gelu", + is_encoder_decoder=True, + use_cache=True, + tokenizer_class="T5Tokenizer", + tie_word_embeddings=False, + pad_token_id=0, + eos_token_id=1, + decoder_start_token_id=0, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + is_encoder_decoder=is_encoder_decoder, + tokenizer_class=tokenizer_class, + tie_word_embeddings=tie_word_embeddings, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + # Copied from transformers.models.t5.configuration_t5.T5OnnxConfig.inputs + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + # Copied from transformers.models.t5.configuration_t5.T5OnnxConfig.default_onnx_opset + def default_onnx_opset(self) -> int: + return 13 + + @property + def atol_for_validation(self) -> float: + return 5e-4 + + +__all__ = ["MT5Config", "MT5OnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_flax_mt5.py b/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_flax_mt5.py new file mode 100644 index 0000000000000000000000000000000000000000..13bd83b75034ba6de1260a0bd754f289979e4fa7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_flax_mt5.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Copyright 2021 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax mT5 model.""" + +import jax.numpy as jnp + +from ...utils import logging +from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model +from .configuration_mt5 import MT5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +class FlaxMT5Model(FlaxT5Model): + r""" + This class overrides [`FlaxT5Model`]. Please check the superclass for the appropriate documentation alongside usage + examples. + + Examples: + + ```python + >>> from transformers import FlaxMT5Model, AutoTokenizer + + >>> model = FlaxMT5Model.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="np") + + >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids + + >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids) + >>> hidden_states = outputs.last_hidden_state + ```""" + + model_type = "mt5" + config_class = MT5Config + + +class FlaxMT5EncoderModel(FlaxT5EncoderModel): + r""" + This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation + alongside usage examples. + + Examples: + + ```python + >>> from transformers import FlaxT5EncoderModel, AutoTokenizer + + >>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="np") + + >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids + + >>> outputs = model(input_ids=inputs["input_ids"]) + >>> hidden_states = outputs.last_hidden_state + ```""" + + model_type = "mt5" + config_class = MT5Config + + +class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration): + r""" + This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate + documentation alongside usage examples. + + Examples: + + ```python + >>> from transformers import FlaxMT5ForConditionalGeneration, AutoTokenizer + + >>> model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="np") + + >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids + + >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids) + >>> logits = outputs.logits + ```""" + + model_type = "mt5" + config_class = MT5Config + + +__all__ = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_mt5.py b/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_mt5.py new file mode 100644 index 0000000000000000000000000000000000000000..4e57d0aadda2a6793bc955e730732deb97c35a11 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_mt5.py @@ -0,0 +1,2453 @@ +# coding=utf-8 +# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch mT5 model.""" + +import copy +import math +import os +import warnings +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + auto_docstring, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_mt5 import MT5Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + +logger = logging.get_logger(__name__) + + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`dict[int, list]`, *optional*): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the mt5 models have the + following number of attention modules: + + - mt5-small: 6 + - mt5-base: 12 + - mt5-large: 24 + - mt5-xl: 24 + - mt5-xxl: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using mt5-xl, which has a total of 24 attention modules: + model = MT5ForConditionalGeneration.from_pretrained("mt5-xl") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with mt5-xl: + model = MT5ForConditionalGeneration.from_pretrained("Mt5-xl") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->MT5 +class MT5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the MT5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # MT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->MT5 +class MT5DenseActDense(nn.Module): + def __init__(self, config: MT5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->MT5 +class MT5DenseGatedActDense(nn.Module): + def __init__(self, config: MT5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->MT5 +class MT5LayerFF(nn.Module): + def __init__(self, config: MT5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = MT5DenseGatedActDense(config) + else: + self.DenseReluDense = MT5DenseActDense(config) + + self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5 +class MT5Attention(nn.Module): + def __init__( + self, + config: MT5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_values=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + is_updated = False + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 +class MT5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = MT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5 +class MT5LayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 +class MT5Block(GradientCheckpointingLayer): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) + if self.is_decoder: + self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(MT5LayerFF(config)) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return ( + outputs + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert pointer.shape == array.shape, ( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->MT5 +class MT5ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: MT5Config): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +@auto_docstring +# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->MT5, t5->mt5 +class MT5PreTrainedModel(PreTrainedModel): + config: MT5Config + load_tf_weights = load_tf_weights_in_mt5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _can_compile_fullgraph = True + + _no_split_modules = ["MT5Block"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, MT5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() + elif isinstance(module, MT5ForTokenClassification): + if hasattr(module, "classifier"): + module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.data.zero_() + elif isinstance(module, MT5ClassificationHead): + module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.dense, "bias") and module.dense.bias is not None: + module.dense.bias.data.zero_() + module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, MT5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, MT5DenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, MT5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In MT5 it is usually set to the pad_token_id. " + "See MT5 docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.t5.modeling_t5.T5Stack with T5->MT5 +class MT5Stack(MT5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] + ) + self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`MT5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache( + DynamicCache(config=self.config), DynamicCache(config=self.config) + ) + else: + past_key_values = DynamicCache(config=self.config) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@auto_docstring +class MT5Model(MT5PreTrainedModel): + r""" + Examples: + + ```python + >>> from transformers import MT5Model, AutoTokenizer + + >>> model = MT5Model.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="pt") + >>> labels = tokenizer(text_target=summary, return_tensors="pt") + + >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) + >>> hidden_states = outputs.last_hidden_state + ```""" + + model_type = "mt5" + config: MT5Config + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5 + def __init__(self, config: MT5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + self.encoder = MT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = MT5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + # Copied from transformers.models.t5.modeling_t5.T5Model.parallelize + def parallelize(self, device_map=None): + warnings.warn( + "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" + " 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # Copied from transformers.models.t5.modeling_t5.T5Model.deparallelize + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder + def get_encoder(self): + return self.encoder + + # Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + # Copied from transformers.models.t5.modeling_t5.T5Model.forward with google-t5/->google/, T5->MT5, t5->mt5 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5 + Training](./mt5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> model = MT5Model.from_pretrained("google/mt5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for MT5Model. + >>> # This is not needed for torch's MT5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + MT5 Model with a `language modeling` head on top. + """ +) +class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin): + r""" + Examples: + + ```python + >>> from transformers import MT5ForConditionalGeneration, AutoTokenizer + + >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> loss = outputs.loss + ```""" + + model_type = "mt5" + config: MT5Config + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5 + def __init__(self, config: MT5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + self.encoder = MT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = MT5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize + def parallelize(self, device_map=None): + warnings.warn( + "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" + " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" + " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.deparallelize + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder + def get_encoder(self): + return self.encoder + + @auto_docstring + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with google-t5/->google/, T5->MT5, t5->mt5 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5 + Training](./mt5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Examples: + + ```python + >>> from transformers import AutoTokenizer, MT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class MT5EncoderModel(MT5PreTrainedModel): + r""" + Examples: + + ```python + >>> from transformers import MT5EncoderModel, AutoTokenizer + + >>> model = MT5EncoderModel.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> input_ids = tokenizer(article, return_tensors="pt").input_ids + >>> outputs = model(input_ids) + >>> hidden_state = outputs.last_hidden_state + ```""" + + model_type = "mt5" + config: MT5Config + _tied_weights_keys = ["encoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5 + def __init__(self, config: MT5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = config + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = MT5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.parallelize + def parallelize(self, device_map=None): + warnings.warn( + "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.deparallelize + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder + def get_encoder(self): + return self.encoder + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @auto_docstring + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with google-t5/->google/, T5->MT5, t5->mt5 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training). + + Example: + + ```python + >>> from transformers import AutoTokenizer, MT5EncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> model = MT5EncoderModel.from_pretrained("google/mt5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +@auto_docstring( + custom_intro=""" + MT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """ +) +class MT5ForSequenceClassification(MT5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->MT5 + def __init__(self, config: MT5Config): + super().__init__(config) + self.transformer = MT5Model(config) + self.classification_head = MT5ClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + @auto_docstring + # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.forward with T5->MT5, t5->mt5 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5 + Training](./mt5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + # Copied from models.bart.modeling_bart.BartModel.forward different to other models, MT5 automatically creates + # decoder_input_ids from input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + batch_size, _, hidden_size = sequence_output.shape + sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@auto_docstring +class MT5ForTokenClassification(MT5PreTrainedModel): + _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->MT5 + def __init__(self, config: MT5Config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = MT5EncoderModel(config) + self.dropout = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->MT5 + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./t5#training). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, outputs[2:-1]) + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MT5ForQuestionAnswering(MT5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.__init__ with T5->MT5 + def __init__(self, config: MT5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + self.encoder = MT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = MT5Stack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder + def get_encoder(self): + return self.encoder + + @auto_docstring + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # Copied from models.bart.modeling_bart.BartModel.forward + # different to other models, T5 automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +__all__ = [ + "MT5EncoderModel", + "MT5ForConditionalGeneration", + "MT5ForQuestionAnswering", + "MT5ForSequenceClassification", + "MT5ForTokenClassification", + "MT5Model", + "MT5PreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_tf_mt5.py b/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_tf_mt5.py new file mode 100644 index 0000000000000000000000000000000000000000..6152aea0a5acad5d1d6db9b7ee044cbde2bca976 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mt5/modeling_tf_mt5.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow mT5 model.""" + +from ...utils import logging +from ..t5.modeling_tf_t5 import TFT5EncoderModel, TFT5ForConditionalGeneration, TFT5Model +from .configuration_mt5 import MT5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" + + +class TFMT5Model(TFT5Model): + r""" + This class overrides [`TFT5Model`]. Please check the superclass for the appropriate documentation alongside usage + examples. + + Examples: + + ```python + >>> from transformers import TFMT5Model, AutoTokenizer + + >>> model = TFMT5Model.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="tf") + >>> labels = tokenizer(text_target=summary, return_tensors="tf") + + >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) + >>> hidden_states = outputs.last_hidden_state + ```""" + + model_type = "mt5" + config_class = MT5Config + + +class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration): + r""" + This class overrides [`TFT5ForConditionalGeneration`]. Please check the superclass for the appropriate + documentation alongside usage examples. + + Examples: + + ```python + >>> from transformers import TFMT5ForConditionalGeneration, AutoTokenizer + + >>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, text_target=summary, return_tensors="tf") + + >>> outputs = model(**inputs) + >>> loss = outputs.loss + ```""" + + model_type = "mt5" + config_class = MT5Config + + +class TFMT5EncoderModel(TFT5EncoderModel): + r""" + This class overrides [`TFT5EncoderModel`]. Please check the superclass for the appropriate documentation alongside + usage examples. + + Examples: + + ```python + >>> from transformers import TFMT5EncoderModel, AutoTokenizer + + >>> model = TFMT5EncoderModel.from_pretrained("google/mt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> input_ids = tokenizer(article, return_tensors="tf").input_ids + >>> outputs = model(input_ids) + >>> hidden_state = outputs.last_hidden_state + ```""" + + model_type = "mt5" + config_class = MT5Config + + +__all__ = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mt5/tokenization_mt5.py b/venv/lib/python3.13/site-packages/transformers/models/mt5/tokenization_mt5.py new file mode 100644 index 0000000000000000000000000000000000000000..a3058816ff2032c41e2b0cdb6940705a7572faa0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mt5/tokenization_mt5.py @@ -0,0 +1,24 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""mT5 tokenization file""" + +from ..t5 import T5Tokenizer + + +class MT5Tokenizer(T5Tokenizer): + pass + + +__all__ = ["MT5Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mt5/tokenization_mt5_fast.py b/venv/lib/python3.13/site-packages/transformers/models/mt5/tokenization_mt5_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..8737088cc44206bea1e2f9d1793ae49692ae35e8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mt5/tokenization_mt5_fast.py @@ -0,0 +1,24 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""mT5 tokenization file""" + +from ..t5 import T5TokenizerFast + + +class MT5TokenizerFast(T5TokenizerFast): + pass + + +__all__ = ["MT5TokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mvp/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/mvp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..beab37f65c1a8daa20e27a816db5e832c49f6fa0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mvp/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mvp import * + from .modeling_mvp import * + from .tokenization_mvp import * + from .tokenization_mvp_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/mvp/configuration_mvp.py b/venv/lib/python3.13/site-packages/transformers/models/mvp/configuration_mvp.py new file mode 100644 index 0000000000000000000000000000000000000000..c216e53ed81ab350e7648859ced14d2dbcda8dca --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mvp/configuration_mvp.py @@ -0,0 +1,183 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MVP model configuration""" + +import warnings + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MvpConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MvpModel`]. It is used to instantiate a MVP model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MVP [RUCAIBox/mvp](https://huggingface.co/RUCAIBox/mvp) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50267): + Vocabulary size of the MVP model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MvpModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + forced_eos_token_id (`int`, *optional*, defaults to 2): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + use_prompt (`bool`, *optional*, defaults to `False`): + Whether or not to use prompt. + prompt_length (`int`, *optional*, defaults to 100): + The length of prompt. + prompt_mid_dim (`int`, *optional*, defaults to 800): + Dimensionality of the "intermediate" layer in prompt. + Example: + + ```python + >>> from transformers import MvpConfig, MvpModel + + >>> # Initializing a MVP RUCAIBox/mvp style configuration + >>> configuration = MvpConfig() + + >>> # Initializing a model (with random weights) from the RUCAIBox/mvp style configuration + >>> model = MvpModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mvp" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=50267, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + is_encoder_decoder=True, + decoder_start_token_id=2, + forced_eos_token_id=2, + use_prompt=False, + prompt_length=100, + prompt_mid_dim=800, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.use_prompt = use_prompt + self.prompt_length = prompt_length + self.prompt_mid_dim = prompt_mid_dim + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): + self.forced_bos_token_id = self.bos_token_id + warnings.warn( + f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " + "The config can simply be saved and uploaded again to be fixed." + ) + + +__all__ = ["MvpConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mvp/modeling_mvp.py b/venv/lib/python3.13/site-packages/transformers/models/mvp/modeling_mvp.py new file mode 100644 index 0000000000000000000000000000000000000000..6838f209cb4e51dc9fbf0679a20c85864809e8a4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mvp/modeling_mvp.py @@ -0,0 +1,1786 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MVP model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_mvp import MvpConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->Mvp +class MvpLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Mvp is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) + + return super().forward(position_ids + self.offset) + + +class MvpAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + attn_prompt: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + if attn_prompt is not None: + key_states = torch.cat([attn_prompt[0].expand(bsz, -1, -1, -1), key_states], dim=2) + value_states = torch.cat([attn_prompt[1].expand(bsz, -1, -1, -1), value_states], dim=2) + if attention_mask is not None: + prompt_mask = torch.zeros(bsz, 1, tgt_len, attn_prompt[0].size(1)).to(attention_mask.device) + attention_mask = torch.cat([prompt_mask, attention_mask], dim=(-1)) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class MvpEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MvpConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = MvpAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + self_attn_prompt: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape + `(2, encoder_attention_heads, pro_len, head_dim)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + attn_prompt=self_attn_prompt, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states, attn_weights + + +class MvpDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MvpConfig, layer_idx=None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = MvpAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = MvpAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + self_attn_prompt: Optional[torch.Tensor] = None, + cross_attn_prompt: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape + `(2, decoder_attention_heads, pro_len, head_dim)`. + cross_attn_prompt (`torch.FloatTensor`): prompt of cross attention of shape + `(2, decoder_attention_heads, pro_len, head_dim)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + attn_prompt=self_attn_prompt, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + attn_prompt=cross_attn_prompt, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MVP +class MvpClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class MvpPrompt(nn.Module): + """Layer-wise prompt for encoder or decoder.""" + + def __init__(self, config, num_layers, num_heads): + super().__init__() + self.prompt_length = config.prompt_length + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = config.d_model // num_heads + self.dropout = nn.Dropout(p=config.dropout) + self.prompt_embedding = nn.Embedding(config.prompt_length, config.d_model) + self.prompt_trans = nn.Sequential( + nn.Linear(config.d_model, config.prompt_mid_dim), + nn.GELU(), + nn.Linear(config.prompt_mid_dim, num_layers * 2 * config.d_model), + ) + + def forward(self, prompt_ids: torch.Tensor) -> tuple[torch.Tensor]: + prompt = self.prompt_trans(self.prompt_embedding(prompt_ids)) + prompt = prompt.view(self.prompt_length, self.num_layers * 2, self.num_heads, self.head_dim) + prompt = self.dropout(prompt) + prompt = prompt.permute([1, 2, 0, 3]).split(2) + return prompt + + +@auto_docstring +class MvpPreTrainedModel(PreTrainedModel): + config: MvpConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class MvpEncoder(MvpPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`MvpEncoderLayer`]. + + Args: + config: MvpConfig + embed_tokens (nn.Embedding): output embedding + use_prompt (bool): whether to use prompt + """ + + def __init__( + self, config: MvpConfig, embed_tokens: Optional[nn.Embedding] = None, use_prompt: Optional[bool] = False + ): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = MvpLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([MvpEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.use_prompt = use_prompt + if use_prompt: + self.prompt_length = config.prompt_length + self.self_attn_prompt = MvpPrompt( + config, + config.encoder_layers, + config.encoder_attention_heads, + ) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # layer-wise prompt + if self.use_prompt: + prompt_ids = torch.arange(self.prompt_length).to(self.device) + self_attn_prompt = self.self_attn_prompt(prompt_ids) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class MvpDecoder(MvpPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MvpDecoderLayer`] + + Args: + config: MvpConfig + embed_tokens (nn.Embedding): output embedding + use_prompt (bool): whether to use prompt + """ + + def __init__( + self, config: MvpConfig, embed_tokens: Optional[nn.Embedding] = None, use_prompt: Optional[bool] = False + ): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = MvpLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([MvpDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.use_prompt = use_prompt + if use_prompt: + self.prompt_length = config.prompt_length + self.self_attn_prompt = MvpPrompt( + config, + config.decoder_layers, + config.decoder_attention_heads, + ) + self.cross_attn_prompt = MvpPrompt( + config, + config.decoder_layers, + config.decoder_attention_heads, + ) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # layer-wise prompt + if self.use_prompt: + prompt_ids = torch.arange(self.prompt_length).to(self.device) + self_attn_prompt = self.self_attn_prompt(prompt_ids) + cross_attn_prompt = self.cross_attn_prompt(prompt_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), + cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class MvpModel(MvpPreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: MvpConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.use_prompt = config.use_prompt + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = MvpEncoder(config, self.shared, config.use_prompt) + self.decoder = MvpDecoder(config, self.shared, config.use_prompt) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def set_lightweight_tuning(self): + assert self.use_prompt, "If you want to use lightweight tuning, make sure that `use_prompt=True`." + + self.requires_grad_(False) + self.encoder.self_attn_prompt.requires_grad_(True) + self.decoder.self_attn_prompt.requires_grad_(True) + self.decoder.cross_attn_prompt.requires_grad_(True) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + # different to other models, Mvp automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The MVP Model with a language modeling head. Can be used for various text generation tasks. + """ +) +class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: MvpConfig): + super().__init__(config) + self.model = MvpModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_num_tokens) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def set_lightweight_tuning(self): + self.model.set_lightweight_tuning() + self.lm_head.requires_grad_(False) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example of summarization: + + Fine-tuning a model + ```python + >>> import torch + >>> from transformers import AutoTokenizer, MvpForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp") + >>> model = MvpForConditionalGeneration.from_pretrained("RUCAIBox/mvp") + + >>> inputs = tokenizer( + ... "Summarize: You may want to stick it to your boss and leave your job, but don't do it if these are your reasons.", + ... return_tensors="pt", + ... ) + >>> labels = tokenizer("Bad Reasons To Quit Your Job", return_tensors="pt")["input_ids"] + + >>> loss = model(**inputs, labels=labels).loss + >>> loss.backward() + ``` + + Inference after the model fine-tuned + ```python + >>> with torch.no_grad(): + ... generated_ids = model.generate(**inputs) + + >>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + +@auto_docstring( + custom_intro=""" + Mvp model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """ +) +class MvpForSequenceClassification(MvpPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: MvpConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = MvpModel(config) + self.classification_head = MvpClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + def set_lightweight_tuning(self): + self.model.set_lightweight_tuning() + self.classification_head.requires_grad_(False) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Example of single-label classification: + + Fine-tuning a model on `num_labels` classes + ```python + >>> import torch + >>> from transformers import AutoTokenizer, MvpForSequenceClassification + + >>> num_labels = 2 # for example, this is a binary classification task + >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp") + >>> model = MvpForSequenceClassification.from_pretrained("RUCAIBox/mvp", num_labels=num_labels) + + >>> inputs = tokenizer("Classify: Hello, my dog is cute", return_tensors="pt") + >>> labels = torch.tensor(1) # the real label for inputs + + >>> loss = model(**inputs, labels=labels).loss + >>> loss.backward() + ``` + + Inference after the model fine-tuned + ```python + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_class_id = logits.argmax() + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@auto_docstring +class MvpForQuestionAnswering(MvpPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = MvpModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def set_lightweight_tuning(self): + self.model.set_lightweight_tuning() + self.qa_outputs.requires_grad_(False) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + Fine-tuning a model for extrative question answering, and our model also supports generative question answering + using `BartForConditionalGeneration` + ```python + >>> import torch + >>> from transformers import AutoTokenizer, MvpForQuestionAnswering + + >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp") + >>> model = MvpForQuestionAnswering.from_pretrained("RUCAIBox/mvp") + + >>> inputs = tokenizer( + ... "Answer the following question: Who was Jim Henson? [SEP] Jim Henson was a nice puppet", + ... return_tensors="pt", + ... ) + >>> target_start_index = torch.tensor([18]) + >>> target_end_index = torch.tensor([19]) + + >>> loss = model(**inputs, start_positions=target_start_index, end_positions=target_end_index).loss + >>> loss.backward() + ``` + + Inference after the model fine-tuned + ```python + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] + >>> predict_answer = tokenizer.decode(predict_answer_tokens) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Mvp +class MvpDecoderWrapper(MvpPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = MvpDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = MvpDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + def set_lightweight_tuning(self): + self.model.set_lightweight_tuning() + self.lm_head.requires_grad_(False) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MvpForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp") + >>> model = MvpForCausalLM.from_pretrained("RUCAIBox/mvp", add_cross_attention=False) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 8, 50267] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +__all__ = [ + "MvpForCausalLM", + "MvpForConditionalGeneration", + "MvpForQuestionAnswering", + "MvpForSequenceClassification", + "MvpModel", + "MvpPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mvp/tokenization_mvp.py b/venv/lib/python3.13/site-packages/transformers/models/mvp/tokenization_mvp.py new file mode 100644 index 0000000000000000000000000000000000000000..f6039df2dc02564378f6f83ef1770ded9000d31f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mvp/tokenization_mvp.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from functools import lru_cache +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} + +# See all MVP models at https://huggingface.co/models?filter=mvp + + +@lru_cache +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class MvpTokenizer(PreTrainedTokenizer): + """ + Constructs a MVP tokenizer, which is smilar to the RoBERTa tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import MvpTokenizer + + >>> tokenizer = MvpTokenizer.from_pretrained("RUCAIBox/mvp") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (MVP tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, special=True) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + vocab = self.encoder.copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A MVP sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. MVP does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + +__all__ = ["MvpTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/mvp/tokenization_mvp_fast.py b/venv/lib/python3.13/site-packages/transformers/models/mvp/tokenization_mvp_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..ca0bc6b165f7687496999943d8e4293bd6458396 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/mvp/tokenization_mvp_fast.py @@ -0,0 +1,274 @@ +# coding=utf-8 +# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_mvp import MvpTokenizer + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + +# See all MVP models at https://huggingface.co/models?filter=mvp + + +class MvpTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" MVP tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer, + using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import MvpTokenizerFast + + >>> tokenizer = MvpTokenizerFast.from_pretrained("RUCAIBox/mvp") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (MVP tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = MvpTokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + trim_offsets=True, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + **kwargs, + ) + + # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + @property + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + MVP tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on Mvp. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + if is_split_into_words and not self.add_prefix_space: + raise ValueError( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + if is_split_into_words and not self.add_prefix_space: + raise ValueError( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. MVP does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + +__all__ = ["MvpTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nemotron/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/nemotron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd66ddd622e067fc6b2863ffdc745a9e09f2307 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nemotron/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_nemotron import * + from .modeling_nemotron import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/nemotron/configuration_nemotron.py b/venv/lib/python3.13/site-packages/transformers/models/nemotron/configuration_nemotron.py new file mode 100644 index 0000000000000000000000000000000000000000..0a15bb96b0684417d11b669c2b2d97fe72194f9b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nemotron/configuration_nemotron.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nemotron model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class NemotronConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NemotronModel`]. It is used to instantiate an Nemotron + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Nemotron-8B. + e.g. [nvidia/nemotron-3-8b-base-4k-hf](https://huggingface.co/nvidia/nemotron-3-8b-base-4k-hf). + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Nemotron model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NemotronModel`] + hidden_size (`int`, *optional*, defaults to 6144): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer decoder. + head_dim (`int`, *optional*): + Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if None + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.0134): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 3): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj and down_proj layers in the MLP layers. + + ```python + >>> from transformers import NemotronModel, NemotronConfig + + >>> # Initializing a Nemotron nemotron-15b style configuration + >>> configuration = NemotronConfig() + + >>> # Initializing a model from the nemotron-15b style configuration + >>> model = NemotronModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nemotron" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=6144, + intermediate_size=24576, + num_hidden_layers=32, + num_attention_heads=48, + head_dim=None, + num_key_value_heads=None, + hidden_act="relu2", + max_position_embeddings=4096, + initializer_range=0.0134, + norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=2, + eos_token_id=3, + tie_word_embeddings=False, + rope_theta=10000.0, + partial_rotary_factor=0.5, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + rope_config_validation(self) + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["NemotronConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nemotron/modeling_nemotron.py b/venv/lib/python3.13/site-packages/transformers/models/nemotron/modeling_nemotron.py new file mode 100644 index 0000000000000000000000000000000000000000..b98f0a5ef2ac5444de7a3e4de12bcb9c43e5eca9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nemotron/modeling_nemotron.py @@ -0,0 +1,962 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Nemotron model.""" + +import math +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Size, Tensor, nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_nemotron import NemotronConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +def _cast_if_autocast_enabled(device_type, *args): + if not torch.is_autocast_enabled(): + return args + else: + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + return torch.amp.autocast_mode._cast(args, device_type, target_dtype) + + +class NemotronLayerNorm1P(nn.LayerNorm): + def __init__( + self, + normalized_shape: Union[int, list[int], Size], + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) + + def forward(self, input: Tensor) -> Tensor: + device_type = input.device.type if input.device.type != "mps" else "cpu" + args = _cast_if_autocast_enabled( + device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps + ) + with torch.autocast(device_type=input.device.type, enabled=False): + return F.layer_norm(*args) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +class NemotronRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + # Ignore copy + def __init__( + self, + config: NemotronConfig, + device=None, + ): + super().__init__() + + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rot_dim = cos.shape[-1] + # If q_pass/k_pass is empty, rotary pos embedding is applied to all tensor q/k + q, q_pass = q[..., :rot_dim], q[..., rot_dim:] + k, k_pass = k[..., :rot_dim], k[..., rot_dim:] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return torch.cat((q_embed, q_pass), dim=-1), torch.cat((k_embed, k_pass), dim=-1) + + +class NemotronMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class NemotronAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: NemotronConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO cyril: modular +class NemotronFlashAttention2(NemotronAttention): + """ + Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if isinstance(past_key_values, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (NemotronRMSNorm handles it correctly) + + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO cyril: modular +class NemotronSdpaAttention(NemotronAttention): + """ + Nemotron attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `NemotronAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "NemotronModel is using NemotronSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = causal_mask is None and q_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +NEMOTRON_ATTENTION_CLASSES = { + "eager": NemotronAttention, + "flash_attention_2": NemotronFlashAttention2, + "sdpa": NemotronSdpaAttention, +} + + +# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# no longer copied after attention refactors +class NemotronDecoderLayer(GradientCheckpointingLayer): + # Ignore copy + def __init__(self, config: NemotronConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = NEMOTRON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = NemotronMLP(config) + self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class NemotronPreTrainedModel(PreTrainedModel): + config: NemotronConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["NemotronDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, NemotronLayerNorm1P): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +@auto_docstring +class NemotronModel(NemotronPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`NemotronDecoderLayer`] + + Args: + config: NemotronConfig + """ + + def __init__(self, config: NemotronConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [NemotronDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps) + self.rotary_emb = NemotronRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = NemotronModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, NemotronForCausalLM + + >>> model = NemotronForCausalLM.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class NemotronForSequenceClassification(GenericForSequenceClassification, NemotronPreTrainedModel): ... + + +class NemotronForQuestionAnswering(GenericForQuestionAnswering, NemotronPreTrainedModel): + base_model_prefix = "transformer" + + +class NemotronForTokenClassification(GenericForTokenClassification, NemotronPreTrainedModel): ... + + +__all__ = [ + "NemotronForQuestionAnswering", + "NemotronForCausalLM", + "NemotronModel", + "NemotronPreTrainedModel", + "NemotronForSequenceClassification", + "NemotronForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nllb/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/nllb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdb326098a30020f0d9d2cc4c083090bce70d5a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nllb/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_nllb import * + from .tokenization_nllb_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/nllb/tokenization_nllb.py b/venv/lib/python3.13/site-packages/transformers/models/nllb/tokenization_nllb.py new file mode 100644 index 0000000000000000000000000000000000000000..4962a642bb3181d0d0e918530681ba7ff40ef70e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nllb/tokenization_nllb.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip + + +@requires(backends=("sentencepiece",)) +class NllbTokenizer(PreTrainedTokenizer): + """ + Construct an NLLB tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and ` + ` for target language documents. + + Examples: + + ```python + >>> from transformers import NllbTokenizer + + >>> tokenizer = NllbTokenizer.from_pretrained( + ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + src_lang (`str`, *optional*): + The language to use as source language for translation. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + sp_model_kwargs (`dict[str, str]`): + Additional keyword arguments to pass to the model initialization. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: list[int] = [] + suffix_tokens: list[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + tokenizer_file=None, + src_lang=None, + tgt_lang=None, + sp_model_kwargs: Optional[dict[str, Any]] = None, + additional_special_tokens=None, + legacy_behaviour=False, + **kwargs, + ): + if additional_special_tokens is None: + additional_special_tokens = FAIRSEQ_LANGUAGE_CODES + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, normalized=True, lstrip=True, special=True) + if isinstance(mask_token, str) + else mask_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.legacy_behaviour = legacy_behaviour + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ---- + # fairseq | '' | '' | '' | '' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a' + # spm | '' | '' | '' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a' | '▁s' + + # unk token needs to be in the vocab with correct index + self._added_tokens_decoder = {0: bos_token, 1: pad_token, 2: eos_token, 3: unk_token} + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + self.sp_model_size = len(self.sp_model) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + legacy_behaviour=legacy_behaviour, + **kwargs, + ) + + self._src_lang = src_lang if src_lang is not None else "eng_Latn" + self.cur_lang_code_id = self.convert_tokens_to_ids(self._src_lang) + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + return len(self.sp_model) + self.fairseq_offset + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> list[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + spm_id = self.sp_model.PieceToId(token) + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def prepare_seq2seq_batch( + self, + src_texts: list[str], + src_lang: str = "eng_Latn", + tgt_texts: Optional[list[str]] = None, + tgt_lang: str = "fra_Latn", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + - In legacy mode: No prefix and suffix=[eos, src_lang_code]. + - In default mode: Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + if self.legacy_behaviour: + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + - In legacy mode: No prefix and suffix=[eos, tgt_lang_code]. + - In default mode: Prefix=[tgt_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + if self.legacy_behaviour: + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + +__all__ = ["NllbTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nllb/tokenization_nllb_fast.py b/venv/lib/python3.13/site-packages/transformers/models/nllb/tokenization_nllb_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..5300b3942b5d261e5e83be8c55f1a41da3427b09 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nllb/tokenization_nllb_fast.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_nllb import NllbTokenizer +else: + NllbTokenizer = None + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip + + +class NllbTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" NLLB tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + The tokenization method is ` ` for source language documents, and ` + ` for target language documents. + + Examples: + + ```python + >>> from transformers import NllbTokenizerFast + + >>> tokenizer = NllbTokenizerFast.from_pretrained( + ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + src_lang (`str`, *optional*): + The language to use as source language for translation. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = NllbTokenizer + + prefix_tokens: list[int] = [] + suffix_tokens: list[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + src_lang=None, + tgt_lang=None, + additional_special_tokens=None, + legacy_behaviour=False, + **kwargs, + ): + if additional_special_tokens is None: + additional_special_tokens = FAIRSEQ_LANGUAGE_CODES + + self.vocab_file = vocab_file + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, normalized=True, lstrip=True, special=True) + if isinstance(mask_token, str) + else mask_token + ) + self.legacy_behaviour = legacy_behaviour + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + src_lang=src_lang, + tgt_lang=tgt_lang, + mask_token=mask_token, + additional_special_tokens=additional_special_tokens, + legacy_behaviour=legacy_behaviour, + **kwargs, + ) + + self._src_lang = src_lang if src_lang is not None else "eng_Latn" + self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang) + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. The special tokens depend on calling set_lang. + + An NLLB sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def prepare_seq2seq_batch( + self, + src_texts: list[str], + src_lang: str = "eng_Latn", + tgt_texts: Optional[list[str]] = None, + tgt_lang: str = "fra_Latn", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + - In legacy mode: No prefix and suffix=[eos, src_lang_code]. + - In default mode: Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + + if self.legacy_behaviour: + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + - In legacy mode: No prefix and suffix=[eos, tgt_lang_code]. + - In default mode: Prefix=[tgt_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + if self.legacy_behaviour: + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory.") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["NllbTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e21bde18010d3c81455b0a66eed8263c47a0e5de --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_nllb_moe import * + from .modeling_nllb_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/configuration_nllb_moe.py b/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/configuration_nllb_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..61cb2197c84d906e41833cf87eb2bebfb4fbd984 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/configuration_nllb_moe.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2023, HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NLLB-MoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class NllbMoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NllbMoeModel`]. It is used to instantiate an + NLLB-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the NLLB-MoE + [facebook/nllb-moe-54b](https://huggingface.co/facebook/nllb-moe-54b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the NllbMoe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NllbMoeModel`] or + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + second_expert_policy ( `str`, *optional*, default to `"all"`): + The policy used for the sampling the probability of being sampled to a second expert for each token. + normalize_router_prob_before_dropping (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the router probabilities before applying a mask based on the experts capacity + (capacity dropping). + batch_prioritized_routing (`bool`, *optional*, defaults to `True`): + Whether or not to orders the tokens by their router probabilities before capacity dropping. This means that + the tokens that have the highest probabilities will be routed before other tokens that might be further in + the sequence. + moe_eval_capacity_token_fraction (`float`, *optional*, defaults to 1.0): + Fraction of tokens as capacity during validation, if set to negative, uses the same as training. Should be + in range: (0.0, 1.0]. + num_experts (`int`, *optional*, defaults to 128): + Number of experts for each NllbMoeSparseMlp layer. + expert_capacity (`int`, *optional*, defaults to 64): + Number of tokens that can be stored in each expert. + encoder_sparse_step (`int`, *optional*, defaults to 4): + Frequency of the sparse layers in the encoder. 4 means that one out of 4 layers will be sparse. + decoder_sparse_step (`int`, *optional*, defaults to 4): + Frequency of the sparse layers in the decoder. 4 means that one out of 4 layers will be sparse. + router_dtype (`str`, *optional*, default to `"float32"`): + The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the + *selective precision* discussion in [the paper](https://huggingface.co/papers/2101.03961). + router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): + Whether to ignore padding tokens when routing. if `False`, the padding tokens are not routed to any + experts. + router_bias (`bool`, *optional*, defaults to `False`): + Whether or not the classifier of the router should have a bias. + moe_token_dropout (`float`, *optional*, default to 0.2): + Masking rate for MoE expert output masking (EOM), which is implemented via a Dropout2d on the expert + outputs. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the router logits. Only set to `True` to get the auxiliary loss when training. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example: + + ```python + >>> from transformers import NllbMoeModel, NllbMoeConfig + + >>> # Initializing a NllbMoe facebook/nllb-moe-54b style configuration + >>> configuration = NllbMoeConfig() + + >>> # Initializing a model from the facebook/nllb-moe-54b style configuration + >>> model = NllbMoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nllb-moe" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=128112, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.05, + decoder_layerdrop=0.05, + use_cache=True, + is_encoder_decoder=True, + activation_function="relu", + d_model=1024, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + scale_embedding=True, + router_bias=False, + router_dtype="float32", + router_ignore_padding_tokens=False, + num_experts=128, + expert_capacity=64, + encoder_sparse_step=4, + decoder_sparse_step=4, + router_z_loss_coef=0.001, + router_aux_loss_coef=0.001, + second_expert_policy="all", + normalize_router_prob_before_dropping=False, + batch_prioritized_routing=False, + moe_eval_capacity_token_fraction=1.0, + moe_token_dropout=0.2, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + output_router_logits=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.router_z_loss_coef = router_z_loss_coef + self.router_aux_loss_coef = router_aux_loss_coef + self.decoder_sparse_step = decoder_sparse_step + self.encoder_sparse_step = encoder_sparse_step + self.num_experts = num_experts + self.expert_capacity = expert_capacity + self.router_bias = router_bias + if router_dtype not in ["float32", "float16", "bfloat16"]: + raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}") + self.router_dtype = router_dtype + + self.router_ignore_padding_tokens = router_ignore_padding_tokens + self.batch_prioritized_routing = batch_prioritized_routing + self.second_expert_policy = second_expert_policy + self.normalize_router_prob_before_dropping = normalize_router_prob_before_dropping + self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction + self.moe_token_dropout = moe_token_dropout + self.output_router_logits = output_router_logits + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +__all__ = ["NllbMoeConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/modeling_nllb_moe.py b/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/modeling_nllb_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..f0131b6b999b9ab75a557de1a4a5a73147ed0724 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -0,0 +1,1775 @@ +# coding=utf-8 +# Copyright 2023 NllbMoe Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch NLLB-MoE model.""" + +import math +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + MoEModelOutput, + MoEModelOutputWithPastAndCrossAttentions, + Seq2SeqMoEModelOutput, + Seq2SeqMoEOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_nllb_moe import NllbMoeConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + router_probs (`torch.Tensor`): + Probability assigned to each expert per token. Shape: [batch_size, sequence_length, num_experts]. + expert_indices (`torch.Tensor`): + Indices tensor of shape [batch_size, sequence_length] identifying the selected expert for a given token. + + Returns: + The auxiliary loss. + """ + if router_probs is None: + return 0 + + num_experts = router_probs.shape[-1] + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if expert_indices.dtype != torch.int64: + expert_indices = expert_indices.to(torch.int64) + + if len(expert_indices.shape) == 2: + expert_indices = expert_indices.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) + return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe +class NllbMoeScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding +class NllbMoeSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values_length: int = 0, + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +class NllbMoeTop2Router(nn.Module): + """ + Router using tokens choose top-2 experts assignment. + + This router uses the same mechanism as in NLLB-MoE from the fairseq repository. Items are sorted by router_probs + and then routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee + that each token is processed by an expert**, or that each expert receives at least one token. + + The router combining weights are also returned to make sure that the states that are not updated will be masked. + + """ + + def __init__(self, config: NllbMoeConfig): + super().__init__() + self.num_experts = config.num_experts + self.expert_capacity = config.expert_capacity + self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) + self.router_ignore_padding_tokens = config.router_ignore_padding_tokens + self.dtype = getattr(torch, config.router_dtype) + + self.second_expert_policy = config.second_expert_policy + self.normalize_router_prob_before_dropping = config.normalize_router_prob_before_dropping + self.batch_prioritized_routing = config.batch_prioritized_routing + self.moe_eval_capacity_token_fraction = config.moe_eval_capacity_token_fraction + + def _cast_classifier(self): + r""" + `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an + instance of the `Linear8bitLt` class by checking special attributes. + """ + if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")): + self.classifier = self.classifier.to(self.dtype) + + def normalize_router_probabilities(self, router_probs, top_1_mask, top_2_mask): + top_1_max_probs = (router_probs * top_1_mask).sum(dim=1) + top_2_max_probs = (router_probs * top_2_mask).sum(dim=1) + denom_s = torch.clamp(top_1_max_probs + top_2_max_probs, min=torch.finfo(router_probs.dtype).eps) + top_1_max_probs = top_1_max_probs / denom_s + top_2_max_probs = top_2_max_probs / denom_s + return top_1_max_probs, top_2_max_probs + + def route_tokens( + self, + router_logits: torch.Tensor, + input_dtype: torch.dtype = torch.float32, + padding_mask: Optional[torch.LongTensor] = None, + ) -> tuple: + """ + Computes the `dispatch_mask` and the `dispatch_weights` for each experts. The masks are adapted to the expert + capacity. + """ + nb_tokens = router_logits.shape[0] + # Apply Softmax and cast back to the original `dtype` + router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(input_dtype) + top_1_expert_index = torch.argmax(router_probs, dim=-1) + top_1_mask = torch.nn.functional.one_hot(top_1_expert_index, num_classes=self.num_experts) + + if self.second_expert_policy == "sampling": + gumbel = torch.distributions.gumbel.Gumbel(0, 1).rsample + router_logits += gumbel(router_logits.shape).to(router_logits.device) + + # replace top_1_expert_index with min values + logits_except_top_1 = router_logits.masked_fill(top_1_mask.bool(), float("-inf")) + top_2_expert_index = torch.argmax(logits_except_top_1, dim=-1) + top_2_mask = torch.nn.functional.one_hot(top_2_expert_index, num_classes=self.num_experts) + + if self.normalize_router_prob_before_dropping: + top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities( + router_probs, top_1_mask, top_2_mask + ) + + if self.second_expert_policy == "random": + top_2_max_probs = (router_probs * top_2_mask).sum(dim=1) + sampled = (2 * top_2_max_probs) > torch.rand_like(top_2_max_probs.float()) + top_2_mask = top_2_mask * sampled.repeat(self.num_experts, 1).transpose(1, 0) + + if padding_mask is not None and not self.router_ignore_padding_tokens: + if len(padding_mask.shape) == 4: + # only get the last causal mask + padding_mask = padding_mask[:, :, -1, :].reshape(-1)[-nb_tokens:] + non_padding = ~padding_mask.bool() + top_1_mask = top_1_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype) + top_2_mask = top_2_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype) + + if self.batch_prioritized_routing: + # sort tokens based on their routing probability + # to make sure important tokens are routed, first + importance_scores = -1 * router_probs.max(dim=1)[0] + sorted_top_1_mask = top_1_mask[importance_scores.argsort(dim=0)] + sorted_cumsum1 = (torch.cumsum(sorted_top_1_mask, dim=0) - 1) * sorted_top_1_mask + locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)] + + sorted_top_2_mask = top_2_mask[importance_scores.argsort(dim=0)] + sorted_cumsum2 = (torch.cumsum(sorted_top_2_mask, dim=0) - 1) * sorted_top_2_mask + locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)] + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(top_1_mask, dim=0, keepdim=True) + + else: + locations1 = torch.cumsum(top_1_mask, dim=0) - 1 + locations2 = torch.cumsum(top_2_mask, dim=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(top_1_mask, dim=0, keepdim=True) + + if not self.training and self.moe_eval_capacity_token_fraction > 0: + self.expert_capacity = math.ceil(self.moe_eval_capacity_token_fraction * nb_tokens) + else: + capacity = 2 * math.ceil(nb_tokens / self.num_experts) + self.expert_capacity = capacity if self.expert_capacity is None else self.expert_capacity + + # Remove locations outside capacity from ( cumsum < capacity = False will not be routed) + top_1_mask = top_1_mask * torch.lt(locations1, self.expert_capacity) + top_2_mask = top_2_mask * torch.lt(locations2, self.expert_capacity) + + if not self.normalize_router_prob_before_dropping: + top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities( + router_probs, top_1_mask, top_2_mask + ) + + # Calculate combine_weights and dispatch_mask + gates1 = top_1_max_probs[:, None] * top_1_mask + gates2 = top_2_max_probs[:, None] * top_2_mask + router_probs = gates1 + gates2 + + return top_1_mask, router_probs + + def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.LongTensor] = None) -> tuple: + r""" + The hidden states are reshaped to simplify the computation of the router probabilities (combining weights for + each experts.) + + Args: + hidden_states (`torch.Tensor`): + (batch_size, sequence_length, hidden_dim) from which router probabilities are computed. + Returns: + top_1_mask (`torch.Tensor` of shape (batch_size, sequence_length)): + Index tensor of shape [batch_size, sequence_length] corresponding to the expert selected for each token + using the top1 probabilities of the router. + router_probabilities (`torch.Tensor` of shape (batch_size, sequence_length, nump_experts)): + Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each + token and expert. Used for routing tokens to experts. + router_logits (`torch.Tensor` of shape (batch_size, sequence_length))): + Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits. + This is used later for computing router z-loss. + """ + self.input_dtype = hidden_states.dtype + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim) + hidden_states = hidden_states.to(self.dtype) + self._cast_classifier() + router_logits = self.classifier(hidden_states) + top_1_mask, router_probs = self.route_tokens(router_logits, self.input_dtype, padding_mask) + return top_1_mask, router_probs + + +class NllbMoeDenseActDense(nn.Module): + def __init__(self, config: NllbMoeConfig, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(config.d_model, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, config.d_model) + self.dropout = nn.Dropout(config.activation_dropout) + self.act = ACT2FN[config.activation_function] + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.fc2.weight, torch.Tensor) + and hidden_states.dtype != self.fc2.weight.dtype + and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8) + ): + hidden_states = hidden_states.to(self.fc2.weight.dtype) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class NllbMoeSparseMLP(nn.Module): + r""" + Implementation of the NLLB-MoE sparse MLP module. + """ + + def __init__(self, config: NllbMoeConfig, ffn_dim: int, expert_class: nn.Module = NllbMoeDenseActDense): + super().__init__() + self.router = NllbMoeTop2Router(config) + self.moe_token_dropout = config.moe_token_dropout + self.token_dropout = nn.Dropout(self.moe_token_dropout) + self.num_experts = config.num_experts + + self.experts = nn.ModuleDict() + for idx in range(self.num_experts): + self.experts[f"expert_{idx}"] = expert_class(config, ffn_dim) + + def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = False): + r""" + The goal of this forward pass is to have the same number of operation as the equivalent `NllbMoeDenseActDense` + (mlp) layer. This means that all of the hidden states should be processed at most twice ( since we are using a + top_2 gating mechanism). This means that we keep the complexity to O(batch_size x sequence_length x hidden_dim) + instead of O(num_experts x batch_size x sequence_length x hidden_dim). + + 1- Get the `router_probs` from the `router`. The shape of the `router_mask` is `(batch_size X sequence_length, + num_expert)` and corresponds to the boolean version of the `router_probs`. The inputs are masked using the + `router_mask`. + + 2- Dispatch the hidden_states to its associated experts. The router probabilities are used to weight the + contribution of each experts when updating the masked hidden states. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`): + The hidden states + padding_mask (`torch.Tensor`, *optional*, defaults to `False`): + Attention mask. Can be in the causal form or not. + + Returns: + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`): + Updated hidden states + router_logits (`torch.Tensor` of shape `(batch_size, sequence_length, num_experts)`): + Needed for computing the loss + + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + + top_1_mask, router_probs = self.router(hidden_states, padding_mask) + router_mask = router_probs.bool() + hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim) + masked_hidden_states = torch.einsum("bm,be->ebm", hidden_states, router_mask) + for idx, expert in enumerate(self.experts.values()): + token_indices = router_mask[:, idx] + combining_weights = router_probs[token_indices, idx] + expert_output = expert(masked_hidden_states[idx, token_indices]) + if self.moe_token_dropout > 0: + if self.training: + expert_output = self.token_dropout(expert_output) + else: + expert_output *= 1 - self.moe_token_dropout + masked_hidden_states[idx, token_indices] = torch.einsum("b,be->be", combining_weights, expert_output) + hidden_states = masked_hidden_states.sum(dim=0).reshape(batch_size, sequence_length, hidden_dim) + + top_1_expert_index = torch.argmax(top_1_mask, dim=-1) + return hidden_states, (router_probs, top_1_expert_index) + + +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states +class NllbMoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, + config: Optional[NllbMoeConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = encoder_hidden_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class NllbMoeEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): + super().__init__() + self.embed_dim = config.d_model + self.is_sparse = is_sparse + self.self_attn = NllbMoeAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + if not self.is_sparse: + self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.encoder_ffn_dim) + else: + self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.encoder_ffn_dim) + self.ff_layer_norm = nn.LayerNorm(config.d_model) + self.ff_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + output_router_logits: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.ff_layer_norm(hidden_states) + if self.is_sparse: + hidden_states, router_states = self.ffn(hidden_states, attention_mask) + else: + # router_states set to None to track which layers have None gradients. + hidden_states, router_states = self.ffn(hidden_states), None + + hidden_states = self.ff_dropout(hidden_states) + + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if output_router_logits: + outputs += (router_states,) + + return outputs + + +class NllbMoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: NllbMoeConfig, is_sparse: bool = False, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + self.is_sparse = is_sparse + self.self_attn = NllbMoeAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.attn_dropout = nn.Dropout(config.dropout) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attention = NllbMoeAttention( + self.embed_dim, + config.decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) + if not self.is_sparse: + self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.decoder_ffn_dim) + else: + self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.decoder_ffn_dim) + self.ff_layer_norm = nn.LayerNorm(config.d_model) + self.ff_dropout = nn.Dropout(config.activation_dropout) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): + encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by + very large negative values. + layer_head_mask (`torch.FloatTensor`): + mask for attention heads in a given layer of size `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): + mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. + past_key_values (`Cache`): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + hidden_states, cross_attn_weights = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + hidden_states = self.ff_layer_norm(hidden_states) + if self.is_sparse: + hidden_states, router_states = self.ffn(hidden_states, attention_mask) + else: + hidden_states, router_states = self.ffn(hidden_states), None + + hidden_states = self.ff_dropout(hidden_states) + + hidden_states = residual + hidden_states + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if output_router_logits: + outputs += (router_states,) + + return outputs + + +@auto_docstring +class NllbMoePreTrainedModel(PreTrainedModel): + config: NllbMoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"] + # TODO: If anyone is up to it to make sure tests pass etc + # Flash attention has problems due to not preparing masks the same way as eager/sdpa + # SDPA has more flaky logits which requires more time to look into tests + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +class NllbMoeEncoder(NllbMoePreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`NllbMoeEncoderLayer`]. + + Args: + config: + NllbMoeConfig + embed_tokens (nn.Embedding): + output embedding + """ + + def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = NllbMoeScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = NllbMoeSinusoidalPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + self.padding_idx, + ) + sparse_step = config.encoder_sparse_step + self.layers = nn.ModuleList() + for i in range(config.encoder_layers): + is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False + self.layers.append(NllbMoeEncoderLayer(config, is_sparse)) + + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input_ids, inputs_embeds) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) + + encoder_states = () if output_hidden_states else None + all_router_probs = () if output_router_logits else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = torch.rand([]) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + output_router_logits=output_router_logits, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_router_logits: + all_router_probs += (layer_outputs[-1],) + + last_hidden_state = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states += (last_hidden_state,) + + if not return_dict: + return tuple( + v for v in [last_hidden_state, encoder_states, all_attentions, all_router_probs] if v is not None + ) + + return MoEModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=encoder_states, + attentions=all_attentions, + router_probs=all_router_probs, + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + +class NllbMoeDecoder(NllbMoePreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`NllbMoeDecoderLayer`] + + Args: + config: + NllbMoeConfig + embed_tokens (nn.Embedding): + output embedding + """ + + def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = NllbMoeScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = NllbMoeSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + ) + + sparse_step = config.decoder_sparse_step + self.layers = nn.ModuleList() + for i in range(config.decoder_layers): + is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False + self.layers.append(NllbMoeDecoderLayer(config, is_sparse, layer_idx=i)) + + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = True, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + # embed positions + positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_probs = () if output_router_logits else None + all_cross_attentions = () if output_attentions else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = self.training and dropout_probability < self.layerdrop + if not skip_the_layer or synced_gpus: + layer_head_mask = head_mask[idx] if head_mask is not None else None + cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + + # under fsdp or deepspeed zero3 all gpus must run in sync + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if skip_the_layer: + continue + + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) + + if output_router_logits: + all_router_probs += (layer_outputs[-1],) + + hidden_states = self.layer_norm(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + all_cross_attentions, + all_router_probs, + ] + if v is not None + ) + return MoEModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + router_probs=all_router_probs, + ) + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +@auto_docstring +class NllbMoeModel(NllbMoePreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: NllbMoeConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = NllbMoeEncoder(config, self.shared) + self.decoder = NllbMoeDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = True, + ) -> Union[tuple[torch.Tensor], Seq2SeqMoEModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + NllbMoe uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoTokenizer, NllbMoeModel + + >>> tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts") + >>> model = SwitchTransformersModel.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for NllbMoeModel + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, MoEModelOutput): + encoder_outputs = MoEModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqMoEModelOutput( + past_key_values=decoder_outputs.past_key_values, + cross_attentions=decoder_outputs.cross_attentions, + last_hidden_state=decoder_outputs.last_hidden_state, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + decoder_hidden_states=decoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + decoder_attentions=decoder_outputs.attentions, + encoder_router_logits=encoder_outputs.router_probs, + decoder_router_logits=decoder_outputs.router_probs, + ) + + +@auto_docstring( + custom_intro=""" + The NllbMoe Model with a language modeling head. Can be used for summarization. + """ +) +class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: NllbMoeConfig): + super().__init__(config) + self.model = NllbMoeModel(config) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.router_z_loss_coef = config.router_z_loss_coef + self.router_aux_loss_coef = config.router_aux_loss_coef + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], Seq2SeqMoEOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + NllbMoe uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example Translation: + + ```python + >>> from transformers import AutoTokenizer, NllbMoeForConditionalGeneration + + >>> model = NllbMoeForConditionalGeneration.from_pretrained("facebook/nllb-moe-54b") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-moe-54b") + + >>> text_to_translate = "Life is like a box of chocolates" + >>> model_inputs = tokenizer(text_to_translate, return_tensors="pt") + + >>> # translate to French + >>> gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("eng_Latn")) + >>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + + loss = None + encoder_aux_loss = None + decoder_aux_loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # todo check in the config if router loss enables + + if output_router_logits: + encoder_router_logits = outputs[-1] + decoder_router_logits = outputs[3 if output_attentions else 4] + + # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder + encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits) + encoder_aux_loss = load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes) + + decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_router_logits) + decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes) + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if output_router_logits and labels is not None: + aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss) + loss = loss + aux_loss + + output = (loss,) if loss is not None else () + if not return_dict: + output += (lm_logits,) + if output_router_logits: # only return the loss if they are not None + output += ( + encoder_aux_loss, + decoder_aux_loss, + *outputs[1:], + ) + else: + output += outputs[1:] + + return output + + return Seq2SeqMoEOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + cross_attentions=outputs.cross_attentions, + encoder_aux_loss=encoder_aux_loss, + decoder_aux_loss=decoder_aux_loss, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + decoder_hidden_states=outputs.decoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + decoder_attentions=outputs.decoder_attentions, + encoder_router_logits=outputs.encoder_router_logits, + decoder_router_logits=outputs.decoder_router_logits, + ) + + def _unpack_router_logits(self, router_outputs): + total_router_logits = [] + total_expert_indexes = [] + for router_output in router_outputs: + if router_output is not None: + router_logits, expert_indexes = router_output + total_router_logits.append(router_logits) + total_expert_indexes.append(expert_indexes) + + total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None + total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None + return total_router_logits, total_expert_indexes + + +__all__ = [ + "NllbMoeForConditionalGeneration", + "NllbMoeModel", + "NllbMoePreTrainedModel", + "NllbMoeTop2Router", + "NllbMoeSparseMLP", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nystromformer/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/nystromformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c32df6d8db909207e4867e119d2a7fca4a0bd907 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nystromformer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_nystromformer import * + from .modeling_nystromformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/nystromformer/configuration_nystromformer.py b/venv/lib/python3.13/site-packages/transformers/models/nystromformer/configuration_nystromformer.py new file mode 100644 index 0000000000000000000000000000000000000000..96a48b99fda4cd7ff056ceeb441889653682bec1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nystromformer/configuration_nystromformer.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2022 UW-Madison and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nystromformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class NystromformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NystromformerModel`]. It is used to instantiate + an Nystromformer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Nystromformer + [uw-madison/nystromformer-512](https://huggingface.co/uw-madison/nystromformer-512) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30000): + Vocabulary size of the Nystromformer model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`NystromformerModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`NystromformerModel`]. + segment_means_seq_len (`int`, *optional*, defaults to 64): + Sequence length used in segment-means. + num_landmarks (`int`, *optional*, defaults to 64): + The number of landmark (or Nystrom) points to use in Nystrom approximation of the softmax self-attention + matrix. + conv_kernel_size (`int`, *optional*, defaults to 65): + The kernel size of depthwise convolution used in Nystrom approximation. + inv_coeff_init_option (`bool`, *optional*, defaults to `False`): + Whether or not to use exact coefficient computation for the initial values for the iterative method of + calculating the Moore-Penrose inverse of a matrix. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import NystromformerModel, NystromformerConfig + + >>> # Initializing a Nystromformer uw-madison/nystromformer-512 style configuration + >>> configuration = NystromformerConfig() + + >>> # Initializing a model from the uw-madison/nystromformer-512 style configuration + >>> model = NystromformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nystromformer" + + def __init__( + self, + vocab_size=30000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu_new", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=510, + type_vocab_size=2, + segment_means_seq_len=64, + num_landmarks=64, + conv_kernel_size=65, + inv_coeff_init_option=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.segment_means_seq_len = segment_means_seq_len + self.num_landmarks = num_landmarks + self.conv_kernel_size = conv_kernel_size + self.inv_coeff_init_option = inv_coeff_init_option + self.layer_norm_eps = layer_norm_eps + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +__all__ = ["NystromformerConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/nystromformer/modeling_nystromformer.py b/venv/lib/python3.13/site-packages/transformers/models/nystromformer/modeling_nystromformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb1fad2401948957a7c11b63e2b3d16b279c733 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/nystromformer/modeling_nystromformer.py @@ -0,0 +1,1017 @@ +# coding=utf-8 +# Copyright 2022 UW-Madison The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Nystromformer model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + auto_docstring, + logging, +) +from .configuration_nystromformer import NystromformerConfig + + +logger = logging.get_logger(__name__) + + +class NystromformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class NystromformerSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.num_landmarks = config.num_landmarks + self.seq_len = config.segment_means_seq_len + self.conv_kernel_size = config.conv_kernel_size + + if config.inv_coeff_init_option: + self.init_option = config["inv_init_coeff_option"] + else: + self.init_option = "original" + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + + if self.conv_kernel_size is not None: + self.conv = nn.Conv2d( + in_channels=self.num_attention_heads, + out_channels=self.num_attention_heads, + kernel_size=(self.conv_kernel_size, 1), + padding=(self.conv_kernel_size // 2, 0), + bias=False, + groups=self.num_attention_heads, + ) + + # Function to approximate Moore-Penrose inverse via the iterative method + def iterative_inv(self, mat, n_iter=6): + identity = torch.eye(mat.size(-1), device=mat.device) + key = mat + + # The entries of key are positive and ||key||_{\infty} = 1 due to softmax + if self.init_option == "original": + # This original implementation is more conservative to compute coefficient of Z_0. + value = 1 / torch.max(torch.sum(key, dim=-2)) * key.transpose(-1, -2) + else: + # This is the exact coefficient computation, 1 / ||key||_1, of initialization of Z_0, leading to faster convergence. + value = 1 / torch.max(torch.sum(key, dim=-2), dim=-1).values[:, :, None, None] * key.transpose(-1, -2) + + for _ in range(n_iter): + key_value = torch.matmul(key, value) + value = torch.matmul( + 0.25 * value, + 13 * identity + - torch.matmul(key_value, 15 * identity - torch.matmul(key_value, 7 * identity - key_value)), + ) + return value + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + query_layer = query_layer / math.sqrt(math.sqrt(self.attention_head_size)) + key_layer = key_layer / math.sqrt(math.sqrt(self.attention_head_size)) + + if self.num_landmarks == self.seq_len: + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function) + attention_scores = attention_scores + attention_mask + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + context_layer = torch.matmul(attention_probs, value_layer) + + else: + q_landmarks = query_layer.reshape( + -1, + self.num_attention_heads, + self.num_landmarks, + self.seq_len // self.num_landmarks, + self.attention_head_size, + ).mean(dim=-2) + k_landmarks = key_layer.reshape( + -1, + self.num_attention_heads, + self.num_landmarks, + self.seq_len // self.num_landmarks, + self.attention_head_size, + ).mean(dim=-2) + + kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1) + kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1) + + attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2)) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function) + attention_scores = attention_scores + attention_mask + + kernel_3 = nn.functional.softmax(attention_scores, dim=-1) + attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2)) + new_value_layer = torch.matmul(kernel_3, value_layer) + context_layer = torch.matmul(attention_probs, new_value_layer) + + if self.conv_kernel_size is not None: + context_layer += self.conv(value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class NystromformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class NystromformerAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = NystromformerSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = NystromformerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + self_outputs = self.self(hidden_states, attention_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Nystromformer +class NystromformerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Nystromformer +class NystromformerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class NystromformerLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = NystromformerAttention(config) + self.add_cross_attention = config.add_cross_attention + self.intermediate = NystromformerIntermediate(config) + self.output = NystromformerOutput(config) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class NystromformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([NystromformerLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Nystromformer +class NystromformerPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Nystromformer +class NystromformerLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = NystromformerPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Nystromformer +class NystromformerOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = NystromformerLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@auto_docstring +class NystromformerPreTrainedModel(PreTrainedModel): + config: NystromformerConfig + base_model_prefix = "nystromformer" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class NystromformerModel(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = NystromformerEmbeddings(config) + self.encoder = NystromformerEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring +class NystromformerForMaskedLM(NystromformerPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config): + super().__init__(config) + + self.nystromformer = NystromformerModel(config) + self.cls = NystromformerOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class NystromformerClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@auto_docstring( + custom_intro=""" + Nyströmformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class NystromformerForSequenceClassification(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.nystromformer = NystromformerModel(config) + self.classifier = NystromformerClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class NystromformerForMultipleChoice(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.nystromformer = NystromformerModel(config) + self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_state = outputs[0] # (bs * num_choices, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs * num_choices, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class NystromformerForTokenClassification(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.nystromformer = NystromformerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class NystromformerForQuestionAnswering(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.nystromformer = NystromformerModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "NystromformerForMaskedLM", + "NystromformerForMultipleChoice", + "NystromformerForQuestionAnswering", + "NystromformerForSequenceClassification", + "NystromformerForTokenClassification", + "NystromformerLayer", + "NystromformerModel", + "NystromformerPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/olmoe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/olmoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e18d77cbad11bf218b6cd304f703b4d7935eef --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/olmoe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_olmoe import * + from .modeling_olmoe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/olmoe/configuration_olmoe.py b/venv/lib/python3.13/site-packages/transformers/models/olmoe/configuration_olmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..864d06b64d7ad48d904ab549b4d107a7832666d3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/olmoe/configuration_olmoe.py @@ -0,0 +1,182 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OLMoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class OlmoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OlmoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 16): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + clip_qkv (`float`, *optional*): + If not `None`, elements of query, key and value attention states are clipped so that their + absolute value does not exceed this value. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 64): + Number of routed experts. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.01): + The aux loss factor for the total loss. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + + ```python + >>> from transformers import OlmoeModel, OlmoeConfig + + >>> # Initializing a OLMoE 7B A1B style configuration + >>> configuration = OlmoeConfig() + + >>> # Initializing a model from the OLMoE 7B A1B style configuration + >>> model = OlmoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "olmoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=2048, + intermediate_size=2048, + num_hidden_layers=16, + num_attention_heads=16, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + clip_qkv=None, + num_experts_per_tok=8, + num_experts=64, + output_router_logits=False, + router_aux_loss_coef=0.01, + norm_topk_prob=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.clip_qkv = clip_qkv + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.norm_topk_prob = norm_topk_prob + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["OlmoeConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/olmoe/modeling_olmoe.py b/venv/lib/python3.13/site-packages/transformers/models/olmoe/modeling_olmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..4070d3b2b48072f41dd033b1d10b1d0f9595e69c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/olmoe/modeling_olmoe.py @@ -0,0 +1,1097 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OLMoE model.""" + +import math +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_olmoe import OlmoeConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + device_index = routing_weights.device.index if routing_weights.device.index is not None else 0 + rank = routing_weights.shape[1] * int(device_index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) + return overall_loss * num_experts + + +class OlmoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + OlmoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe +class OlmoeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: OlmoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe +class OlmoeMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class OlmoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = OlmoeRMSNorm( + (self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class OlmoeFlashAttention2(OlmoeAttention): + """ + OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoeRMSNorm handles it correctly) + + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class OlmoeSdpaAttention(OlmoeAttention): + """ + OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from OlmoeAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = causal_mask is None and q_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +OLMOE_ATTENTION_CLASSES = { + "eager": OlmoeAttention, + "flash_attention_2": OlmoeFlashAttention2, + "sdpa": OlmoeSdpaAttention, +} + + +class OlmoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) + self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be selected + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class OlmoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: OlmoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = OlmoeSparseMoeBlock(config) + self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@auto_docstring +class OlmoePreTrainedModel(PreTrainedModel): + config: OlmoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OlmoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, OlmoeRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@auto_docstring +class OlmoeModel(OlmoePreTrainedModel): + def __init__(self, config: OlmoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = OlmoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OlmoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[tuple, MoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, OlmoeForCausalLM + + >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/owlvit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/owlvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fd80c2bcc6b268cd47a7afe321364ce3ed1aa5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/owlvit/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_owlvit import * + from .feature_extraction_owlvit import * + from .image_processing_owlvit import * + from .image_processing_owlvit_fast import * + from .modeling_owlvit import * + from .processing_owlvit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/owlvit/configuration_owlvit.py b/venv/lib/python3.13/site-packages/transformers/models/owlvit/configuration_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..d4873ff4a08b3f53cfb4dc8cd63ba01e7a2e96c1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/owlvit/configuration_owlvit.py @@ -0,0 +1,336 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OWL-ViT model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional + + +if TYPE_CHECKING: + from ...processing_utils import ProcessorMixin + from ...utils import TensorType + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class OwlViTTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`OwlViTTextModel`]. It is used to instantiate an + OwlViT text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OwlViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the OWL-ViT text model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`OwlViTTextModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 16): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token in the input sequences. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the input sequences. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the input sequences. + + Example: + + ```python + >>> from transformers import OwlViTTextConfig, OwlViTTextModel + + >>> # Initializing a OwlViTTextModel with google/owlvit-base-patch32 style configuration + >>> configuration = OwlViTTextConfig() + + >>> # Initializing a OwlViTTextConfig from the google/owlvit-base-patch32 style configuration + >>> model = OwlViTTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "owlvit_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=16, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=0, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +class OwlViTVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`OwlViTVisionModel`]. It is used to instantiate + an OWL-ViT image encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OWL-ViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 768): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import OwlViTVisionConfig, OwlViTVisionModel + + >>> # Initializing a OwlViTVisionModel with google/owlvit-base-patch32 style configuration + >>> configuration = OwlViTVisionConfig() + + >>> # Initializing a OwlViTVisionModel model from the google/owlvit-base-patch32 style configuration + >>> model = OwlViTVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "owlvit_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=768, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +class OwlViTConfig(PretrainedConfig): + r""" + [`OwlViTConfig`] is the configuration class to store the configuration of an [`OwlViTModel`]. It is used to + instantiate an OWL-ViT model according to the specified arguments, defining the text model and vision model + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWL-ViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`OwlViTTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`OwlViTVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original OWL-ViT + implementation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not the model should return a dictionary. If `False`, returns a tuple. + kwargs (*optional*): + Dictionary of keyword arguments. + """ + + model_type = "owlvit" + sub_configs = {"text_config": OwlViTTextConfig, "vision_config": OwlViTVisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + projection_dim=512, + logit_scale_init_value=2.6592, + return_dict=True, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the OwlViTTextConfig with default values.") + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. initializing the OwlViTVisionConfig with default values.") + + self.text_config = OwlViTTextConfig(**text_config) + self.vision_config = OwlViTVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.return_dict = return_dict + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: dict, vision_config: dict, **kwargs): + r""" + Instantiate a [`OwlViTConfig`] (or a derived class) from owlvit text model configuration and owlvit vision + model configuration. + + Returns: + [`OwlViTConfig`]: An instance of a configuration object + """ + config_dict = {} + config_dict["text_config"] = text_config + config_dict["vision_config"] = vision_config + + return cls.from_dict(config_dict, **kwargs) + + +class OwlViTOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("logits_per_image", {0: "batch"}), + ("logits_per_text", {0: "batch"}), + ("text_embeds", {0: "batch"}), + ("image_embeds", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + processor: "ProcessorMixin", + batch_size: int = -1, + seq_length: int = -1, + framework: Optional["TensorType"] = None, + ) -> Mapping[str, Any]: + text_input_dict = super().generate_dummy_inputs( + processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + ) + image_input_dict = super().generate_dummy_inputs( + processor.image_processor, batch_size=batch_size, framework=framework + ) + return {**text_input_dict, **image_input_dict} + + @property + def default_onnx_opset(self) -> int: + return 14 + + +__all__ = ["OwlViTConfig", "OwlViTOnnxConfig", "OwlViTTextConfig", "OwlViTVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/owlvit/feature_extraction_owlvit.py b/venv/lib/python3.13/site-packages/transformers/models/owlvit/feature_extraction_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3a8d0b145cc54e8ccebf7acc84538236fac644 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/owlvit/feature_extraction_owlvit.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for OwlViT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_owlvit import OwlViTImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class OwlViTFeatureExtractor(OwlViTImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class OwlViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use OwlViTImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["OwlViTFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/owlvit/image_processing_owlvit.py b/venv/lib/python3.13/site-packages/transformers/models/owlvit/image_processing_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9c6cfdeaa8aa0532539f03e3e8e1e1aa3d9619 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/owlvit/image_processing_owlvit.py @@ -0,0 +1,625 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for OwlViT""" + +import warnings +from typing import TYPE_CHECKING, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + center_crop, + center_to_corners_format, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_torch_available, logging +from ...utils.import_utils import requires + + +if TYPE_CHECKING: + from .modeling_owlvit import OwlViTObjectDetectionOutput + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +def _upcast(t): + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +def _scale_boxes(boxes, target_sizes): + """ + Scale batch of bounding boxes to the target sizes. + + Args: + boxes (`torch.Tensor` of shape `(batch_size, num_boxes, 4)`): + Bounding boxes to scale. Each box is expected to be in (x1, y1, x2, y2) format. + target_sizes (`list[tuple[int, int]]` or `torch.Tensor` of shape `(batch_size, 2)`): + Target sizes to scale the boxes to. Each target size is expected to be in (height, width) format. + + Returns: + `torch.Tensor` of shape `(batch_size, num_boxes, 4)`: Scaled bounding boxes. + """ + + if isinstance(target_sizes, (list, tuple)): + image_height = torch.tensor([i[0] for i in target_sizes]) + image_width = torch.tensor([i[1] for i in target_sizes]) + elif isinstance(target_sizes, torch.Tensor): + image_height, image_width = target_sizes.unbind(1) + else: + raise TypeError("`target_sizes` must be a list, tuple or torch.Tensor") + + scale_factor = torch.stack([image_width, image_height, image_width, image_height], dim=1) + scale_factor = scale_factor.unsqueeze(1).to(boxes.device) + boxes = boxes * scale_factor + return boxes + + +def box_area(boxes): + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +@requires(backends=("vision",)) +class OwlViTImageProcessor(BaseImageProcessor): + r""" + Constructs an OWL-ViT image processor. + + This image processor inherits from [`ImageProcessingMixin`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the shorter edge of the input to a certain `size`. + size (`dict[str, int]`, *optional*, defaults to {"height": 768, "width": 768}): + The size to use for resizing the image. Only has an effect if `do_resize` is set to `True`. If `size` is a + sequence like (h, w), output size will be matched to this. If `size` is an int, then image will be resized + to (size, size). + resample (`int`, *optional*, defaults to `Resampling.BICUBIC`): + An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`, + `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`, + `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set + to `True`. + do_center_crop (`bool`, *optional*, defaults to `False`): + Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the + image is padded with 0's and then center cropped. + crop_size (`int`, *optional*, defaults to {"height": 768, "width": 768}): + The size to use for center cropping the image. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input by a certain factor. + rescale_factor (`float`, *optional*, defaults to `1/255`): + The factor to use for rescaling the image. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with `image_mean` and `image_std`. Desired output size when applying + center-cropping. Only has an effect if `do_center_crop` is set to `True`. + image_mean (`list[int]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`list[int]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=None, + resample=PILImageResampling.BICUBIC, + do_center_crop=False, + crop_size=None, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs, + ): + size = size if size is not None else {"height": 768, "width": 768} + size = get_size_dict(size, default_to_square=True) + + crop_size = crop_size if crop_size is not None else {"height": 768, "width": 768} + crop_size = get_size_dict(crop_size, default_to_square=True) + + # Early versions of the OWL-ViT config on the hub had "rescale" as a flag. This clashes with the + # vision image processor method `rescale` as it would be set as an attribute during the super().__init__ + # call. This is for backwards compatibility. + if "rescale" in kwargs: + rescale_val = kwargs.pop("rescale") + kwargs["do_rescale"] = rescale_val + + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to a certain size. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + The size to resize the image to. Must contain height and width keys. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use when resizing the input. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=True) + if "height" not in size or "width" not in size: + raise ValueError("size dictionary must contain height and width keys") + + return resize( + image, + (size["height"], size["width"]), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def center_crop( + self, + image: np.ndarray, + crop_size: dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Center crop an image to a certain size. + + Args: + image (`np.ndarray`): + Image to center crop. + crop_size (`dict[str, int]`): + The size to center crop the image to. Must contain height and width keys. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + crop_size = get_size_dict(crop_size, default_to_square=True) + if "height" not in crop_size or "width" not in crop_size: + raise ValueError("crop_size dictionary must contain height and width keys") + + return center_crop( + image, + (crop_size["height"], crop_size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Prepares an image or batch of images for the model. + + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Expects a single or batch of images with pixel values + ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether or not to resize the input. If `True`, will resize the input to the size specified by `size`. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + The size to resize the input to. Only has an effect if `do_resize` is set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + The resampling filter to use when resizing the input. Only has an effect if `do_resize` is set to + `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether or not to center crop the input. If `True`, will center crop the input to the size specified by + `crop_size`. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + The size to center crop the input to. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether or not to rescale the input. If `True`, will rescale the input by dividing it by + `rescale_factor`. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + The factor to rescale the input by. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether or not to normalize the input. If `True`, will normalize the input by subtracting `image_mean` + and dividing by `image_std`. + image_mean (`Union[float, list[float]]`, *optional*, defaults to `self.image_mean`): + The mean to subtract from the input when normalizing. Only has an effect if `do_normalize` is set to + `True`. + image_std (`Union[float, list[float]]`, *optional*, defaults to `self.image_std`): + The standard deviation to divide the input by when normalizing. Only has an effect if `do_normalize` is + set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: defaults to the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image, crop_size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + return encoded_inputs + + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + # TODO: (amy) add support for other frameworks + warnings.warn( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + FutureWarning, + ) + + logits, boxes = outputs.logits, outputs.pred_boxes + + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + labels = probs.indices + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + def post_process_object_detection( + self, + outputs: "OwlViTObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, list[tuple]]] = None, + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + """ + batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes + batch_size = len(batch_logits) + + if target_sizes is not None and len(target_sizes) != batch_size: + raise ValueError("Make sure that you pass in as many target sizes as images") + + # batch_logits of shape (batch_size, num_queries, num_classes) + batch_class_logits = torch.max(batch_logits, dim=-1) + batch_scores = torch.sigmoid(batch_class_logits.values) + batch_labels = batch_class_logits.indices + + # Convert to [x0, y0, x1, y1] format + batch_boxes = center_to_corners_format(batch_boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + batch_boxes = _scale_boxes(batch_boxes, target_sizes) + + results = [] + for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes): + keep = scores > threshold + scores = scores[keep] + labels = labels[keep] + boxes = boxes[keep] + results.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return results + + def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. All labels are set to None as + `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection. + """ + logits, target_boxes = outputs.logits, outputs.target_pred_boxes + + if target_sizes is not None and len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + + # Convert to [x0, y0, x1, y1] format + target_boxes = center_to_corners_format(target_boxes) + + # Apply non-maximum suppression (NMS) + if nms_threshold < 1.0: + for idx in range(target_boxes.shape[0]): + for i in torch.argsort(-scores[idx]): + if not scores[idx][i]: + continue + + ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0] + ious[i] = -1.0 # Mask self-IoU. + scores[idx][ious > nms_threshold] = 0.0 + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + target_boxes = _scale_boxes(target_boxes, target_sizes) + + # Compute box display alphas based on prediction scores + results = [] + alphas = torch.zeros_like(scores) + + for idx in range(target_boxes.shape[0]): + # Select scores for boxes matching the current query: + query_scores = scores[idx] + if not query_scores.nonzero().numel(): + continue + + # Apply threshold on scores before scaling + query_scores[query_scores < threshold] = 0.0 + + # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. + # All other boxes will either belong to a different query, or will not be shown. + max_score = torch.max(query_scores) + 1e-6 + query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) + query_alphas = torch.clip(query_alphas, 0.0, 1.0) + alphas[idx] = query_alphas + + mask = alphas[idx] > 0 + box_scores = alphas[idx][mask] + boxes = target_boxes[idx][mask] + results.append({"scores": box_scores, "labels": None, "boxes": boxes}) + + return results + + +__all__ = ["OwlViTImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/owlvit/image_processing_owlvit_fast.py b/venv/lib/python3.13/site-packages/transformers/models/owlvit/image_processing_owlvit_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..1e458f964a04cd9bf12712f82f4691be374d497d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/owlvit/image_processing_owlvit_fast.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for OwlViT""" + +import warnings +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_transforms import center_to_corners_format +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...utils import TensorType, auto_docstring, logging +from .image_processing_owlvit import _scale_boxes, box_iou + + +if TYPE_CHECKING: + from .modeling_owlvit import OwlViTObjectDetectionOutput + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class OwlViTImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 768, "width": 768} + default_to_square = True + crop_size = {"height": 768, "width": 768} + do_resize = True + do_center_crop = False + do_rescale = True + do_normalize = None + do_convert_rgb = None + model_input_names = ["pixel_values"] + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + # TODO: (amy) add support for other frameworks + warnings.warn( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + FutureWarning, + ) + + logits, boxes = outputs.logits, outputs.pred_boxes + + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + labels = probs.indices + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection + def post_process_object_detection( + self, + outputs: "OwlViTObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, list[tuple]]] = None, + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + """ + batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes + batch_size = len(batch_logits) + + if target_sizes is not None and len(target_sizes) != batch_size: + raise ValueError("Make sure that you pass in as many target sizes as images") + + # batch_logits of shape (batch_size, num_queries, num_classes) + batch_class_logits = torch.max(batch_logits, dim=-1) + batch_scores = torch.sigmoid(batch_class_logits.values) + batch_labels = batch_class_logits.indices + + # Convert to [x0, y0, x1, y1] format + batch_boxes = center_to_corners_format(batch_boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + batch_boxes = _scale_boxes(batch_boxes, target_sizes) + + results = [] + for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes): + keep = scores > threshold + scores = scores[keep] + labels = labels[keep] + boxes = boxes[keep] + results.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return results + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection + def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. All labels are set to None as + `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection. + """ + logits, target_boxes = outputs.logits, outputs.target_pred_boxes + + if target_sizes is not None and len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + + # Convert to [x0, y0, x1, y1] format + target_boxes = center_to_corners_format(target_boxes) + + # Apply non-maximum suppression (NMS) + if nms_threshold < 1.0: + for idx in range(target_boxes.shape[0]): + for i in torch.argsort(-scores[idx]): + if not scores[idx][i]: + continue + + ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0] + ious[i] = -1.0 # Mask self-IoU. + scores[idx][ious > nms_threshold] = 0.0 + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + target_boxes = _scale_boxes(target_boxes, target_sizes) + + # Compute box display alphas based on prediction scores + results = [] + alphas = torch.zeros_like(scores) + + for idx in range(target_boxes.shape[0]): + # Select scores for boxes matching the current query: + query_scores = scores[idx] + if not query_scores.nonzero().numel(): + continue + + # Apply threshold on scores before scaling + query_scores[query_scores < threshold] = 0.0 + + # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. + # All other boxes will either belong to a different query, or will not be shown. + max_score = torch.max(query_scores) + 1e-6 + query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) + query_alphas = torch.clip(query_alphas, 0.0, 1.0) + alphas[idx] = query_alphas + + mask = alphas[idx] > 0 + box_scores = alphas[idx][mask] + boxes = target_boxes[idx][mask] + results.append({"scores": box_scores, "labels": None, "boxes": boxes}) + + return results + + +__all__ = ["OwlViTImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/owlvit/modeling_owlvit.py b/venv/lib/python3.13/site-packages/transformers/models/owlvit/modeling_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..3971b1376d9c10b945fa17fd9f23fcba5ff65e38 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/owlvit/modeling_owlvit.py @@ -0,0 +1,1642 @@ +# coding=utf-8 +# Copyright 2022 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OWL-ViT model.""" + +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Optional, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + auto_docstring, + filter_out_non_signature_kwargs, + is_vision_available, + logging, + torch_int, +) +from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig + + +if is_vision_available(): + from transformers.image_transforms import center_to_corners_format + + +logger = logging.get_logger(__name__) + + +# See all OwlViT models at https://huggingface.co/models?filter=owlvit + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit +def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +@auto_docstring +class OwlViTOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`OwlViTVisionModel`]. + text_model_output (tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.loss.loss_for_object_detection._upcast +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +# Copied from transformers.loss.loss_for_object_detection.box_area +def box_area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# Copied from transformers.loss.loss_for_object_detection.box_iou +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: + `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") + iou, union = box_iou(boxes1, boxes2) + + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] + + return iou - (area - union) / area + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`OwlViTForObjectDetection`]. + """ +) +class OwlViTObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + class_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`OwlViTForObjectDetection.image_guided_detection`]. + """ +) +class OwlViTImageGuidedObjectDetectionOutput(ModelOutput): + r""" + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual target image in the batch + (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual query image in the batch + (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + logits: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + query_image_embeds: Optional[torch.FloatTensor] = None + target_pred_boxes: Optional[torch.FloatTensor] = None + query_pred_boxes: Optional[torch.FloatTensor] = None + class_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class OwlViTVisionEmbeddings(nn.Module): + def __init__(self, config: OwlViTVisionConfig): + super().__init__() + self.patch_size = config.patch_size + self.config = config + self.embed_dim = config.hidden_size + self.class_embedding = nn.Parameter(torch.randn(config.hidden_size)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + + self.num_patches = (config.image_size // config.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class OwlViTTextEmbeddings(nn.Module): + def __init__(self, config: OwlViTTextConfig): + super().__init__() + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class OwlViTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # For int8 compatibility, sometimes the `attn_probs` are in `fp32` + attn_probs = attn_probs.to(value_states.dtype) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT +class OwlViTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->OwlViT +class OwlViTEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: OwlViTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OwlViTAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = OwlViTMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +@auto_docstring +class OwlViTPreTrainedModel(PreTrainedModel): + config: OwlViTConfig + base_model_prefix = "owlvit" + supports_gradient_checkpointing = True + _no_split_modules = ["OwlViTEncoderLayer"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, OwlViTTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, OwlViTVisionEmbeddings): + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, OwlViTAttention): + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, OwlViTMLP): + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, OwlViTModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * factor, + ) + module.logit_scale.data.fill_(self.config.logit_scale_init_value) + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=factor) + if module.bias is not None: + module.bias.data.zero_() + + +class OwlViTEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`OwlViTEncoderLayer`]. + + Args: + config: OwlViTConfig + """ + + def __init__(self, config: OwlViTConfig): + super().__init__() + self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class OwlViTTextTransformer(nn.Module): + def __init__(self, config: OwlViTTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = OwlViTTextEmbeddings(config) + self.encoder = OwlViTEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @auto_docstring + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries + # OWLVIT's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # take features from the end of tokens embedding (end of token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class OwlViTTextModel(OwlViTPreTrainedModel): + config: OwlViTTextConfig + + def __init__(self, config: OwlViTTextConfig): + super().__init__(config) + self.text_model = OwlViTTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @auto_docstring + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + + Examples: + ```python + >>> from transformers import AutoProcessor, OwlViTTextModel + + >>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + # Get embeddings for all text queries in all batch samples + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class OwlViTVisionTransformer(nn.Module): + def __init__(self, config: OwlViTVisionConfig): + super().__init__() + self.config = config + + self.embeddings = OwlViTVisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.encoder = OwlViTEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Cast the input to the expected `dtype` + expected_input_dtype = self.embeddings.patch_embedding.weight.dtype + pixel_values = pixel_values.to(expected_input_dtype) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layernorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class OwlViTVisionModel(OwlViTPreTrainedModel): + config: OwlViTVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: OwlViTVisionConfig): + super().__init__(config) + self.vision_model = OwlViTVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, OwlViTVisionModel + + >>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + +@auto_docstring +class OwlViTModel(OwlViTPreTrainedModel): + config: OwlViTConfig + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + if not isinstance(config.text_config, OwlViTTextConfig): + raise TypeError( + "config.text_config is expected to be of type OwlViTTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, OwlViTVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type OwlViTVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = OwlViTTextTransformer(text_config) + self.vision_model = OwlViTVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`OwlViTTextModel`]. + + Examples: + ```python + >>> import torch + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Get embeddings for all text queries in all batch samples + text_outputs: BaseModelOutputWithPooling = self.text_model(input_ids=input_ids, attention_mask=attention_mask) + text_features = self.text_projection(text_outputs.pooler_output) + + return text_features + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`OwlViTVisionModel`]. + + Examples: + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.inference_mode(): + ... image_features = model.get_image_features(**inputs) + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + image_features = self.visual_projection(vision_outputs.pooler_output) + + return image_features + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_base_image_embeds: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, OwlViTOutput]: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + return_base_image_embeds (`bool`, *optional*): + Whether or not to return the base image embeddings. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + # Get embeddings for all text queries in all batch samples + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + # normalized features + image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True) + text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True) + + # cosine similarity as logits and set it on the correct device + logit_scale = self.logit_scale.exp().to(image_embeds.device) + + logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = owlvit_loss(logits_per_text) + + text_embeds = text_embeds_norm + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return OwlViTOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class OwlViTBoxPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig, out_dim: int = 4): + super().__init__() + + width = config.vision_config.hidden_size + self.dense0 = nn.Linear(width, width) + self.dense1 = nn.Linear(width, width) + self.gelu = nn.GELU() + self.dense2 = nn.Linear(width, out_dim) + + def forward(self, image_features: torch.Tensor) -> torch.FloatTensor: + output = self.dense0(image_features) + output = self.gelu(output) + output = self.dense1(output) + output = self.gelu(output) + output = self.dense2(output) + return output + + +class OwlViTClassPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig): + super().__init__() + + out_dim = config.text_config.hidden_size + self.query_dim = config.vision_config.hidden_size + + self.dense0 = nn.Linear(self.query_dim, out_dim) + self.logit_shift = nn.Linear(self.query_dim, 1) + self.logit_scale = nn.Linear(self.query_dim, 1) + self.elu = nn.ELU() + + def forward( + self, + image_embeds: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor], + query_mask: Optional[torch.Tensor], + ) -> tuple[torch.FloatTensor]: + image_class_embeds = self.dense0(image_embeds) + if query_embeds is None: + device = image_class_embeds.device + batch_size, num_patches = image_class_embeds.shape[:2] + pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device) + return (pred_logits, image_class_embeds) + + # Normalize image and text features + image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6) + query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) + + # Get class predictions + pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + + # Apply a learnable shift and scale to logits + logit_shift = self.logit_shift(image_embeds) + logit_scale = self.logit_scale(image_embeds) + logit_scale = self.elu(logit_scale) + 1 + pred_logits = (pred_logits + logit_shift) * logit_scale + + if query_mask is not None: + if query_mask.ndim > 1: + query_mask = torch.unsqueeze(query_mask, dim=-2) + + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) + pred_logits = pred_logits.to(torch.float32) + + return (pred_logits, image_class_embeds) + + +class OwlViTForObjectDetection(OwlViTPreTrainedModel): + config: OwlViTConfig + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + self.owlvit = OwlViTModel(config) + self.class_head = OwlViTClassPredictionHead(config) + self.box_head = OwlViTBoxPredictionHead(config) + + self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) + self.sigmoid = nn.Sigmoid() + self.config = config + self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width) + + @staticmethod + def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor: + # Create grid coordinates using torch + x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32) + y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32) + xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy") + + # Stack the coordinates and divide by their respective patch counts + box_coordinates = torch.stack((xx, yy), dim=-1) + box_coordinates[..., 0] /= num_patches_width + box_coordinates[..., 1] /= num_patches_height + + # Flatten (h, w, 2) -> (h*w, 2) + box_coordinates = box_coordinates.view(-1, 2) + + return box_coordinates + + @lru_cache(maxsize=2) + def compute_box_bias( + self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None + ) -> torch.Tensor: + if feature_map is not None: + raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead") + # The box center is biased to its position on the feature grid + box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width) + box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) + + # Unnormalize xy + box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) + + # The box size is biased to the patch size + box_size = torch.full_like(box_coord_bias, 1.0) + box_size[..., 0] /= num_patches_width + box_size[..., 1] /= num_patches_height + box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) + + # Compute box bias + box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) + return box_bias + + def box_predictor( + self, + image_feats: torch.FloatTensor, + feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + """ + Args: + image_feats: + Features extracted from the image, returned by the `image_text_embedder` method. + feature_map: + A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. + interpolate_pos_encoding: + Whether to interpolate the pre-trained position encodings. + Returns: + pred_boxes: + List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. + """ + # Bounding box detection head [batch_size, num_boxes, 4]. + pred_boxes = self.box_head(image_feats) + + # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction + if interpolate_pos_encoding: + _, num_patches_height, num_patches_width, _ = feature_map.shape + box_bias = self.compute_box_bias(num_patches_height, num_patches_width) + else: + box_bias = self.box_bias + + box_bias = box_bias.to(feature_map.device) + pred_boxes += box_bias + pred_boxes = self.sigmoid(pred_boxes) + return pred_boxes + + def class_predictor( + self, + image_feats: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor] = None, + query_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.FloatTensor]: + """ + Args: + image_feats: + Features extracted from the `image_text_embedder`. + query_embeds: + Text query embeddings. + query_mask: + Must be provided with query_embeddings. A mask indicating which query embeddings are valid. + """ + (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) + + return (pred_logits, image_class_embeds) + + def image_text_embedder( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> tuple[torch.FloatTensor]: + # Encode text and image + outputs = self.owlvit( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Get image embeddings + last_hidden_state = outputs.vision_model_output[0] + image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + text_embeds = outputs[-4] + + return (text_embeds, image_embeds, outputs) + + def image_embedder( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> tuple[torch.FloatTensor]: + # Get OwlViTModel vision embeddings (same as CLIP) + vision_outputs = self.owlvit.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Apply post_layernorm to last_hidden_state, return non-projected output + last_hidden_state = vision_outputs[0] + image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + + return (image_embeds, vision_outputs) + + def embed_image_query( + self, + query_image_features: torch.FloatTensor, + query_feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + _, class_embeds = self.class_predictor(query_image_features) + pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding) + pred_boxes_as_corners = center_to_corners_format(pred_boxes) + + # Loop over query images + best_class_embeds = [] + best_box_indices = [] + pred_boxes_device = pred_boxes_as_corners.device + + for i in range(query_image_features.shape[0]): + each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device) + each_query_pred_boxes = pred_boxes_as_corners[i] + ious, _ = box_iou(each_query_box, each_query_pred_boxes) + + # If there are no overlapping boxes, fall back to generalized IoU + if torch.all(ious[0] == 0.0): + ious = generalized_box_iou(each_query_box, each_query_pred_boxes) + + # Use an adaptive threshold to include all boxes within 80% of the best IoU + iou_threshold = torch.max(ious) * 0.8 + + selected_inds = (ious[0] >= iou_threshold).nonzero() + if selected_inds.numel(): + selected_embeddings = class_embeds[i][selected_inds.squeeze(1)] + mean_embeds = torch.mean(class_embeds[i], axis=0) + mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings) + best_box_ind = selected_inds[torch.argmin(mean_sim)] + best_class_embeds.append(class_embeds[i][best_box_ind]) + best_box_indices.append(best_box_ind) + + if best_class_embeds: + query_embeds = torch.stack(best_class_embeds) + box_indices = torch.stack(best_box_indices) + else: + query_embeds, box_indices = None, None + + return query_embeds, box_indices, pred_boxes + + @auto_docstring + def image_guided_detection( + self, + pixel_values: torch.FloatTensor, + query_pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> OwlViTImageGuidedObjectDetectionOutput: + r""" + query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values of query image(s) to be detected. Pass in one query image per target image. + + Examples: + ```python + >>> import requests + >>> from PIL import Image + >>> import torch + >>> from transformers import AutoProcessor, OwlViTForObjectDetection + + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch16") + >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg" + >>> query_image = Image.open(requests.get(query_url, stream=True).raw) + >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model.image_guided_detection(**inputs) + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.Tensor([image.size[::-1]]) + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_image_guided_detection( + ... outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes + ... ) + >>> i = 0 # Retrieve predictions for the first image + >>> boxes, scores = results[i]["boxes"], results[i]["scores"] + >>> for box, score in zip(boxes, scores): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}") + Detected similar object with confidence 0.856 at location [10.94, 50.4, 315.8, 471.39] + Detected similar object with confidence 1.0 at location [334.84, 25.33, 636.16, 374.71] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Compute feature maps for the input and query images + query_feature_map = self.image_embedder( + pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + )[0] + feature_map, vision_outputs = self.image_embedder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape + query_image_feats = torch.reshape( + query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim) + ) + # Get top class embedding and best box index for each query image in batch + query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query( + query_image_feats, query_feature_map, interpolate_pos_encoding + ) + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds) + + # Predict object boxes + target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + if not return_dict: + output = ( + feature_map, + query_feature_map, + target_pred_boxes, + query_pred_boxes, + pred_logits, + class_embeds, + vision_outputs.to_tuple(), + ) + output = tuple(x for x in output if x is not None) + return output + + return OwlViTImageGuidedObjectDetectionOutput( + image_embeds=feature_map, + query_image_embeds=query_feature_map, + target_pred_boxes=target_pred_boxes, + query_pred_boxes=query_pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_output=None, + vision_model_output=vision_outputs, + ) + + @auto_docstring + def forward( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> OwlViTObjectDetectionOutput: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids). + output_hidden_states (`bool`, *optional*): + Whether or not to return the last hidden state. See `text_model_last_hidden_state` and + `vision_model_last_hidden_state` under returned tensors for more detail. + + Examples: + ```python + >>> import requests + >>> from PIL import Image + >>> import torch + + >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection + + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text_labels = [["a photo of a cat", "a photo of a dog"]] + >>> inputs = processor(text=text_labels, images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.tensor([(image.height, image.width)]) + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_grounded_object_detection( + ... outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels + ... ) + >>> # Retrieve predictions for the first image for the corresponding text queries + >>> result = results[0] + >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"] + >>> for box, score, text_label in zip(boxes, scores, text_labels): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") + Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] + Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Embed images and text queries + query_embeds, feature_map, outputs = self.image_text_embedder( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + # Text and vision model outputs + text_outputs = outputs.text_model_output + vision_outputs = outputs.vision_model_output + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] + max_text_queries = input_ids.shape[0] // batch_size + query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1]) + + # If first token is 0, then this is a padded query [batch_size, num_queries]. + input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1]) + query_mask = input_ids[..., 0] > 0 + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask) + + # Predict object boxes + pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + if not return_dict: + output = ( + pred_logits, + pred_boxes, + query_embeds, + feature_map, + class_embeds, + text_outputs.to_tuple(), + vision_outputs.to_tuple(), + ) + output = tuple(x for x in output if x is not None) + return output + + return OwlViTObjectDetectionOutput( + image_embeds=feature_map, + text_embeds=query_embeds, + pred_boxes=pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["OwlViTModel", "OwlViTPreTrainedModel", "OwlViTTextModel", "OwlViTVisionModel", "OwlViTForObjectDetection"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/owlvit/processing_owlvit.py b/venv/lib/python3.13/site-packages/transformers/models/owlvit/processing_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..ae39d2b6b30783fb7281b56f874f8a840b129732 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/owlvit/processing_owlvit.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for OWL-ViT +""" + +import warnings +from typing import TYPE_CHECKING, Optional, Union + +import numpy as np + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available + + +if TYPE_CHECKING: + from .modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput, OwlViTObjectDetectionOutput + + +class OwlViTImagesKwargs(ImagesKwargs, total=False): + query_images: Optional[ImageInput] + + +class OwlViTProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: OwlViTImagesKwargs + _defaults = { + "text_kwargs": { + "padding": "max_length", + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "np", + }, + } + + +class OwlViTProcessor(ProcessorMixin): + r""" + Constructs an OWL-ViT processor which wraps [`OwlViTImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] + into a single processor that inherits both the image processor and tokenizer functionalities. See the + [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information. + + Args: + image_processor ([`OwlViTImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "OwlViTImageProcessor" + tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[OwlViTProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and + `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: + the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, + `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The query image to be prepared, one query image is expected per target image to be queried. Each image + can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image + should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + OwlViTProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + query_images = output_kwargs["images_kwargs"].pop("query_images", None) + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] + + if text is None and query_images is None and images is None: + raise ValueError( + "You have to specify at least one text or query image or image. All three cannot be none." + ) + + data = {} + if text is not None: + if isinstance(text, str) or (isinstance(text, list) and not isinstance(text[0], list)): + encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])] + + elif isinstance(text, list) and isinstance(text[0], list): + encodings = [] + + # Maximum number of queries across batch + max_num_queries = max(len(text_single) for text_single in text) + + # Pad all batch samples to max number of text queries + for text_single in text: + if len(text_single) != max_num_queries: + text_single = text_single + [" "] * (max_num_queries - len(text_single)) + + encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"]) + encodings.append(encoding) + else: + raise TypeError("Input text should be a string, a list of strings or a nested list of strings") + + if return_tensors == "np": + input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "jax" and is_flax_available(): + import jax.numpy as jnp + + input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "pt" and is_torch_available(): + import torch + + input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) + attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) + + elif return_tensors == "tf" and is_tf_available(): + import tensorflow as tf + + input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) + + else: + raise ValueError("Target return tensor type could not be returned") + + data["input_ids"] = input_ids + data["attention_mask"] = attention_mask + + if query_images is not None: + query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values + # Query images always override the text prompt + data = {"query_pixel_values": query_pixel_values} + + if images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data["pixel_values"] = image_features.pixel_values + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OwlViTImageProcessor.post_process`]. Please refer to the docstring + of this method for more information. + """ + return self.image_processor.post_process(*args, **kwargs) + + def post_process_object_detection(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer + to the docstring of this method for more information. + """ + warnings.warn( + "`post_process_object_detection` method is deprecated for OwlVitProcessor and will be removed in v5. " + "Use `post_process_grounded_object_detection` instead.", + FutureWarning, + ) + return self.image_processor.post_process_object_detection(*args, **kwargs) + + def post_process_grounded_object_detection( + self, + outputs: "OwlViTObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, list[tuple]]] = None, + text_labels: Optional[list[list[str]]] = None, + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + text_labels (`list[list[str]]`, *optional*): + List of lists of text labels for each image in the batch. If unset, "text_labels" in output will be + set to `None`. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + - "text_labels": The text labels for each predicted bounding box on the image. + """ + output = self.image_processor.post_process_object_detection( + outputs=outputs, threshold=threshold, target_sizes=target_sizes + ) + + if text_labels is not None and len(text_labels) != len(output): + raise ValueError("Make sure that you pass in as many lists of text labels as images") + + # adding text labels to the output + if text_labels is not None: + for image_output, image_text_labels in zip(output, text_labels): + object_text_labels = [image_text_labels[i] for i in image_output["labels"]] + image_output["text_labels"] = object_text_labels + else: + for image_output in output: + image_output["text_labels"] = None + + return output + + def post_process_image_guided_detection( + self, + outputs: "OwlViTImageGuidedObjectDetectionOutput", + threshold: float = 0.0, + nms_threshold: float = 0.3, + target_sizes: Optional[Union[TensorType, list[tuple]]] = None, + ): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + - "labels": Set to `None`. + """ + return self.image_processor.post_process_image_guided_detection( + outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes + ) + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + @property + def feature_extractor(self): + warnings.warn( + "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", + FutureWarning, + ) + return self.image_processor + + +__all__ = ["OwlViTProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/paligemma/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/paligemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9048afe6adbdc0ad36007e02f60e899cae677c55 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/paligemma/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_paligemma import * + from .modeling_paligemma import * + from .processing_paligemma import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/paligemma/configuration_paligemma.py b/venv/lib/python3.13/site-packages/transformers/models/paligemma/configuration_paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ee4b3b45c2637d804f7eafa7f9cc7afa2638f8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/paligemma/configuration_paligemma.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PaliGemmamodel configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class PaliGemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an + PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`PaliGemmaVisionConfig`, *optional*): + Custom vision config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + image_token_index (`int`, *optional*, defaults to 256000): + The image token index to encode the image prompt. + vocab_size (`int`, *optional*, defaults to 257152): + Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`] + projection_dim (`int`, *optional*, defaults to 2048): + Dimension of the multimodal projection space. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden layer of the Language model. + + Example: + + ```python + >>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a PaliGemma config + >>> text_config = GemmaConfig() + + >>> # Initializing a PaliGemma paligemma-3b-224 style configuration + >>> configuration = PaliGemmaConfig(vision_config, text_config) + + >>> # Initializing a model from the paligemma-3b-224 style configuration + >>> model = PaliGemmaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "paligemma" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=256000, + vocab_size=257152, + projection_dim=2048, + hidden_size=2048, + **kwargs, + ): + self.image_token_index = image_token_index + self.projection_dim = projection_dim + self.hidden_size = hidden_size + self.vision_config = vision_config + self.is_encoder_decoder = False + + if isinstance(self.vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model") + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["siglip_vision_model"]( + intermediate_size=4096, + hidden_size=1152, + patch_size=14, + image_size=224, + num_hidden_layers=27, + num_attention_heads=16, + vocab_size=257152, + vision_use_head=False, + ) + + self.text_config = text_config + if isinstance(self.text_config, dict): + text_config["model_type"] = text_config.get("model_type", "gemma") + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + self.text_config = CONFIG_MAPPING["gemma"]( + hidden_size=2048, + num_hidden_layers=18, + intermediate_size=16384, + num_attention_heads=8, + num_key_value_heads=1, + is_encoder_decoder=False, + vocab_size=vocab_size, + ) + self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2 + self.vision_config.projection_dim = projection_dim + super().__init__(**kwargs) + + +__all__ = ["PaliGemmaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/paligemma/modeling_paligemma.py b/venv/lib/python3.13/site-packages/transformers/models/paligemma/modeling_paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..abd8595e24abb7d3850a19144ebe67707b85f7ae --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/paligemma/modeling_paligemma.py @@ -0,0 +1,625 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PaliGemmamodel.""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from ..auto import AutoModel +from .configuration_paligemma import PaliGemmaConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Paligemma outputs, with hidden states and attentions. + """ +) +class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): + r""" + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for PaliGemma causal language model (or autoregressive) outputs. + """ +) +class PaliGemmaCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class PaliGemmaMultiModalProjector(nn.Module): + def __init__(self, config: PaliGemmaConfig): + super().__init__() + self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) + + def forward(self, image_features): + hidden_states = self.linear(image_features) + + return hidden_states + + +@auto_docstring +class PaliGemmaPreTrainedModel(PreTrainedModel): + config: PaliGemmaConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["PaliGemmaMultiModalProjector"] + _skip_keys_device_placement = "past_key_values" + + _can_compile_fullgraph = False + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + # important: this ported version of PaliGemmaisn't meant for training from scratch - only + # inference and fine-tuning + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., + """ +) +class PaliGemmaModel(PaliGemmaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch + accepts_loss_kwargs = False + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModel.from_config(config=config.text_config) + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.text_config_dtype = self.config.get_text_config().dtype or self.dtype + self.post_init() + + # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def _update_causal_mask( + self, + attention_mask, + token_type_ids=None, + past_key_values=None, + cache_position=None, + input_tensor=None, + is_training: Optional[bool] = None, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + is_training = is_training if is_training is not None else self.training + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(self.text_config_dtype).min + if input_tensor is None: + input_tensor = attention_mask + + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=self.text_config_dtype, + device=cache_position.device, + ) + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + if is_training: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + causal_mask[:, :sequence_length] = 0.0 + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # First unmask prefix tokens during training + if is_training: + if token_type_ids is None: + raise ValueError("Token type ids must be provided during training") + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features / (self.config.text_config.hidden_size**0.5) + return image_features + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, PaligemmaModelOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id with PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return PaligemmaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., + """ +) +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.model = PaliGemmaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.model.get_image_features(pixel_values) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return PaliGemmaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + is_static_hybrid_cache = isinstance(past_key_values, StaticCache) and any(past_key_values.is_sliding) + if cache_position[0] == 0 and is_static_hybrid_cache: + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self.model._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/paligemma/processing_paligemma.py b/venv/lib/python3.13/site-packages/transformers/models/paligemma/processing_paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..242627a0eb71a98f71ca0aadf260cc51df3fb085 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/paligemma/processing_paligemma.py @@ -0,0 +1,336 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for PaliGemma. +""" + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import ( + ImagesKwargs, + MultiModalData, + ProcessingKwargs, + ProcessorMixin, + TextKwargs, + Unpack, +) +from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + +IMAGE_TOKEN = "" +EXTRA_TOKENS = [f"4}>" for i in range(1024)] + [f"3}>" for i in range(128)] + + +class PaliGemmaTextKwargs(TextKwargs): + suffix: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] + + +class PaliGemmaImagesKwargs(ImagesKwargs): + do_convert_rgb: Optional[bool] + + +class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: PaliGemmaTextKwargs + images_kwargs: PaliGemmaImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + "images_kwargs": { + "data_format": "channels_first", + }, + } + + +# Copied from transformers.models.idefics2.processing_idefics2.is_url +def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url +def is_image_or_image_url(elem): + return is_url(elem) or is_valid_image(elem) + + +def _is_str_or_image(elem): + return isinstance(elem, (str)) or is_image_or_image_url(elem) + + +def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images): + """ + Builds a string from the input prompt and image tokens. + For example, for the call: + build_string_from_input( + prompt="Prefix str" + bos_token="", + image_seq_len=3, + image_token="", + ) + The output will be: + "Initial str" + Args: + prompt (`list[Union[str, ImageInput]]`): The input prompt. + bos_token (`str`): The beginning of sentence token. + image_seq_len (`int`): The length of the image sequence. + image_token (`str`): The image token. + num_images (`int`): Number of images in the prompt. + """ + return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" + + +class PaliGemmaProcessor(ProcessorMixin): + r""" + Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor. + + [`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`GemmaTokenizerFast`]. See the + [`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`GemmaTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") + tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + **kwargs, + ): + if not hasattr(image_processor, "image_seq_length"): + raise ValueError("Image processor is missing an `image_seq_length` attribute.") + + self.image_seq_length = image_processor.image_seq_length + + if not hasattr(tokenizer, "image_token"): + image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + self.image_token = IMAGE_TOKEN + else: + self.image_token_id = tokenizer.image_token_id + self.image_token = tokenizer.image_token + + tokenizer.add_tokens(EXTRA_TOKENS) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[PaliGemmaProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to GemmaTokenizerFast's [`~GemmaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + The usage for PaliGemma fine-tuning preparation is slightly different than usual. suffix passed are suffixes to + the prompt in `text`, and will be placed after the prompt. This is because attention is handled differently for + the prefix and the suffix. For instance, + ```python + image = PIL_cow_image + prompt = "answer en Where is the cow standing?" + suffix = "on the beach" + inputs = processor(text=prompt, images=image, suffix=suffix) + ``` + Here `inputs` will contain the `input_ids` and `token_type_ids` that follow + ```python + inputs["input_ids"][:, 256:] + # tensor([[ 2, 6006, 603, 573, 13910, 9980, 235336, 108, 477, 573, 8318]]) + inputs["token_type_ids"][:, 256:] + tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]) + ``` + Meaning the last three tokens are of "label" ("suffix") type while the other ones are of "prefix" type. + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + suffix (`str`, `list[str]`, `list[list[str]]`): + The suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md + for more information. If your prompt is " What is on the image", the suffix corresponds to the expected prediction "a cow sitting on a bench". + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix` + is provided, the `input_ids` will also contain the suffix input ids. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **labels** -- Labels compatible with training if `suffix` is not None + """ + + output_kwargs = self._merge_kwargs( + PaliGemmaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) + + return_token_type_ids = suffix is not None + + if images is None: + raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.") + if text is None: + logger.warning_once( + "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model." + ) + text = "" + + if _is_str_or_image(text): + text = [text] + elif isinstance(text, list) and _is_str_or_image(text[0]): + pass + + if text is not None and images is not None: + if not any(IMAGE_TOKEN in sample for sample in text): + logger.warning( + "You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special " + "image tokens in the text, as many tokens as there are images per each text. It is recommended to " + "add `` tokens in the very beginning of your text. For this call, we will infer how many images " + "each text has and add special tokens." + ) + + if isinstance(text, list) and isinstance(images, list): + if len(images) != len(text): + raise ValueError( + f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images." + ) + + # make a nested list of lists to be able to iterate over the images and text below + if is_valid_image(images): + images = [[images]] + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + images = [[image] for image in images] + elif not ( + isinstance(images, (list, tuple)) + and isinstance(images[0], (list, tuple)) + and is_valid_image(images[0][0]) + ): + raise ValueError("images must be an image, list of images or list of list of images") + + input_strings = [ + build_string_from_input( + prompt=prompt, + bos_token=self.tokenizer.bos_token, + image_seq_len=self.image_seq_length, + image_token=IMAGE_TOKEN, + num_images=len(image_list) if isinstance(image_list, list) else 1, + ) + for prompt, image_list in zip(text, images) + ] + else: + expanded_samples = [] + for sample in text: + expanded_sample = sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) + bos_rfind_index = expanded_sample.rfind(IMAGE_TOKEN) + bos_index = bos_rfind_index + len(IMAGE_TOKEN) if bos_rfind_index != -1 else 0 + expanded_sample = ( + expanded_sample[:bos_index] + self.tokenizer.bos_token + expanded_sample[bos_index:] + ) + expanded_samples.append(expanded_sample) + input_strings = [f"{sample}\n" for sample in expanded_samples] + + if suffix is not None and _is_str_or_image(suffix): + suffix = [suffix] + if suffix is not None: + suffix = [sfx + self.tokenizer.eos_token for sfx in suffix] + pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + inputs = self.tokenizer( + input_strings, + text_pair=suffix, + return_token_type_ids=return_token_type_ids, + **output_kwargs["text_kwargs"], + ) + self._check_special_mm_tokens(input_strings, inputs, modalities=["image"]) + + return_data = {**inputs, "pixel_values": pixel_values} + + if return_token_type_ids: + labels = np.array(inputs["input_ids"]) + labels[np.array(inputs["token_type_ids"]) == 0] = -100 + return_data.update({"labels": labels}) + + if return_mm_token_type_ids: + array_ids = np.array(return_data["input_ids"]) + mm_token_type_ids = np.zeros_like(return_data["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + return_data["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data=return_data, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (list[list[str]], *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + vision_data = {} + if image_sizes is not None: + num_image_tokens = [self.image_seq_length] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + return MultiModalData(**vision_data) + + +__all__ = ["PaliGemmaProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/patchtst/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/patchtst/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a045bfb1820424f77c9b798a3c027569eed3eb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/patchtst/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_patchtst import * + from .modeling_patchtst import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/patchtst/configuration_patchtst.py b/venv/lib/python3.13/site-packages/transformers/models/patchtst/configuration_patchtst.py new file mode 100644 index 0000000000000000000000000000000000000000..17cabe491c022be1047eb2189cb54bcdfec78848 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/patchtst/configuration_patchtst.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PatchTST model configuration""" + +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class PatchTSTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`PatchTSTModel`]. It is used to instantiate an + PatchTST model according to the specified arguments, defining the model architecture. + [ibm/patchtst](https://huggingface.co/ibm/patchtst) architecture. + + Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_input_channels (`int`, *optional*, defaults to 1): + The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of + multivariate targets. + context_length (`int`, *optional*, defaults to 32): + The context length of the input sequence. + distribution_output (`str`, *optional*, defaults to `"student_t"`): + The distribution emission head for the model when loss is "nll". Could be either "student_t", "normal" or + "negative_binomial". + loss (`str`, *optional*, defaults to `"mse"`): + The loss function for the model corresponding to the `distribution_output` head. For parametric + distributions it is the negative log likelihood ("nll") and for point estimates it is the mean squared + error "mse". + patch_length (`int`, *optional*, defaults to 1): + Define the patch length of the patchification process. + patch_stride (`int`, *optional*, defaults to 1): + Define the stride of the patchification process. + num_hidden_layers (`int`, *optional*, defaults to 3): + Number of hidden layers. + d_model (`int`, *optional*, defaults to 128): + Dimensionality of the transformer layers. + num_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + share_embedding (`bool`, *optional*, defaults to `True`): + Sharing the input embedding across all channels. + channel_attention (`bool`, *optional*, defaults to `False`): + Activate channel attention block in the Transformer to allow channels to attend each other. + ffn_dim (`int`, *optional*, defaults to 512): + Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + norm_type (`str` , *optional*, defaults to `"batchnorm"`): + Normalization at each Transformer layer. Can be `"batchnorm"` or `"layernorm"`. + norm_eps (`float`, *optional*, defaults to 1e-05): + A value added to the denominator for numerical stability of normalization. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention probabilities. + positional_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability in the positional embedding layer. + path_dropout (`float`, *optional*, defaults to 0.0): + The dropout path in the residual block. + ff_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability used between the two layers of the feed-forward networks. + bias (`bool`, *optional*, defaults to `True`): + Whether to add bias in the feed-forward networks. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (string) in the Transformer.`"gelu"` and `"relu"` are supported. + pre_norm (`bool`, *optional*, defaults to `True`): + Normalization is applied before self-attention if pre_norm is set to `True`. Otherwise, normalization is + applied after residual block. + positional_encoding_type (`str`, *optional*, defaults to `"sincos"`): + Positional encodings. Options `"random"` and `"sincos"` are supported. + use_cls_token (`bool`, *optional*, defaults to `False`): + Whether cls token is used. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal weight initialization distribution. + share_projection (`bool`, *optional*, defaults to `True`): + Sharing the projection layer across different channels in the forecast head. + scaling (`Union`, *optional*, defaults to `"std"`): + Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the + scaler is set to "mean". + do_mask_input (`bool`, *optional*): + Apply masking during the pretraining. + mask_type (`str`, *optional*, defaults to `"random"`): + Masking type. Only `"random"` and `"forecast"` are currently supported. + random_mask_ratio (`float`, *optional*, defaults to 0.5): + Masking ratio applied to mask the input data during random pretraining. + num_forecast_mask_patches (`int` or `list`, *optional*, defaults to `[2]`): + Number of patches to be masked at the end of each batch sample. If it is an integer, + all the samples in the batch will have the same number of masked patches. If it is a list, + samples in the batch will be randomly masked by numbers defined in the list. This argument is only used + for forecast pretraining. + channel_consistent_masking (`bool`, *optional*, defaults to `False`): + If channel consistent masking is True, all the channels will have the same masking pattern. + unmasked_channel_indices (`list`, *optional*): + Indices of channels that are not masked during pretraining. Values in the list are number between 1 and + `num_input_channels` + mask_value (`int`, *optional*, defaults to 0): + Values in the masked patches will be filled by `mask_value`. + pooling_type (`str`, *optional*, defaults to `"mean"`): + Pooling of the embedding. `"mean"`, `"max"` and `None` are supported. + head_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for head. + prediction_length (`int`, *optional*, defaults to 24): + The prediction horizon that the model will output. + num_targets (`int`, *optional*, defaults to 1): + Number of targets for regression and classification tasks. For classification, it is the number of + classes. + output_range (`list`, *optional*): + Output range for regression task. The range of output values can be set to enforce the model to produce + values within a range. + num_parallel_samples (`int`, *optional*, defaults to 100): + The number of samples is generated in parallel for probabilistic prediction. + + + ```python + >>> from transformers import PatchTSTConfig, PatchTSTModel + + >>> # Initializing an PatchTST configuration with 12 time steps for prediction + >>> configuration = PatchTSTConfig(prediction_length=12) + + >>> # Randomly initializing a model (with random weights) from the configuration + >>> model = PatchTSTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "patchtst" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_attention_heads", + "num_hidden_layers": "num_hidden_layers", + } + + def __init__( + self, + # time series specific configuration + num_input_channels: int = 1, + context_length: int = 32, + distribution_output: str = "student_t", + loss: str = "mse", + # PatchTST arguments + patch_length: int = 1, + patch_stride: int = 1, + # Transformer architecture configuration + num_hidden_layers: int = 3, + d_model: int = 128, + num_attention_heads: int = 4, + share_embedding: bool = True, + channel_attention: bool = False, + ffn_dim: int = 512, + norm_type: str = "batchnorm", + norm_eps: float = 1e-05, + attention_dropout: float = 0.0, + positional_dropout: float = 0.0, + path_dropout: float = 0.0, + ff_dropout: float = 0.0, + bias: bool = True, + activation_function: str = "gelu", + pre_norm: bool = True, + positional_encoding_type: str = "sincos", + use_cls_token: bool = False, + init_std: float = 0.02, + share_projection: bool = True, + scaling: Optional[Union[str, bool]] = "std", + # mask pretraining + do_mask_input: Optional[bool] = None, + mask_type: str = "random", + random_mask_ratio: float = 0.5, + num_forecast_mask_patches: Optional[Union[list[int], int]] = [2], + channel_consistent_masking: Optional[bool] = False, + unmasked_channel_indices: Optional[list[int]] = None, + mask_value: int = 0, + # head + pooling_type: str = "mean", + head_dropout: float = 0.0, + prediction_length: int = 24, + num_targets: int = 1, + output_range: Optional[list] = None, + # distribution head + num_parallel_samples: int = 100, + **kwargs, + ): + # time series specific configuration + self.context_length = context_length + self.num_input_channels = num_input_channels # n_vars + self.loss = loss + self.distribution_output = distribution_output + self.num_parallel_samples = num_parallel_samples + + # Transformer architecture configuration + self.d_model = d_model + self.num_attention_heads = num_attention_heads + self.ffn_dim = ffn_dim + self.num_hidden_layers = num_hidden_layers + self.attention_dropout = attention_dropout + self.share_embedding = share_embedding + self.channel_attention = channel_attention + self.norm_type = norm_type + self.norm_eps = norm_eps + self.positional_dropout = positional_dropout + self.path_dropout = path_dropout + self.ff_dropout = ff_dropout + self.bias = bias + self.activation_function = activation_function + self.pre_norm = pre_norm + self.positional_encoding_type = positional_encoding_type + self.use_cls_token = use_cls_token + self.init_std = init_std + self.scaling = scaling + + # PatchTST parameters + self.patch_length = patch_length + self.patch_stride = patch_stride + + # Mask pretraining + self.do_mask_input = do_mask_input + self.mask_type = mask_type + self.random_mask_ratio = random_mask_ratio # for random masking + self.num_forecast_mask_patches = num_forecast_mask_patches # for forecast masking + self.channel_consistent_masking = channel_consistent_masking + self.unmasked_channel_indices = unmasked_channel_indices + self.mask_value = mask_value + + # general head params + self.pooling_type = pooling_type + self.head_dropout = head_dropout + + # For prediction head + self.share_projection = share_projection + self.prediction_length = prediction_length + + # For prediction and regression head + self.num_parallel_samples = num_parallel_samples + + # Regression + self.num_targets = num_targets + self.output_range = output_range + + super().__init__(**kwargs) + + +__all__ = ["PatchTSTConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/patchtst/modeling_patchtst.py b/venv/lib/python3.13/site-packages/transformers/models/patchtst/modeling_patchtst.py new file mode 100644 index 0000000000000000000000000000000000000000..1912e318f8fdb9f5e4e1daf72a2fbb6c907911b4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/patchtst/modeling_patchtst.py @@ -0,0 +1,1957 @@ +# coding=utf-8 +# Copyright 2023 IBM & Hugging Face. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PatchTST model.""" + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2CLS +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ModelOutput, auto_docstring, logging +from .configuration_patchtst import PatchTSTConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTST +class PatchTSTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PatchTSTConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, None + + +class PatchTSTBatchNorm(nn.Module): + """ + Compute batch normalization over the sequence length (time) dimension. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps) + + def forward(self, inputs: torch.Tensor): + """ + Parameters: + inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`): + input for Batch norm calculation + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length, d_model)` + """ + output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length) + output = self.batchnorm(output) + return output.transpose(1, 2) + + +def random_masking( + inputs: torch.Tensor, + mask_ratio: float, + unmasked_channel_indices: Optional[list] = None, + channel_consistent_masking: bool = False, + mask_value: int = 0, +): + """random_masking: Mask the input considering the control variables. + + Args: + inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`): + The input tensor to mask. + mask_ratio (`float`): + Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1. + unmasked_channel_indices (list, *optional*): + Indices of channels that will not be masked. + channel_consistent_masking (bool, *optional*, defaults to `False`): + When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary + across channels. + mask_value (int, *optional*, defaults to 0): + Define the value of masked patches for pretraining. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x + n] + """ + if mask_ratio < 0 or mask_ratio >= 1: + raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.") + + batch_size, num_channels, sequence_length, num_features = inputs.shape + device = inputs.device + + len_keep = int(sequence_length * (1 - mask_ratio)) + + if channel_consistent_masking: + noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L + noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time + else: + # noise in [0, 1], bs x num_channels x L + noise = torch.rand(batch_size, num_channels, sequence_length, device=device) + + # mask: [bs x num_channels x num_patch] + mask = torch.ones(batch_size, num_channels, sequence_length, device=device) + mask[:, :, :len_keep] = 0 + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L] + + mask = torch.gather(mask, dim=-1, index=ids_restore) + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +def forecast_masking( + inputs: torch.Tensor, + num_forecast_mask_patches: Union[list, int], + unmasked_channel_indices: Optional[list] = None, + mask_value: int = 0, +): + """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches. + If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list. + + Parameters: + inputs (`torch.Tensor`): + Input of shape `(bs, num_channels, num_patch, patch_length)` + num_forecast_mask_patches (`list`): + Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5]. + unmasked_channel_indices (`list`, *optional*): + Indices of channels that are not masked. + mask_value (`int`, *optional*, defaults to 0): + Values in the masked patches will be filled by `mask_value`. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs, + num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)` + """ + + if isinstance(num_forecast_mask_patches, int): + num_forecast_mask_patches = [num_forecast_mask_patches] + forecast_mask_ratios = [1 for _ in num_forecast_mask_patches] + + batch_size, num_channels, sequence_length, num_features = inputs.shape + mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device) + + t_list = [] + total_length = 0 + total_ratio = sum(forecast_mask_ratios) + + for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios): + if patch_length <= 0 or patch_length >= sequence_length: + raise ValueError( + f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches." + ) + temp_len = int(batch_size * ratio / total_ratio) + t_list.append([patch_length, ratio, temp_len]) + total_length += temp_len + + t_list = sorted(t_list, key=lambda x: x[2]) + + if total_length < batch_size: + t_list[0][2] = t_list[0][2] + (batch_size - total_length) + elif total_length > batch_size: + t_list[-1][2] = t_list[-1][2] + (total_length - batch_size) + + batch1 = 0 + for patch_len, _, temp_len in t_list: + batch2 = batch1 + temp_len + mask[batch1:batch2, :, -patch_len:] = 1 + batch1 = batch2 + + perm = torch.randperm(mask.shape[0]) + mask = mask[perm] + + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +class PatchTSTPatchify(nn.Module): + """ + A class to patchify the time series sequence into different patches + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + + self.sequence_length = config.context_length + self.patch_length = config.patch_length + self.patch_stride = config.patch_stride + + if self.sequence_length <= self.patch_length: + raise ValueError( + f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})" + ) + + # get the number of patches + self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1 + new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1) + self.sequence_start = self.sequence_length - new_sequence_length + + def forward(self, past_values: torch.Tensor): + """ + Parameters: + past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*): + Input for patchification + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + sequence_length = past_values.shape[-2] + if sequence_length != self.sequence_length: + raise ValueError( + f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})." + ) + # output: [bs x new_sequence_length x num_channels] + output = past_values[:, self.sequence_start :, :] + # output: [bs x num_patches x num_input_channels x patch_length] + output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride) + # output: [bs x num_input_channels x num_patches x patch_length] + output = output.transpose(-2, -3).contiguous() + return output + + +class PatchTSTMasking(nn.Module): + """ + Class to perform random or forecast masking. + + Parameters: + config (`PatchTSTConfig`): model config + Returns: + x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.random_mask_ratio = config.random_mask_ratio + self.channel_consistent_masking = config.channel_consistent_masking + self.mask_type = config.mask_type + self.num_forecast_mask_patches = config.num_forecast_mask_patches + self.unmasked_channel_indices = config.unmasked_channel_indices + self.mask_value = config.mask_value + if self.unmasked_channel_indices is not None: + self.unmasked_channel_indices = sorted(self.unmasked_channel_indices) + + def forward(self, patch_input: torch.Tensor): + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Patch input + + Return: + masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + + """ + if self.mask_type == "random": + masked_input, mask = random_masking( + inputs=patch_input, + mask_ratio=self.random_mask_ratio, + unmasked_channel_indices=self.unmasked_channel_indices, + channel_consistent_masking=self.channel_consistent_masking, + mask_value=self.mask_value, + ) + elif self.mask_type == "forecast": + masked_input, mask = forecast_masking( + inputs=patch_input, + num_forecast_mask_patches=self.num_forecast_mask_patches, + unmasked_channel_indices=self.unmasked_channel_indices, + mask_value=self.mask_value, + ) + else: + raise ValueError(f"Invalid mask type {self.mask_type}.") + + # mask: [bs x num_input_channels x num_patch] + mask = mask.bool() + return masked_input, mask + + +class PatchTSTEncoderLayer(nn.Module): + """ + PatchTST encoder layer + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + + self.channel_attention = config.channel_attention + # Multi-Head attention + self.self_attn = PatchTSTAttention( + embed_dim=config.d_model, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + # Add & Norm of the sublayer 1 + self.dropout_path1 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity() + if config.norm_type == "batchnorm": + self.norm_sublayer1 = PatchTSTBatchNorm(config) + elif config.norm_type == "layernorm": + self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + else: + raise ValueError(f"{config.norm_type} is not a supported norm layer type.") + + # Add & Norm of the sublayer 2 + if self.channel_attention: + self.dropout_path2 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity() + if config.norm_type == "batchnorm": + self.norm_sublayer2 = PatchTSTBatchNorm(config) + elif config.norm_type == "layernorm": + self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + else: + raise ValueError(f"{config.norm_type} is not a supported norm layer type.") + + # Position-wise Feed-Forward + self.ff = nn.Sequential( + nn.Linear(config.d_model, config.ffn_dim, bias=config.bias), + ACT2CLS[config.activation_function](), + nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(), + nn.Linear(config.ffn_dim, config.d_model, bias=config.bias), + ) + + # Add & Norm of sublayer 3 + self.dropout_path3 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity() + if config.norm_type == "batchnorm": + self.norm_sublayer3 = PatchTSTBatchNorm(config) + elif config.norm_type == "layernorm": + self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + else: + raise ValueError(f"{config.norm_type} is not a supported norm layer type.") + + self.pre_norm = config.pre_norm + + def forward(self, hidden_state: torch.Tensor, output_attentions: Optional[bool] = None): + """ + Parameters: + hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*): + Past values of the time series + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + Return: + `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)` + + """ + batch_size, num_input_channels, sequence_length, d_model = hidden_state.shape + + # First sublayer: attention across time + # hidden_states: [(bs*num_channels) x sequence_length x d_model] + hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model) + + if self.pre_norm: + ## Norm and Multi-Head attention and Add residual connection + attn_output, attn_weights, _ = self.self_attn( + hidden_states=self.norm_sublayer1(hidden_state), output_attentions=output_attentions + ) + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path1(attn_output) + else: + ## Multi-Head attention and Add residual connection and Norm - Standard Transformer from BERT + attn_output, attn_weights, _ = self.self_attn( + hidden_states=hidden_state, output_attentions=output_attentions + ) + # hidden_states: [(bs*num_channels) x sequence_length x d_model] + hidden_state = self.norm_sublayer1(hidden_state + self.dropout_path1(attn_output)) + + # hidden_state: [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model) + + # second sublayer: attention across variable at any given time + if self.channel_attention: + # hidden_state: [bs x sequence_length x num_channels x d_model] + hidden_state = hidden_state.transpose(2, 1).contiguous() + # hidden_state: [(bs*sequence_length) x num_channels x d_model] + hidden_state = hidden_state.view(batch_size * sequence_length, num_input_channels, d_model) + if self.pre_norm: + ## Norm and Multi-Head attention and Add residual connection + attn_output, channel_attn_weights, _ = self.self_attn( + hidden_states=self.norm_sublayer2(hidden_state), output_attentions=output_attentions + ) + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path2(attn_output) + else: + ## Multi-Head attention and Add residual connection and Norm + attn_output, channel_attn_weights, _ = self.self_attn( + hidden_states=hidden_state, output_attentions=output_attentions + ) + # hidden_states: [(bs*sequence_length) x num_channels x d_model] + hidden_state = self.norm_sublayer2(hidden_state + self.dropout_path2(attn_output)) + + # Reshape hidden state + # hidden_state: [bs x sequence_length x num_channels x d_model] + hidden_state = hidden_state.reshape(batch_size, sequence_length, num_input_channels, d_model) + # hidden_state: [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.transpose(1, 2).contiguous() + + # Third sublayer: mixing across hidden + # hidden_state: [(batch_size*num_channels) x sequence_length x d_model] + hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model) + if self.pre_norm: + ## Norm and Position-wise Feed-Forward and Add residual connection + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path3(self.ff(self.norm_sublayer3(hidden_state))) + else: + ## Position-wise Feed-Forward and Add residual connection and Norm + # Add: residual connection with residual dropout + hidden_state = self.norm_sublayer3(hidden_state + self.dropout_path3(self.ff(hidden_state))) + + # [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model) + + outputs = (hidden_state,) + if output_attentions: + outputs += (attn_weights, channel_attn_weights) if self.channel_attention else (attn_weights,) + + return outputs + + +@auto_docstring +class PatchTSTPreTrainedModel(PreTrainedModel): + config: PatchTSTConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + """ + Initialize weights + """ + if isinstance(module, PatchTSTPositionalEncoding): + # get the number of patches + num_patches = ( + max(self.config.context_length, self.config.patch_length) - self.config.patch_length + ) // self.config.patch_stride + 1 + # initialize cls_token + if self.config.use_cls_token: + nn.init.normal_(module.cls_token, std=0.02) + num_patches += 1 + # initialize positional encoding + module.position_enc = module._init_pe(self.config, num_patches) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, PatchTSTBatchNorm): + module.batchnorm.bias.data.zero_() + module.batchnorm.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PatchTSTEncoder)): + module.gradient_checkpointing = value + + +class PatchTSTEmbedding(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.num_input_channels = config.num_input_channels + self.share_embedding = config.share_embedding + # Input encoding: projection of feature vectors onto a d-dim vector space + if self.share_embedding: + self.input_embedding = nn.Linear(config.patch_length, config.d_model) + else: + self.input_embedding = nn.ModuleList() + for _ in range(config.num_input_channels): + self.input_embedding.append(nn.Linear(config.patch_length, config.d_model)) + + def forward(self, patch_input: torch.Tensor): + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Patch input for embedding + return: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)` + """ + # Input encoding + num_input_channels = patch_input.shape[1] + if num_input_channels != self.num_input_channels: + raise ValueError( + f"The defined number of input channels ({self.num_input_channels}) in the config " + f"has to be the same as the number of channels in the batch input ({num_input_channels})" + ) + if self.share_embedding: + embeddings = self.input_embedding(patch_input) # x: [bs x num_channels x num_patches x d_model] + else: + embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)] + embeddings = torch.stack(embeddings, dim=1) + return embeddings + + +class PatchTSTPositionalEncoding(nn.Module): + """ + Class for positional encoding + """ + + def __init__(self, config: PatchTSTConfig, num_patches: int): + super().__init__() + self.use_cls_token = config.use_cls_token + self.num_input_channels = config.num_input_channels + if config.use_cls_token: + # cls_token: [1 x num_input_channels x 1 x d_model] + self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model)) + num_patches += 1 + # positional encoding: [num_patches x d_model] + self.position_enc = self._init_pe(config, num_patches) + # Positional dropout + self.positional_dropout = ( + nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity() + ) + + @staticmethod + def _init_pe(config: PatchTSTConfig, num_patches: int) -> nn.Parameter: + # Positional encoding + if config.positional_encoding_type == "random": + position_enc = nn.Parameter(torch.randn(num_patches, config.d_model), requires_grad=True) + elif config.positional_encoding_type == "sincos": + position_enc = torch.zeros(num_patches, config.d_model) + position = torch.arange(0, num_patches).unsqueeze(1) + div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model)) + position_enc[:, 0::2] = torch.sin(position * div_term) + position_enc[:, 1::2] = torch.cos(position * div_term) + position_enc = position_enc - position_enc.mean() + position_enc = position_enc / (position_enc.std() * 10) + position_enc = nn.Parameter(position_enc, requires_grad=False) + else: + raise ValueError( + f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'." + ) + return position_enc + + def forward(self, patch_input: torch.Tensor): + if self.use_cls_token: + # patch_input: [bs x num_channels x num_patches x d_model] + patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :]) + # append cls token where cls_token: [1 x num_channels x 1 x d_model] + cls_token = self.cls_token + self.position_enc[:1, :] + # get the same copy of cls_token for all the samples in batch: [bs x num_channels x 1 x d_model] + cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1) + # hidden_state: [bs x num_channels x (num_patches+1) x d_model] + hidden_state = torch.cat((cls_tokens, patch_input), dim=2) + else: + # hidden_state: [bs x num_channels x num_patches x d_model] + hidden_state = self.positional_dropout(patch_input + self.position_enc) + return hidden_state + + +class PatchTSTEncoder(PatchTSTPreTrainedModel): + """ + PatchTST Encoder + """ + + def __init__(self, config: PatchTSTConfig, num_patches: int): + super().__init__(config) + self.gradient_checkpointing = False + + # Input embedding: projection of feature vectors onto a d-dim vector space + self.embedder = PatchTSTEmbedding(config) + # Positional encoding + self.positional_encoder = PatchTSTPositionalEncoding(config, num_patches) + # Encoder + self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.num_hidden_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + patch_input: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + ) -> BaseModelOutput: + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Past values of the time series + output_hidden_states (bool, optional): Indicates if hidden states should be outputted. + output_attentions (bool, optional): Indicates if attentions should be outputted. + + return: + `BaseModelOutput` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Input embedding + patch_input = self.embedder(patch_input) + # Positional encoding + hidden_state = self.positional_encoder(patch_input) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_state,) + + layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions) + # get hidden state. hidden_state shape is [bs x num_channels x num_patches x d_model] + # or [bs x num_channels x (num_patches+1) x d_model] if use cls_token + hidden_state = layer_outputs[0] + # append attention matrix at each layer + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + # return past_values, hidden_states + return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model's outputs, with potential hidden states. + """ +) +class PatchTSTModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*): + Bool masked tensor indicating which patches are masked + loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*): + Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length + scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*): + Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length + patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): + Patched input to the Transformer + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + mask: Optional[torch.FloatTensor] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + patch_input: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`PatchTSTForPretraining`]. + """ +) +class PatchTSTForPretrainingOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + prediction_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction outputs of the time series modeling heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`PatchTSTForRegression`]. + """ +) +class PatchTSTForRegressionOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`): + Regression outputs of the time series modeling heads. + """ + + loss: Optional[torch.FloatTensor] = None + regression_outputs: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`PatchTSTForPrediction`]. + """ +) +class PatchTSTForPredictionOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`): + Prediction outputs of the time series modeling heads. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*) + Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length + scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*) + Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length + """ + + loss: Optional[torch.FloatTensor] = None + prediction_outputs: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`PatchTSTForClassification`]. + """ +) +class PatchTSTForClassificationOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`): + Prediction scores of the PatchTST modeling head (scores before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + """ +) +class SamplePatchTSTOutput(ModelOutput): + r""" + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, num_targets)`): + Sampled values from the chosen distribution. + """ + + sequences: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTStdScaler(nn.Module): + """ + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5 + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) + denominator = denominator.clamp_min(1.0) + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator + + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + scale = torch.sqrt(variance + self.minimum_scale) + return (data - loc) / scale, loc, scale + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTMeanScaler(nn.Module): + """ + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data + accordingly. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) + num_observed = observed_indicator.sum(self.dim, keepdim=True) + + scale = ts_sum / torch.clamp(num_observed, min=1) + + # If `default_scale` is provided, we use it, otherwise we use the scale + # of the batch. + if self.default_scale is None: + batch_sum = ts_sum.sum(dim=0) + batch_observations = torch.clamp(num_observed.sum(0), min=1) + default_scale = torch.squeeze(batch_sum / batch_observations) + else: + default_scale = self.default_scale * torch.ones_like(scale) + + # apply default scale where there are no observations + scale = torch.where(num_observed > 0, scale, default_scale) + + # ensure the scale is at least `self.minimum_scale` + scale = torch.clamp(scale, min=self.minimum_scale) + scaled_data = data / scale + + if not self.keepdim: + scale = scale.squeeze(dim=self.dim) + + return scaled_data, torch.zeros_like(scale), scale + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTNOPScaler(nn.Module): + """ + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + + def forward( + self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + return data, loc, scale + + +class PatchTSTScaler(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + if config.scaling == "mean" or config.scaling is True: + self.scaler = PatchTSTMeanScaler(config) + elif config.scaling == "std": + self.scaler = PatchTSTStdScaler(config) + else: + self.scaler = PatchTSTNOPScaler(config) + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Input for scaler calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, um_input_channels)`) + """ + data, loc, scale = self.scaler(data, observed_indicator) + return data, loc, scale + + +@auto_docstring +class PatchTSTModel(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + self.scaler = PatchTSTScaler(config) + self.patchifier = PatchTSTPatchify(config) + self.do_mask_input = config.do_mask_input + # get num_patches information from PatchTSTPatchify + num_patches = self.patchifier.num_patches + + if self.do_mask_input: + self.masking = PatchTSTMasking(config) + else: + self.masking = nn.Identity() + self.encoder = PatchTSTEncoder(config, num_patches=num_patches) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTModelOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*): + Future target values associated with the `past_values` + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): + Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False) + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import PatchTSTModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain") + + >>> # during training, one provides both past and future values + >>> outputs = model( + ... past_values=batch["past_values"], + ... future_values=batch["future_values"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + # x: tensor [bs x sequence_length x num_input_channels] + scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask) + + # patched_values: [bs x num_input_channels x num_patches x patch_length] for pretrain + patched_values = self.patchifier(scaled_past_values) + if self.do_mask_input: + masked_values, mask = self.masking(patched_values) + else: + masked_values, mask = self.masking(patched_values), None + + encoder_output = self.encoder( + patch_input=masked_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + + if not return_dict: + outputs = (encoder_output.last_hidden_state, encoder_output.hidden_states, encoder_output.attentions) + outputs = outputs + (mask, loc, scale, patched_values) + return tuple(v for v in outputs if v is not None) + + return PatchTSTModelOutput( + last_hidden_state=encoder_output.last_hidden_state, + hidden_states=encoder_output.hidden_states, + attentions=encoder_output.attentions, + mask=mask, + loc=loc, + scale=scale, + patch_input=patched_values, + ) + + +class PatchTSTMaskPretrainHead(nn.Module): + """ + Pretraining head for mask modelling + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + self.linear = nn.Linear(config.d_model, config.patch_length) + self.use_cls_token = config.use_cls_token + + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True + + """ + embedding = self.linear(self.dropout(embedding)) # [bs x num_channels x num_patches x patch_length] + if self.use_cls_token: + embedding = embedding[:, :, 1:, :] # remove the first cls token + return embedding + + +@auto_docstring( + custom_intro=""" + The PatchTST for pretrain model. + """ +) +class PatchTSTForPretraining(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + config.do_mask_input = True + self.model = PatchTSTModel(config=config) + self.head = PatchTSTMaskPretrainHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForPretrainingOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import PatchTSTConfig, PatchTSTForPretraining + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> # Config for random mask pretraining + >>> config = PatchTSTConfig( + ... num_input_channels=7, + ... context_length=512, + ... patch_length=12, + ... stride=12, + ... mask_type='random', + ... random_mask_ratio=0.4, + ... use_cls_token=True, + ... ) + >>> # Config for forecast mask pretraining + >>> config = PatchTSTConfig( + ... num_input_channels=7, + ... context_length=512, + ... patch_length=12, + ... stride=12, + ... mask_type='forecast', + ... num_forecast_mask_patches=5, + ... use_cls_token=True, + ... ) + >>> model = PatchTSTForPretraining(config) + + >>> # during training, one provides both past and future values + >>> outputs = model(past_values=batch["past_values"]) + + >>> loss = outputs.loss + >>> loss.backward() + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_values: [bs x num_channels x num_patches x d_model] or + # [bs x num_channels x (num_patches+1) x d_model] if use cls_token + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + + # last_hidden_state: [bs x num_channels x num_patches x patch_length] or + # [bs x num_channels x (num_patches+1) x patch_length] if use cls_token + x_hat = self.head(model_output.last_hidden_state) + + # calculate masked_loss + loss = nn.MSELoss(reduction="none") + loss_val = loss(x_hat, model_output.patch_input) + masked_loss = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10) + + encoder_states = model_output.hidden_states + if not return_dict: + outputs = (x_hat,) + model_output[1:-4] + outputs = (masked_loss,) + outputs if masked_loss is not None else outputs + return outputs + return PatchTSTForPretrainingOutput( + loss=masked_loss, prediction_output=x_hat, hidden_states=encoder_states, attentions=model_output.attentions + ) + + +class PatchTSTClassificationHead(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + self.flatten = nn.Flatten(start_dim=1) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets) + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, num_targets)` + + """ + if self.use_cls_token: + # use the first output token, pooled_embedding: bs x num_channels x d_model + pooled_embedding = embedding[:, :, 0, :] + elif self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2).values + else: + raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet") + # pooled_embedding: bs x num_channels * d_model + pooled_embedding = self.flatten(pooled_embedding) + # output: bs x n_classes + output = self.linear(self.dropout(pooled_embedding)) + return output + + +@auto_docstring( + custom_intro=""" + The PatchTST for classification model. + """ +) +class PatchTSTForClassification(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + # Turn off masking + if config.do_mask_input: + logger.warning("Setting `do_mask_input` parameter to False.") + config.do_mask_input = False + + self.model = PatchTSTModel(config) + self.head = PatchTSTClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + past_values: torch.Tensor, + target_values: Optional[torch.Tensor] = None, + past_observed_mask: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForClassificationOutput]: + r""" + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + target_values (`torch.Tensor`, *optional*): + Labels associates with the `past_values` + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Examples: + + ```python + >>> from transformers import PatchTSTConfig, PatchTSTForClassification + + >>> # classification task with two input channel2 and 3 classes + >>> config = PatchTSTConfig( + ... num_input_channels=2, + ... num_targets=3, + ... context_length=512, + ... patch_length=12, + ... stride=12, + ... use_cls_token=True, + ... ) + >>> model = PatchTSTForClassification(config=config) + + >>> # during inference, one only provides past values + >>> past_values = torch.randn(20, 512, 2) + >>> outputs = model(past_values=past_values) + >>> labels = outputs.prediction_logits + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + y_hat = self.head(model_output.last_hidden_state) + + loss_val = None + if target_values is not None: + loss = nn.CrossEntropyLoss() + loss_val = loss(y_hat, target_values) + + if not return_dict: + outputs = (y_hat,) + model_output[1:-3] + outputs = (loss_val,) + outputs if loss_val is not None else outputs + return outputs + return PatchTSTForClassificationOutput( + loss=loss_val, + prediction_logits=y_hat, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The PatchTST for regression Model. + """ +) +class PatchTSTPredictionHead(nn.Module): + def __init__(self, config: PatchTSTConfig, num_patches: int, distribution_output=None): + r""" + num_patches (`int`): + The number of patches in the input sequence. + distribution_output (`DistributionOutput`, *optional*): + The distribution output layer for probabilistic forecasting. If None, a linear output layer is used. + """ + super().__init__() + + self.share_projection = config.share_projection + self.num_input_channels = config.num_input_channels + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + if self.pooling_type or self.use_cls_token: + head_dim = config.d_model + else: + head_dim = config.d_model * num_patches + + if not self.share_projection: + # if each channel has its own head + self.projections = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.num_input_channels): + self.flattens.append(nn.Flatten(start_dim=2)) + if distribution_output is None: + # use linear head + self.projections.append(nn.Linear(head_dim, config.prediction_length)) + else: + # use distribution head + self.projections.append(distribution_output.get_parameter_projection(head_dim)) + self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()) + else: + # all the channels share the same head + self.flatten = nn.Flatten(start_dim=2) + if distribution_output is None: + # use linear head + self.projection = nn.Linear(head_dim, config.prediction_length) + else: + # use distribution head + self.projection = distribution_output.get_parameter_projection(head_dim) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, forecast_len, num_channels)` + + """ + if self.use_cls_token: + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding[:, :, 0, :] + else: + if self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2).values + else: + # pooled_embedding: [bs x num_channels x num_patches x d_model] + pooled_embedding = embedding + + if not self.share_projection: + output = [] + for i in range(self.num_input_channels): + # pooled_embedding: [bs x (d_model * num_patches)] or [bs x d_model)] + pooled_embedding = self.flattens[i](pooled_embedding[:, i, :]) + pooled_embedding = self.dropouts[i](pooled_embedding) + # pooled_embedding: [bs x forecast_len] + # or tuple ([bs x forecast_len], [bs x forecast_len]) if using distribution head + pooled_embedding = self.projections[i](pooled_embedding) + output.append(pooled_embedding) + # output: [bs x num_channels x forecast_len] + output = torch.stack(output, dim=1) + else: + # pooled_embedding: [bs x num_channels x (d_model * num_patches)] or [bs x num_channels x d_model)] + pooled_embedding = self.flatten(pooled_embedding) + pooled_embedding = self.dropout(pooled_embedding) + # output: [bs x num_channels x forecast_len] or + # tuple ([bs x num_channels x forecast_len], [bs x num_channels x forecast_len]) if using distribution head + output = self.projection(pooled_embedding) + + if isinstance(output, tuple): + # output: ([bs x forecast_len x num_channels], [bs x forecast_len x num_channels]) + output = tuple(z.transpose(2, 1) for z in output) + else: + output = output.transpose(2, 1) # [bs x forecast_len x num_channels] + return output + + +@auto_docstring( + custom_intro=""" + The PatchTST for prediction model. + """ +) +class PatchTSTForPrediction(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + # Turn off masking + if config.do_mask_input: + logger.warning("Setting `do_mask_input` parameter to False.") + config.do_mask_input = False + + self.model = PatchTSTModel(config) + + if config.loss == "mse": + self.distribution_output = None + else: + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.prediction_length) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.prediction_length) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.head = PatchTSTPredictionHead( + config, self.model.patchifier.num_patches, distribution_output=self.distribution_output + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForPredictionOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*): + Future target values associated with the `past_values` + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): + Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import PatchTSTConfig, PatchTSTForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> # Prediction task with 7 input channels and prediction length is 96 + >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast") + + >>> # during training, one provides both past and future values + >>> outputs = model( + ... past_values=batch["past_values"], + ... future_values=batch["future_values"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values, the model outputs future values + >>> outputs = model(past_values=batch["past_values"]) + >>> prediction_outputs = outputs.prediction_outputs + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # get model output + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + # get output head + y_hat = self.head(model_output.last_hidden_state) + + loss_val = None + + if self.distribution_output: + y_hat_out = y_hat + else: + y_hat_out = y_hat * model_output.scale + model_output.loc + + if future_values is not None: + if self.distribution_output: + distribution = self.distribution_output.distribution( + y_hat, loc=model_output.loc, scale=model_output.scale + ) + loss_val = nll(distribution, future_values) + # take average of the loss + loss_val = weighted_average(loss_val) + else: + loss = nn.MSELoss(reduction="mean") + loss_val = loss(y_hat_out, future_values) + + loc = model_output.loc + scale = model_output.scale + + if not return_dict: + outputs = (y_hat_out,) + model_output[1:-1] + outputs = (loss_val,) + outputs if loss_val is not None else outputs + return outputs + return PatchTSTForPredictionOutput( + loss=loss_val, + prediction_outputs=y_hat_out, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + loc=loc, + scale=scale, + ) + + @torch.no_grad() + def generate( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + ) -> SamplePatchTSTOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the future. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Return: + [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of + samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)` + for multivariate predictions. + """ + # get number of samples + num_parallel_samples = self.config.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + future_values=None, + past_observed_mask=past_observed_mask, + output_hidden_states=False, + ) + if self.distribution_output: + # get distribution + distribution = self.distribution_output.distribution( + outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale + ) + # get samples: list of [bs x forecast_len x num_channels] + samples = [distribution.sample() for _ in range(num_parallel_samples)] + # samples: [bs x num_samples x forecast_len x num_channels] + samples = torch.stack(samples, dim=1) + else: + samples = outputs.prediction_outputs.unsqueeze(1) + + return SamplePatchTSTOutput(sequences=samples) + + +class PatchTSTRegressionHead(nn.Module): + """ + Regression head + """ + + def __init__(self, config: PatchTSTConfig, distribution_output=None): + super().__init__() + self.y_range = config.output_range + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + self.distribution_output = distribution_output + + head_dim = config.num_input_channels * config.d_model + + self.flatten = nn.Flatten(start_dim=1) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + + if distribution_output is None: + self.projection = nn.Linear(head_dim, config.num_targets) + else: + self.projection = distribution_output.get_parameter_projection(head_dim) + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, output_dim)` + + """ + if self.use_cls_token: + # use the first output token, pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding[:, :, 0, :] + elif self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2).values + else: + raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet") + # flatten the input + # pooled_embedding: bs x (num_channels * d_model) + pooled_embedding = self.dropout(self.flatten(pooled_embedding)) + # projection + # output: bs x output_dim or a tuple of this shape for distribution head + output = self.projection(pooled_embedding) + # apply sigmoid to bound the output if required + if (self.distribution_output is None) & (self.y_range is not None): # linear head + output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0] + return output + + +@auto_docstring( + custom_intro=""" + The PatchTST for regression model. + """ +) +class PatchTSTForRegression(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + # Turn off masking + if config.do_mask_input: + logger.warning("Setting `do_mask_input` parameter to False.") + config.do_mask_input = False + + self.model = PatchTSTModel(config) + if config.loss == "mse": + self.distribution_output = None + else: + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.num_targets) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.num_targets) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.num_targets) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.head = PatchTSTRegressionHead(config, self.distribution_output) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + past_values: torch.Tensor, + target_values: Optional[torch.Tensor] = None, + past_observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForRegressionOutput]: + r""" + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + target_values (`torch.Tensor` of shape `(bs, num_input_channels)`): + Target values associates with the `past_values` + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + Whether or not to return a `ModelOutput` instead of a plain tuple. + + Examples: + + ```python + >>> from transformers import PatchTSTConfig, PatchTSTForRegression + + >>> # Regression task with 6 input channels and regress 2 targets + >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression") + + >>> # during inference, one only provides past values, the model outputs future values + >>> past_values = torch.randn(20, 512, 6) + >>> outputs = model(past_values=past_values) + >>> regression_outputs = outputs.regression_outputs + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + # get output head. y_hat is of shape [bs x num_targets] or tuple of this shape + y_hat = self.head(model_output.last_hidden_state) + + loss = None + if target_values is not None: + if self.distribution_output: + distribution = self.distribution_output.distribution(y_hat) + # y_hat should be a 2-tuple, each with dimension [bs, num_targets] + y_hat = tuple(item.view(-1, self.config.num_targets) for item in y_hat) + loss = nll(distribution, target_values) + # take average of the loss + loss = weighted_average(loss) + else: + loss = nn.MSELoss(reduction="mean") + loss = loss(y_hat, target_values) + + if not return_dict: + # hidden_states, attentions, mask + outputs = (y_hat,) + model_output[1:-3] + outputs = (loss,) + outputs if loss is not None else outputs + return outputs + return PatchTSTForRegressionOutput( + loss=loss, + regression_outputs=y_hat, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + ) + + @torch.no_grad() + def generate( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + ) -> SamplePatchTSTOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the future. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Return: + [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of + samples, num_targets)`. + """ + # get number of samples + num_parallel_samples = self.config.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + target_values=None, + past_observed_mask=past_observed_mask, + output_hidden_states=False, + ) + + # get distribution + distribution = self.distribution_output.distribution(outputs.regression_outputs) + # get samples: list of [bs x num_targets] + samples = [distribution.sample() for _ in range(num_parallel_samples)] + # samples: [bs x num_samples x num_targets] + samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets) + return SamplePatchTSTOutput(sequences=samples) + + +__all__ = [ + "PatchTSTModel", + "PatchTSTPreTrainedModel", + "PatchTSTForPrediction", + "PatchTSTForPretraining", + "PatchTSTForRegression", + "PatchTSTForClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pegasus/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/pegasus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4903c400f9827239f7920b7e2c9bfddfda48eecd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pegasus/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pegasus import * + from .modeling_flax_pegasus import * + from .modeling_pegasus import * + from .modeling_tf_pegasus import * + from .tokenization_pegasus import * + from .tokenization_pegasus_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/pegasus/configuration_pegasus.py b/venv/lib/python3.13/site-packages/transformers/models/pegasus/configuration_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..3c27f7d44d952d769ae362b000f87b3dd6a1cfd5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pegasus/configuration_pegasus.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PEGASUS model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PegasusConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PegasusModel`]. It is used to instantiate an + PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PEGASUS + [google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the PEGASUS model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PegasusModel`] or [`TFPegasusModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 1): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import PegasusConfig, PegasusModel + + >>> # Initializing a PEGASUS google/pegasus-large style configuration + >>> configuration = PegasusConfig() + + >>> # Initializing a model (with random weights) from the google/pegasus-large style configuration + >>> model = PegasusModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pegasus" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=50265, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=0, + scale_embedding=False, + pad_token_id=0, + eos_token_id=1, + forced_eos_token_id=1, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + +__all__ = ["PegasusConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py b/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf0ae492407c17bbc678b82b3d19e34da7d8a75 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py @@ -0,0 +1,1532 @@ +# coding=utf-8 +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax PEGASUS model.""" + +import math +import random +from functools import partial +from typing import Callable, Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + add_start_docstrings_to_model_forward, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, logging, replace_return_docstrings +from .configuration_pegasus import PegasusConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/pegasus-large" +_CONFIG_FOR_DOC = "PegasusConfig" + +PEGASUS_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`PegasusConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +PEGASUS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +PEGASUS_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +PEGASUS_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions +def create_sinusoidal_positions(n_pos, dim): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + sentinel = dim // 2 + dim % 2 + out = np.zeros_like(position_enc) + out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) + out[:, sentinel:] = np.cos(position_enc[:, 1::2]) + + return jnp.array(out) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus +class FlaxPegasusAttention(nn.Module): + config: PegasusConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Pegasus +class FlaxPegasusEncoderLayer(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Pegasus +class FlaxPegasusEncoderLayerCollection(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Pegasus +class FlaxPegasusDecoderLayer(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.encoder_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.decoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Pegasus +class FlaxPegasusDecoderLayerCollection(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxPegasusEncoder(nn.Module): + config: PegasusConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_source_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 + + self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) + self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + embed_pos = jnp.take(self.embed_positions, position_ids, axis=0) + # explicitly cast the positions here, since self.embed_positions are not registered as parameters + embed_pos = embed_pos.astype(inputs_embeds.dtype) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + outputs = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + last_hidden_state = self.layer_norm(last_hidden_state) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_state,) + + if not return_dict: + outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxPegasusDecoder(nn.Module): + config: PegasusConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 + + self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) + + self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + positions = jnp.take(self.embed_positions, position_ids, axis=0) + # explicitly cast the positions here, since self.embed_positions are not registered as parameters + positions = positions.astype(inputs_embeds.dtype) + + hidden_states = inputs_embeds + positions + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + outputs = self.layers( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + last_hidden_state = self.layer_norm(last_hidden_state) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_state,) + + if not return_dict: + outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Pegasus +class FlaxPegasusModule(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): + config_class = PegasusConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: PegasusConfig, + input_shape: tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = input_ids + decoder_attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(PEGASUS_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.", + PEGASUS_START_DOCSTRING, +) +class FlaxPegasusModel(FlaxPegasusPreTrainedModel): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxPegasusModule + + +append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Pegasus +class FlaxPegasusForConditionalGenerationModule(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.model.shared.num_embeddings, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["shared"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING +) +class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel): + module_class = FlaxPegasusForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + deterministic: bool = True, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.variables["params"]["shared"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + lm_logits += module.final_logits_bias.astype(self.dtype) + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') + >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids']).sequences + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + >>> TXT = "My friends are but they eat too many carbs." + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) + >>> values, predictions = jax.lax.top_k(probs) + + >>> tokenizer.decode(predictions).split() + ``` +""" + +overwrite_call_docstring( + FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +__all__ = ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_pegasus.py b/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3a8005acac814393789cfa6146f413af2863bd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_pegasus.py @@ -0,0 +1,1671 @@ +# coding=utf-8 +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PEGASUS model.""" + +import copy +import math +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_pegasus import PegasusConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus +class PegasusSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: + super().__init__(num_positions, embedding_dim) + + def _init_weight(self): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = self.weight.shape + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + self.weight = nn.Parameter(out, requires_grad=False) + + @torch.no_grad() + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """`input_ids_shape` is expected to be [bsz x seqlen].""" + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) + + +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus +class PegasusAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PegasusConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS +class PegasusEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PegasusConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PegasusAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states, attn_weights + + +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS +class PegasusDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PegasusAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PegasusAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@auto_docstring +class PegasusPreTrainedModel(PreTrainedModel): + config: PegasusConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, PegasusSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class PegasusEncoder(PegasusPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`PegasusEncoderLayer`]. + + Args: + config: PegasusConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + self.padding_idx, + ) + self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + self.config.max_position_embeddings, + self.config.d_model, + self.padding_idx, + ) + self.embed_positions._init_weight() + self.embed_positions.to(self.device) + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.embed_positions + + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class PegasusDecoder(PegasusPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`] + + Args: + config: PegasusConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + ) + self.layers = nn.ModuleList([PegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + self.config.max_position_embeddings, + self.config.d_model, + self.padding_idx, + ) + self.embed_positions._init_weight() + self.embed_positions.to(self.device) + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.embed_positions + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + # important to apply scale outside of `if` in case users pass `embeds` + inputs_embeds = inputs_embeds * self.embed_scale + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + # embed positions + positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=cache_position) + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class PegasusModel(PegasusPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PegasusConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = PegasusEncoder(config, self.shared) + self.decoder = PegasusDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.encoder.resize_position_embeddings(new_num_position_embeddings) + self.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings()) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PegasusModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + >>> model = PegasusModel.from_pretrained("google/pegasus-large") + + >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt") + >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt") + >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 4, 1024] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The PEGASUS Model with a language modeling head. Can be used for summarization. + """ +) +class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PegasusConfig): + super().__init__(config) + self.model = PegasusModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.model.encoder.resize_position_embeddings(new_num_position_embeddings) + self.model.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings()) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example Summarization: + + ```python + >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration + + >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "California's largest electricity provider has turned off power to hundreds of thousands of customers." + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus +class PegasusDecoderWrapper(PegasusPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PegasusDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = PegasusDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.model.decoder.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.model.decoder.resize_position_embeddings(new_num_position_embeddings) + + @auto_docstring + # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PegasusForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + >>> model = PegasusForCausalLM.from_pretrained("google/pegasus-large", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +__all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py b/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..d159fc00138dcb6ca2f14213857b6b21e97c6865 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py @@ -0,0 +1,1573 @@ +# coding=utf-8 +# Copyright 2021, Google Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Pegasus model.""" + +from __future__ import annotations + +import random + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) + +# Public API +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_pegasus import PegasusConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/pegasus-large" +_CONFIG_FOR_DOC = "PegasusConfig" + + +LARGE_NEGATIVE = -1e8 + + +# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + start_tokens = tf.fill( + (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) + ) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE + mask_cond = tf.range(shape_list(mask)[-1]) + + mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus +class TFPegasusSinusoidalPositionalEmbedding(keras.layers.Layer): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, **kwargs): + super().__init__(**kwargs) + + if embedding_dim % 2 != 0: + raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") + + self.embedding_dim = embedding_dim + self.num_positions = num_positions + + def build(self, input_shape: tf.TensorShape): + """ + Build shared token embedding layer Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + + weight = self._init_weight(self.num_positions, self.embedding_dim) + + self.weight = self.add_weight( + name="embeddings", + shape=[self.num_positions, self.embedding_dim], + ) + weight = tf.cast(weight, dtype=self.weight.dtype) + + self.weight.assign(weight) + + super().build(input_shape) + + @staticmethod + def _init_weight(n_pos: int, dim: int): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + table = np.zeros_like(position_enc) + # index 0 is all zero + table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) + table[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) + # convert to tensor + table = tf.convert_to_tensor(table) + tf.stop_gradient(table) + return table + + def call( + self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None + ): + """Input is expected to be of size [bsz x seqlen].""" + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return tf.gather(self.weight, position_ids) + + +# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus +class TFPegasusAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: tuple[tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: bool | None = False, + ) -> tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Pegasus +class TFPegasusEncoderLayer(keras.layers.Layer): + def __init__(self, config: PegasusConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFPegasusAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" + ) + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + layer_head_mask: tf.Tensor, + training: bool | None = False, + ): + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(encoder_attention_heads,)* + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask + ) + + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return hidden_states, self_attn_weights + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.encoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Pegasus +class TFPegasusDecoderLayer(keras.layers.Layer): + def __init__(self, config: PegasusConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFPegasusAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.encoder_attn = TFPegasusAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + name="encoder_attn", + is_decoder=True, + ) + self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + cross_attn_layer_head_mask: tf.Tensor | None = None, + past_key_value: tuple[tf.Tensor] | None = None, + training: bool | None = False, + ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + encoder_hidden_states (`tf.Tensor`): + cross attention input to the layer of shape *(batch, seq_len, embed_dim)* + encoder_attention_mask (`tf.Tensor`): encoder attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(decoder_attention_heads,)* + cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. + *(decoder_attention_heads,)* + past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return ( + hidden_states, + self_attn_weights, + cross_attn_weights, + present_key_value, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "encoder_attn", None) is not None: + with tf.name_scope(self.encoder_attn.name): + self.encoder_attn.build(None) + if getattr(self, "encoder_attn_layer_norm", None) is not None: + with tf.name_scope(self.encoder_attn_layer_norm.name): + self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.decoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFPegasusPreTrainedModel(TFPreTrainedModel): + config_class = PegasusConfig + base_model_prefix = "model" + + +PEGASUS_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`PegasusConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PEGASUS_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, TFPegasusForConditionalGeneration + + >>> model = TFPegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf") + + >>> # Generate Summary + >>> summary_ids = model.generate(input_ids) + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + +PEGASUS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tf.FloatTensor`, *optional*): + hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + of shape `(batch_size, sequence_length, hidden_size)` is a sequence of + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation output_attentions (`bool`, + *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` + under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the + value in the config will be used instead. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFPegasusEncoder(keras.layers.Layer): + config_class = PegasusConfig + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`TFPegasusEncoderLayer`]. + + Args: + config: PegasusConfig + """ + + def __init__(self, config: PegasusConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.dropout = keras.layers.Dropout(config.dropout) + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + + self.embed_tokens = embed_tokens + self.embed_positions = TFPegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ): + """ + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value + in the config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. This argument can be used only in eager mode, in graph mode the value in the config + will be used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout(hidden_states, training=training) + + # check attention mask and invert + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask) + else: + attention_mask = None + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + tf.debugging.assert_equal( + shape_list(head_mask)[0], + len(self.layers), + message=( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(head_mask)[0]}." + ), + ) + + # encoder layers + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if training and (dropout_probability < self.layerdrop): # skip the layer + continue + + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + ) + + if output_attentions: + all_attentions += (attn,) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFPegasusDecoder(keras.layers.Layer): + config_class = PegasusConfig + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFPegasusDecoderLayer`] + + Args: + config: PegasusConfig + embed_tokens: output embedding + """ + + def __init__(self, config: PegasusConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.embed_tokens = embed_tokens + self.layerdrop = config.decoder_layerdrop + self.embed_positions = TFPegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + self.layers = [TFPegasusDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + self.dropout = keras.layers.Dropout(config.dropout) + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + cross_attn_head_mask: tf.Tensor | None = None, + past_key_values: tuple[tuple[tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ): + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value + in the config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. This argument can be used only in eager mode, in graph mode the value in the config + will be used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + # embed positions + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + hidden_states = inputs_embeds + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] + ) + + if attention_mask is not None: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) + + hidden_states = self.dropout(hidden_states + positions, training=training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + present_key_values = () if use_cache else None + + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: + if attn_mask is not None: + tf.debugging.assert_equal( + shape_list(attn_mask)[0], + len(self.layers), + message=( + f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(attn_mask)[0]}." + ), + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + + if training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + past_key_value=past_key_value, + ) + + if use_cache: + present_key_values += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + + if encoder_hidden_states is not None: + all_cross_attns += (layer_cross_attn,) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns + else: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFPegasusMainLayer(keras.layers.Layer): + config_class = PegasusConfig + + def __init__(self, config: PegasusConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.shared = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.d_model, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), + name="model.shared", + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "model.shared" + + self.encoder = TFPegasusEncoder(config, self.shared, name="encoder") + self.decoder = TFPegasusDecoder(config, self.shared, name="decoder") + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + decoder_input_ids: tf.Tensor | None = None, + decoder_attention_mask: tf.Tensor | None = None, + decoder_position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + decoder_head_mask: tf.Tensor | None = None, + cross_attn_head_mask: tf.Tensor | None = None, + encoder_outputs: tuple | TFBaseModelOutput | None = None, + past_key_values: tuple[tuple[tf.Tensor]] | None = None, + inputs_embeds: tf.Tensor | None = None, + decoder_inputs_embeds: tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + **kwargs, + ): + if decoder_input_ids is None and decoder_inputs_embeds is None: + use_cache = False + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.", + PEGASUS_START_DOCSTRING, +) +class TFPegasusModel(TFPegasusPreTrainedModel): + def __init__(self, config: PegasusConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.model = TFPegasusMainLayer(config, name="model") + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSeq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: tuple | TFBaseModelOutput | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> TFSeq2SeqModelOutput | tuple[tf.Tensor]: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqModelOutput( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer +class BiasLayer(keras.layers.Layer): + """ + Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, + so all weights have to be registered in a layer. + """ + + def __init__(self, shape, initializer, trainable, name, **kwargs): + super().__init__(name=name, **kwargs) + # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of + # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: + # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 + self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) + + def call(self, x): + return x + self.bias + + +@add_start_docstrings( + "The PEGASUS Model with a language modeling head. Can be used for summarization.", + PEGASUS_START_DOCSTRING, +) +class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss): + _keys_to_ignore_on_load_unexpected = [ + r"model.encoder.embed_tokens.weight", + r"model.decoder.embed_tokens.weight", + ] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFPegasusMainLayer(config, name="model") + self.use_cache = config.use_cache + # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False + ) + + def get_decoder(self): + return self.model.decoder + + def get_encoder(self): + return self.model.encoder + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + def get_bias(self): + return {"final_logits_bias": self.bias_layer.bias} + + def set_bias(self, value): + # Replaces the existing layers containing bias for correct (de)serialization. + vocab_size = value["final_logits_bias"].shape[-1] + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False + ) + self.bias_layer.bias.assign(value["final_logits_bias"]) + + @unpack_inputs + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: TFBaseModelOutput | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: + """ + labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + + if labels is not None: + labels = tf.where( + labels == self.config.pad_token_id, + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), + labels, + ) + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) + lm_logits = self.bias_layer(lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return TFSeq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, # index 1 of d outputs + decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs + decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs + encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs + encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out + encoder_attentions=outputs.encoder_attentions, # 2 of e out + ) + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past_key_values is not None: # no xla + past_key_values + decoder_position_ids = past_key_values[0][0].shape[2] + else: # no xla + no past_key_values + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "bias_layer", None) is not None: + with tf.name_scope(self.bias_layer.name): + self.bias_layer.build(None) + + +__all__ = ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pegasus/tokenization_pegasus.py b/venv/lib/python3.13/site-packages/transformers/models/pegasus/tokenization_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a4a1c737d1592d9362f506e93b20854edfe7a5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pegasus/tokenization_pegasus.py @@ -0,0 +1,292 @@ +# coding=utf-8 +# Copyright 2020 Google and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +logger = logging.get_logger(__name__) + + +# TODO ArthurZ refactor this to only use the added_tokens_encoder + + +@requires(backends=("sentencepiece",)) +class PegasusTokenizer(PreTrainedTokenizer): + r""" + Construct a PEGASUS tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking single token values. This is the token used when training this model with masked + language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining. + It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive + Summarization](https://huggingface.co/papers/1912.08777). + mask_token_sent (`str`, *optional*, defaults to `""`): + The token used for masking whole target sentences. This is the token used when training this model with gap + sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during + pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for + Abstractive Summarization](https://huggingface.co/papers/1912.08777). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and + are used as additional special tokens corresponding to the [original PEGASUS + tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66) + that uses the tokens 2 - 104 only for pretraining + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + pad_token="", + eos_token="", + unk_token="", + mask_token="", + mask_token_sent="", + additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + self.offset = offset + if additional_special_tokens is not None: + if not isinstance(additional_special_tokens, list): + raise TypeError( + f"additional_special_tokens should be of type {type(list)}, but is" + f" {type(additional_special_tokens)}" + ) + additional_special_tokens_extended = ( + ([mask_token_sent] + additional_special_tokens) + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None + else additional_special_tokens + ) + # fill additional tokens with ..., in case not all additional tokens are already taken + additional_special_tokens_extended += [ + f"" for i in range(len(additional_special_tokens_extended), self.offset - 1) + ] + + if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended): + raise ValueError( + "Please make sure that the provided additional_special_tokens do not contain an incorrectly" + f" shifted list of tokens. Found {additional_special_tokens_extended}." + ) + additional_special_tokens = additional_special_tokens_extended + else: + additional_special_tokens_extended = [] + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] + additional_special_tokens += [f"" for i in range(2, self.offset)] + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.mask_token_sent = mask_token_sent + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + _added_tokens_decoder = { + 0: AddedToken(str(pad_token), special=True), + 1: AddedToken(str(eos_token), special=True), + } + + if self.mask_token_sent is not None: + _added_tokens_decoder[2] = AddedToken(mask_token_sent, special=True) + _added_tokens_decoder[3] = AddedToken(str(mask_token), special=True) + + for i in range(2, self.offset): + _added_tokens_decoder[len(_added_tokens_decoder)] = AddedToken(f"", special=True) + + # Force update as we want to make sure vocab is enforced (same as fast) + self._added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) + self._added_tokens_decoder.update(_added_tokens_decoder) + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + mask_token=mask_token, + pad_token=pad_token, + mask_token_sent=mask_token_sent, + offset=offset, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.sp_model) + self.offset + + def get_vocab(self) -> dict[str, int]: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) to an id using the vocab.""" + sp_id = self.sp_model.piece_to_id(token) + return sp_id + self.offset + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) to a token (str) using the vocab.""" + if index < self.offset: + return self.sp_model.IdToPiece(index) + token = self.sp_model.IdToPiece(index - self.offset) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def num_special_tokens_to_add(self, pair=False): + """Just EOS""" + return 1 + + def _special_token_mask(self, seq): + all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp + all_special_ids.remove(self.unk_token_id) # is only sometimes special + + return [1 if x in all_special_ids else 0 for x in seq] + + def get_special_tokens_mask( + self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """Get list where entries are [1] if a token is [eos] or [pad] else 0.""" + if already_has_special_tokens: + return self._special_token_mask(token_ids_0) + elif token_ids_1 is None: + return self._special_token_mask(token_ids_0) + [1] + else: + return self._special_token_mask(token_ids_0 + token_ids_1) + [1] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating + and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence: + + - single sequence: `X ` + - pair of sequences: `A B ` (not intended use) + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["PegasusTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py b/venv/lib/python3.13/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..92a37c44ff2e302fe1e2a56849f2e91fa481ec9b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2020 Google and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model PEGASUS.""" + +import os +from shutil import copyfile +from typing import Optional + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_pegasus import PegasusTokenizer +else: + PegasusTokenizer = None + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +class PegasusTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" PEGASUS tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking single token values. This is the token used when training this model with masked + language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining. + It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive + Summarization](https://huggingface.co/papers/1912.08777). + mask_token_sent (`str`, *optional*, defaults to `""`): + The token used for masking whole target sentences. This is the token used when training this model with gap + sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during + pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for + Abstractive Summarization](https://huggingface.co/papers/1912.08777). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and + are used as additional special tokens corresponding to the [original PEGASUS + tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66) + that uses the tokens 2 - 104 only for pretraining + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = PegasusTokenizer + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + pad_token="", + eos_token="", + unk_token="", + mask_token="", + mask_token_sent="", + additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining + **kwargs, + ): + self.offset = offset + + if additional_special_tokens is not None: + if not isinstance(additional_special_tokens, list): + raise TypeError( + f"additional_special_tokens should be of type {type(list)}, but is" + f" {type(additional_special_tokens)}" + ) + + additional_special_tokens_extended = ( + ([mask_token_sent] + additional_special_tokens) + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None + else additional_special_tokens + ) + # fill additional tokens with ..., in case not all additional tokens are already taken + additional_special_tokens_extended += [ + f"" for i in range(len(additional_special_tokens_extended), self.offset - 1) + ] + + if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended): + raise ValueError( + "Please make sure that the provided additional_special_tokens do not contain an incorrectly" + f" shifted list of tokens. Found {additional_special_tokens_extended}." + ) + additional_special_tokens = additional_special_tokens_extended + else: + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] + additional_special_tokens += [f"" for i in range(2, self.offset)] + + # pegasus was design to support changing the index of the first tokens. If one of the padding/eos/unk/mask token + # is different from default, we must rebuild the vocab + from_slow = kwargs.pop("from_slow", None) + from_slow = from_slow or str(pad_token) != "" or str(eos_token) != "" or str(unk_token) != "" + + kwargs.pop("added_tokens_decoder", {}) + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + pad_token=pad_token, + eos_token=eos_token, + unk_token=unk_token, + mask_token=mask_token, + mask_token_sent=mask_token_sent, + offset=offset, + additional_special_tokens=additional_special_tokens, + from_slow=from_slow, + **kwargs, + ) + self.vocab_file = vocab_file + + def _special_token_mask(self, seq): + all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp + all_special_ids.remove(self.unk_token_id) # is only sometimes special + + if all_special_ids != set(range(len(self.additional_special_tokens) + 3)): + raise ValueError( + "There should be 3 special tokens: mask_token, pad_token, and eos_token +" + f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}" + ) + + return [1 if x in all_special_ids else 0 for x in seq] + + def get_special_tokens_mask( + self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """Get list where entries are [1] if a token is [eos] or [pad] else 0.""" + if already_has_special_tokens: + return self._special_token_mask(token_ids_0) + elif token_ids_1 is None: + return self._special_token_mask(token_ids_0) + [1] + else: + return self._special_token_mask(token_ids_0 + token_ids_1) + [1] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]: + """ + Build model inputs from a sequence by adding eos to the end. no bos token is added to the front. + + - single sequence: `X ` + - pair of sequences: `A B ` (not intended use) + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["PegasusTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/perceiver/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/perceiver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da8d30a5a9b54b27afb74fba377024c2c95436f6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/perceiver/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_perceiver import * + from .feature_extraction_perceiver import * + from .image_processing_perceiver import * + from .image_processing_perceiver_fast import * + from .modeling_perceiver import * + from .tokenization_perceiver import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/perceiver/configuration_perceiver.py b/venv/lib/python3.13/site-packages/transformers/models/perceiver/configuration_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..91e7bcd58fdc2bccdc83bf13fa34ce735311c72b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/perceiver/configuration_perceiver.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright Deepmind and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Perceiver model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import Any, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import FeatureExtractionMixin +from ...onnx import OnnxConfig +from ...onnx.utils import compute_effective_axis_dimension +from ...tokenization_utils_base import PreTrainedTokenizerBase +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class PerceiverConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PerceiverModel`]. It is used to instantiate an + Perceiver model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Perceiver + [deepmind/language-perceiver](https://huggingface.co/deepmind/language-perceiver) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_latents (`int`, *optional*, defaults to 256): + The number of latents. + d_latents (`int`, *optional*, defaults to 1280): + Dimension of the latent embeddings. + d_model (`int`, *optional*, defaults to 768): + Dimension of the inputs. Should only be provided in case [*PerceiverTextPreprocessor*] is used or no + preprocessor is provided. + num_blocks (`int`, *optional*, defaults to 1): + Number of blocks in the Transformer encoder. + num_self_attends_per_block (`int`, *optional*, defaults to 26): + The number of self-attention layers per block. + num_self_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each self-attention layer in the Transformer encoder. + num_cross_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each cross-attention layer in the Transformer encoder. + qk_channels (`int`, *optional*): + Dimension to project the queries + keys before applying attention in the cross-attention and self-attention + layers of the encoder. Will default to preserving the dimension of the queries if not specified. + v_channels (`int`, *optional*): + Dimension to project the values before applying attention in the cross-attention and self-attention layers + of the encoder. Will default to preserving the dimension of the queries if not specified. + cross_attention_shape_for_attention (`str`, *optional*, defaults to `"kv"`): + Dimension to use when downsampling the queries and keys in the cross-attention layer of the encoder. + self_attention_widening_factor (`int`, *optional*, defaults to 1): + Dimension of the feed-forward layer in the cross-attention layer of the Transformer encoder. + cross_attention_widening_factor (`int`, *optional*, defaults to 1): + Dimension of the feed-forward layer in the self-attention layers of the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + use_query_residual (`float`, *optional*, defaults to `True`): + Whether to add a query residual in the cross-attention layer of the encoder. + vocab_size (`int`, *optional*, defaults to 262): + Vocabulary size for the masked language modeling model. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that the masked language modeling model might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + image_size (`int`, *optional*, defaults to 56): + Size of the images after preprocessing, for [`PerceiverForImageClassificationLearned`]. + train_size (`list[int]`, *optional*, defaults to `[368, 496]`): + Training size of the images for the optical flow model. + num_frames (`int`, *optional*, defaults to 16): + Number of video frames used for the multimodal autoencoding model. + audio_samples_per_frame (`int`, *optional*, defaults to 1920): + Number of audio samples per frame for the multimodal autoencoding model. + samples_per_patch (`int`, *optional*, defaults to 16): + Number of audio samples per patch when preprocessing the audio for the multimodal autoencoding model. + output_shape (`list[int]`, *optional*, defaults to `[1, 16, 224, 224]`): + Shape of the output (batch_size, num_frames, height, width) for the video decoder queries of the multimodal + autoencoding model. This excludes the channel dimension. + output_num_channels (`int`, *optional*, defaults to 512): + Number of output channels for each modalitiy decoder. + + Example: + + ```python + >>> from transformers import PerceiverModel, PerceiverConfig + + >>> # Initializing a Perceiver deepmind/language-perceiver style configuration + >>> configuration = PerceiverConfig() + + >>> # Initializing a model from the deepmind/language-perceiver style configuration + >>> model = PerceiverModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "perceiver" + + def __init__( + self, + num_latents=256, + d_latents=1280, + d_model=768, + num_blocks=1, + num_self_attends_per_block=26, + num_self_attention_heads=8, + num_cross_attention_heads=8, + qk_channels=None, + v_channels=None, + cross_attention_shape_for_attention="kv", + self_attention_widening_factor=1, + cross_attention_widening_factor=1, + hidden_act="gelu", + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_query_residual=True, + vocab_size=262, + max_position_embeddings=2048, + image_size=56, + train_size=[368, 496], + num_frames=16, + audio_samples_per_frame=1920, + samples_per_patch=16, + output_shape=[1, 16, 224, 224], + output_num_channels=512, + _label_trainable_num_channels=1024, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_latents = num_latents + self.d_latents = d_latents + self.d_model = d_model + self.num_blocks = num_blocks + self.num_self_attends_per_block = num_self_attends_per_block + self.num_self_attention_heads = num_self_attention_heads + self.num_cross_attention_heads = num_cross_attention_heads + self.qk_channels = qk_channels + self.v_channels = v_channels + self.cross_attention_shape_for_attention = cross_attention_shape_for_attention + self.self_attention_widening_factor = self_attention_widening_factor + self.cross_attention_widening_factor = cross_attention_widening_factor + self.hidden_act = hidden_act + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_query_residual = use_query_residual + # masked language modeling attributes + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + # image classification attributes + self.image_size = image_size + # flow attributes + self.train_size = train_size + # multimodal autoencoding attributes + self.num_frames = num_frames + self.audio_samples_per_frame = audio_samples_per_frame + self.samples_per_patch = samples_per_patch + self.output_shape = output_shape + self.output_num_channels = output_num_channels + self._label_trainable_num_channels = _label_trainable_num_channels + + +class PerceiverOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("inputs", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], + batch_size: int = -1, + seq_length: int = -1, + num_choices: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + num_channels: int = 3, + image_width: int = 40, + image_height: int = 40, + ) -> Mapping[str, Any]: + # copied from `transformers.onnx.config.OnnxConfig` and slightly altered/simplified + + if isinstance(preprocessor, PreTrainedTokenizerBase): + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 + ) + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = preprocessor.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add + ) + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join(["a"]) * seq_length] * batch_size + inputs = dict(preprocessor(dummy_input, return_tensors=framework)) + inputs["inputs"] = inputs.pop("input_ids") + return inputs + elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) + dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) + inputs = dict(preprocessor(images=dummy_input, return_tensors=framework)) + inputs["inputs"] = inputs.pop("pixel_values") + return inputs + else: + raise ValueError( + "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor." + ) + + +__all__ = ["PerceiverConfig", "PerceiverOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/perceiver/feature_extraction_perceiver.py b/venv/lib/python3.13/site-packages/transformers/models/perceiver/feature_extraction_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..566cb31fea78a52c0c7cfaa3e509437c55dab578 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/perceiver/feature_extraction_perceiver.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Perceiver.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_perceiver import PerceiverImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class PerceiverFeatureExtractor(PerceiverImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class PerceiverFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use PerceiverImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["PerceiverFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/perceiver/image_processing_perceiver.py b/venv/lib/python3.13/site-packages/transformers/models/perceiver/image_processing_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..c66d7b51d463c250ad0922211691ca7a6d3a6e72 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/perceiver/image_processing_perceiver.py @@ -0,0 +1,353 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Perceiver.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import center_crop, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class PerceiverImageProcessor(BaseImageProcessor): + r""" + Constructs a Perceiver image processor. + + Args: + do_center_crop (`bool`, `optional`, defaults to `True`): + Whether or not to center crop the image. If the input size if smaller than `crop_size` along any edge, the + image will be padded with zeros and then center cropped. Can be overridden by the `do_center_crop` + parameter in the `preprocess` method. + crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`): + Desired output size when applying center-cropping. Can be overridden by the `crop_size` parameter in the + `preprocess` method. + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image to `(size["height"], size["width"])`. Can be overridden by the `do_resize` + parameter in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter + in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter + in the `preprocess` method. + do_normalize: + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} + crop_size = get_size_dict(crop_size, param_name="crop_size") + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size) + + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + def center_crop( + self, + image: np.ndarray, + crop_size: dict[str, int], + size: Optional[int] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] * + min_dim)`. Where `min_dim = min(size["height"], size["width"])`. + + If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then + center cropped. + + Args: + image (`np.ndarray`): + Image to center crop. + crop_size (`dict[str, int]`): + Desired output size after applying the center crop. + size (`dict[str, int]`, *optional*): + Size of the image after resizing. If not provided, the self.size attribute will be used. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = self.size if size is None else size + size = get_size_dict(size) + crop_size = get_size_dict(crop_size, param_name="crop_size") + + height, width = get_image_size(image, channel_dim=input_data_format) + min_dim = min(height, width) + cropped_height = (size["height"] / crop_size["height"]) * min_dim + cropped_width = (size["width"] / crop_size["width"]) * min_dim + return center_crop( + image, + size=(cropped_height, cropped_width), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image to `crop_size`. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Desired output size after applying the center crop. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_center_crop: + images = [ + self.center_crop(image, crop_size, size=size, input_data_format=input_data_format) for image in images + ] + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["PerceiverImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/perceiver/image_processing_perceiver_fast.py b/venv/lib/python3.13/site-packages/transformers/models/perceiver/image_processing_perceiver_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..72cb17cd40cdf151a8e4182fd19e7337ddbbdb3f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/perceiver/image_processing_perceiver_fast.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Perceiver.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling, SizeDict +from ...utils import ( + TensorType, + auto_docstring, +) + + +@auto_docstring +class PerceiverImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 224, "width": 224} + crop_size = {"height": 256, "width": 256} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + + def center_crop( + self, + image: "torch.Tensor", + crop_size: dict[str, int], + size: dict[str, int], + **kwargs, + ) -> "torch.Tensor": + """ + Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] * + min_dim)`. Where `min_dim = min(size["height"], size["width"])`. + + If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then + center cropped. + + Args: + image (`"torch.Tensor"`): + Image to center crop. + crop_size (`dict[str, int]`): + Desired output size after applying the center crop. + size (`dict[str, int]`): + Size of the output image. + + Returns: + `torch.Tensor`: The center cropped image. + """ + if size.height is None or size.width is None: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + height, width = image.shape[-2:] + min_dim = min(height, width) + cropped_height = int((size.height / crop_size.height) * min_dim) + cropped_width = int((size.width / crop_size.width) * min_dim) + return super().center_crop(image, SizeDict(height=cropped_height, width=cropped_width)) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, size=size, crop_size=crop_size) + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["PerceiverImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/perceiver/modeling_perceiver.py b/venv/lib/python3.13/site-packages/transformers/models/perceiver/modeling_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e3522ed1d101c32cfb333c47a5969886cae223 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/perceiver/modeling_perceiver.py @@ -0,0 +1,3348 @@ +# coding=utf-8 +# Copyright 2021 Deepmind and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Perceiver model.""" + +import abc +import math +from collections.abc import Mapping +from dataclasses import dataclass +from functools import reduce +from operator import __add__ +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutputWithCrossAttentions +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging, torch_int +from .configuration_perceiver import PerceiverConfig + + +ModalitySizeType = Mapping[str, int] +PreprocessorOutputType = tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor] +PreprocessorType = Callable[..., PreprocessorOutputType] +PostprocessorType = Callable[..., Any] + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Perceiver base model's outputs, with potential hidden states, attentions and cross-attentions. + """ +) +class PerceiverModelOutput(ModelOutput): + r""" + logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + """ + + logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Perceiver decoder outputs, with potential cross-attentions. + """ +) +class PerceiverDecoderOutput(ModelOutput): + r""" + logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`): + Output of the basic decoder. + """ + + logits: Optional[torch.FloatTensor] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Perceiver's masked language model outputs. + """ +) +class PerceiverMaskedLMOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Perceiver's outputs of sequence/image classification models, optical flow and multimodal + autoencoding. + """ +) +class PerceiverClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + + +class PerceiverEmbeddings(nn.Module): + """Construct the latent embeddings.""" + + def __init__(self, config): + super().__init__() + self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents)) + + def forward(self, batch_size: int): + return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang + + +class PerceiverSelfAttention(nn.Module): + """Multi-headed {cross, self}-attention. Can be used both in the encoder as well as in the decoder.""" + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + ): + super().__init__() + self.num_heads = num_heads + # Q and K must have the same number of channels. + # Default to preserving Q's input's shape. + if qk_channels is None: + qk_channels = q_dim + # V's num_channels determines the shape of the output of QKV-attention. + # Default to the same number of channels used in the key-query operation. + if v_channels is None: + v_channels = qk_channels + if qk_channels % num_heads != 0: + raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).") + if v_channels % num_heads != 0: + raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).") + + self.qk_channels = qk_channels + self.v_channels = v_channels + self.qk_channels_per_head = self.qk_channels // num_heads + self.v_channels_per_head = self.v_channels // num_heads + + # Layer normalization + self.layernorm1 = nn.LayerNorm(q_dim) + self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity() + + # Projection matrices + self.query = nn.Linear(q_dim, qk_channels) + self.key = nn.Linear(kv_dim, qk_channels) + self.value = nn.Linear(kv_dim, v_channels) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, channels_per_head): + new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + hidden_states = self.layernorm1(hidden_states) + inputs = self.layernorm2(inputs) + + # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module, + # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to. + is_cross_attention = inputs is not None + queries = self.query(hidden_states) + + if is_cross_attention: + keys = self.key(inputs) + values = self.value(inputs) + attention_mask = inputs_mask + else: + keys = self.key(hidden_states) + values = self.value(hidden_states) + + # Reshape channels for multi-head attention. + # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head) + queries = self.transpose_for_scores(queries, self.qk_channels_per_head) + keys = self.transpose_for_scores(keys, self.qk_channels_per_head) + values = self.transpose_for_scores(values, self.v_channels_per_head) + + # Take the dot product between the queries and keys to get the raw attention scores. + attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) + + batch_size, num_heads, seq_len, q_head_dim = queries.shape + _, _, _, v_head_dim = values.shape + hiddens = self.num_heads * v_head_dim + + attention_scores = attention_scores / math.sqrt(q_head_dim) + + if attention_mask is not None: + # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, values) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (hiddens,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class PerceiverSelfOutput(nn.Module): + def __init__(self, config, input_channels, output_channels): + super().__init__() + self.dense = nn.Linear(input_channels, output_channels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + return hidden_states + + +class PerceiverAttention(nn.Module): + """Attention module, including a dense block.""" + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + use_query_residual=True, + ): + super().__init__() + # MultiHead attention + if is_cross_attention and qk_channels is None: + if config.cross_attention_shape_for_attention == "q": + qk_channels = q_dim + elif config.cross_attention_shape_for_attention == "kv": + qk_channels = kv_dim + else: + raise ValueError( + f"Unknown value {config.cross_attention_shape_for_attention} for " + "cross_attention_shape_for_attention." + ) + else: + if qk_channels is None: + qk_channels = q_dim + if v_channels is None: + v_channels = qk_channels + self.self = PerceiverSelfAttention( + config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + ) + # dense block + output_channels = None + if is_cross_attention: + output_channels = q_dim + else: + if output_channels is None: + output_channels = v_channels + self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels) + self.use_query_residual = use_query_residual + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + + # Output projection + attention_output = self.output(self_outputs[0]) + + # Optionally include a residual to the original queries. + # Consider omitting the residual if the semantics of query and output + # are different, e.g. if queries are positions and outputs are pixels. + if self.use_query_residual: + attention_output = attention_output + hidden_states + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class PerceiverMLP(nn.Module): + """A Transformer-style dense module to follow attention.""" + + def __init__(self, config, input_size, widening_factor): + super().__init__() + self.dense1 = nn.Linear(input_size, widening_factor * input_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(widening_factor * input_size, input_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + return hidden_states + + +class PerceiverLayer(nn.Module): + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + widening_factor=4, + use_query_residual=True, + ): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = PerceiverAttention( + config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + use_query_residual=use_query_residual, + ) + self.layernorm = nn.LayerNorm(q_dim) + self.mlp = PerceiverMLP(config, input_size=q_dim, widening_factor=widening_factor) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] # add attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + + layer_output = layer_output + attention_output # residual connection + + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + layer_output = self.layernorm(attention_output) + layer_output = self.mlp(layer_output) + return layer_output + + +class PerceiverEncoder(nn.Module): + """The Perceiver Encoder: a scalable, fully attentional encoder.""" + + def __init__(self, config, kv_dim=None): + super().__init__() + self.config = config + + # Check that we can use multihead-attention with these shapes. + if config.d_latents % config.num_self_attention_heads != 0: + raise ValueError( + f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_self_attend_heads ({config.num_self_attention_heads})." + ) + if config.d_latents % config.num_cross_attention_heads != 0: + raise ValueError( + f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_cross_attend_heads ({config.num_cross_attention_heads})." + ) + + # Construct the cross attention layer. + self.cross_attention = PerceiverLayer( + config, + is_cross_attention=True, + qk_channels=config.qk_channels, + v_channels=config.v_channels, + num_heads=config.num_cross_attention_heads, + q_dim=config.d_latents, + kv_dim=kv_dim, + widening_factor=config.cross_attention_widening_factor, + use_query_residual=config.use_query_residual, + ) + + # Construct a single block of self-attention layers. + # We get deeper architectures by applying this block more than once. + self_attention_layers = [] + for _ in range(config.num_self_attends_per_block): + layer = PerceiverLayer( + config, + is_cross_attention=False, + qk_channels=config.qk_channels, + v_channels=config.v_channels, + num_heads=config.num_self_attention_heads, + q_dim=config.d_latents, + kv_dim=config.d_latents, + widening_factor=config.self_attention_widening_factor, + ) + self_attention_layers.append(layer) + + self.self_attends = nn.ModuleList(self_attention_layers) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple, BaseModelOutputWithCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + # Apply the cross-attention between the latents (hidden_states) and inputs: + layer_outputs = self.cross_attention( + hidden_states, + attention_mask=attention_mask, + head_mask=None, + inputs=inputs, + inputs_mask=inputs_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_cross_attentions = all_cross_attentions + (layer_outputs[1],) + + # Apply the block of self-attention layers more than once: + for _ in range(self.config.num_blocks): + for i, layer_module in enumerate(self.self_attends): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class PerceiverPreTrainedModel(PreTrainedModel): + config: PerceiverConfig + base_model_prefix = "perceiver" + main_input_name = "inputs" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif hasattr(module, "latents"): + module.latents.data.normal_(mean=0.0, std=self.config.initializer_range) + elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding): + module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.ParameterDict): + for modality in module: + module[modality].data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring( + custom_intro=""" + The Perceiver: a scalable, fully attentional architecture. + + + + Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """ +) +class PerceiverModel(PerceiverPreTrainedModel): + def __init__( + self, + config, + decoder: Optional["PerceiverAbstractDecoder"] = None, + input_preprocessor: PreprocessorType = None, + output_postprocessor: PostprocessorType = None, + ): + r""" + decoder (`PerceiverDecoder`, *optional*): + Decoder module that transforms latent representations into task predictions. + input_preprocessor (`PreprocessorType`, *optional*): + Preprocessor that encodes raw inputs into tensors for the model. + output_postprocessor (`PostprocessorType`, *optional*): + Postprocessor that transforms model outputs into final predictions. + """ + super().__init__(config) + self.config = config + + self.input_preprocessor = input_preprocessor + self.output_postprocessor = output_postprocessor + self.embeddings = PerceiverEmbeddings(config) + self.encoder = PerceiverEncoder( + config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model + ) + self.decoder = decoder + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.latents + + def set_input_embeddings(self, value): + self.embeddings.latents = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + inputs: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + subsampled_output_points: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PerceiverModelOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + subsampled_output_points (`dict[str, torch.Tensor]`, *optional*): + Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent + representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches. + + Examples: + + ```python + >>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverImageProcessor, PerceiverModel + >>> from transformers.models.perceiver.modeling_perceiver import ( + ... PerceiverTextPreprocessor, + ... PerceiverImagePreprocessor, + ... PerceiverClassificationDecoder, + ... ) + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> # EXAMPLE 1: using the Perceiver to classify texts + >>> # - we define a TextPreprocessor, which can be used to embed tokens + >>> # - we define a ClassificationDecoder, which can be used to decode the + >>> # final hidden states of the latents to classification logits + >>> # using trainable position embeddings + >>> config = PerceiverConfig() + >>> preprocessor = PerceiverTextPreprocessor(config) + >>> decoder = PerceiverClassificationDecoder( + ... config, + ... num_channels=config.d_latents, + ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1), + ... use_query_residual=True, + ... ) + >>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder) + + >>> # you can then do a forward pass as follows: + >>> tokenizer = PerceiverTokenizer() + >>> text = "hello world" + >>> inputs = tokenizer(text, return_tensors="pt").input_ids + + >>> with torch.no_grad(): + ... outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2] + + >>> # to train, one can train the model using standard cross-entropy: + >>> criterion = torch.nn.CrossEntropyLoss() + + >>> labels = torch.tensor([1]) + >>> loss = criterion(logits, labels) + + >>> # EXAMPLE 2: using the Perceiver to classify images + >>> # - we define an ImagePreprocessor, which can be used to embed images + >>> config = PerceiverConfig(image_size=224) + >>> preprocessor = PerceiverImagePreprocessor( + ... config, + ... prep_type="conv1x1", + ... spatial_downsample=1, + ... out_channels=256, + ... position_encoding_type="trainable", + ... concat_or_add_pos="concat", + ... project_pos_dim=256, + ... trainable_position_encoding_kwargs=dict( + ... num_channels=256, + ... index_dims=config.image_size**2, + ... ), + ... ) + + >>> model = PerceiverModel( + ... config, + ... input_preprocessor=preprocessor, + ... decoder=PerceiverClassificationDecoder( + ... config, + ... num_channels=config.d_latents, + ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1), + ... use_query_residual=True, + ... ), + ... ) + + >>> # you can then do a forward pass as follows: + >>> image_processor = PerceiverImageProcessor() + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = image_processor(image, return_tensors="pt").pixel_values + + >>> with torch.no_grad(): + ... outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2] + + >>> # to train, one can train the model using standard cross-entropy: + >>> criterion = torch.nn.CrossEntropyLoss() + + >>> labels = torch.tensor([1]) + >>> loss = criterion(logits, labels) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.input_preprocessor is not None: + inputs, modality_sizes, inputs_without_pos = self.input_preprocessor( + inputs, interpolate_pos_encoding=interpolate_pos_encoding + ) + else: + modality_sizes = None + inputs_without_pos = None + if inputs.size()[-1] != self.config.d_model: + raise ValueError( + f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:" + f" {self.config.d_model}. Make sure to set config.d_model appropriately." + ) + + batch_size, seq_length, _ = inputs.size() + device = inputs.device + + # If no attention mask is provided, make them all ones + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length), device=device) + # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = self.invert_attention_mask(attention_mask) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_blocks x num_heads] + # and head_mask is converted to shape [num_blocks x batch x num_heads x N x N] + head_mask = self.get_head_mask(head_mask, self.config.num_blocks * self.config.num_self_attends_per_block) + + embedding_output = self.embeddings(batch_size=batch_size) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=None, + head_mask=head_mask, + inputs=inputs, + inputs_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + logits = None + if self.decoder: + if subsampled_output_points is not None: + output_modality_sizes = { + "audio": subsampled_output_points["audio"].shape[0], + "image": subsampled_output_points["image"].shape[0], + "label": 1, + } + else: + output_modality_sizes = modality_sizes + decoder_query = self.decoder.decoder_query( + inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points + ) + decoder_outputs = self.decoder( + decoder_query, + z=sequence_output, + query_mask=extended_attention_mask, + output_attentions=output_attentions, + ) + logits = decoder_outputs.logits + + # add cross-attentions of decoder + if output_attentions and decoder_outputs.cross_attentions is not None: + if return_dict: + encoder_outputs.cross_attentions = ( + encoder_outputs.cross_attentions + decoder_outputs.cross_attentions + ) + else: + encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions + + if self.output_postprocessor: + logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes) + + if not return_dict: + if logits is not None: + return (logits, sequence_output) + encoder_outputs[1:] + else: + return (sequence_output,) + encoder_outputs[1:] + + return PerceiverModelOutput( + logits=logits, + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Example use of Perceiver for masked language modeling. + """ +) +class PerceiverForMaskedLM(PerceiverPreTrainedModel): + def __init__(self, config: PerceiverConfig): + super().__init__(config) + + text_preprocessor = PerceiverTextPreprocessor(config) + + trainable_position_encoding_kwargs_decoder = { + "num_channels": text_preprocessor.num_channels, + "index_dims": config.max_position_embeddings, + } + + self.perceiver = PerceiverModel( + config, + input_preprocessor=text_preprocessor, + decoder=PerceiverBasicDecoder( + config, + output_num_channels=config.d_latents, + output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand + num_channels=text_preprocessor.num_channels, + qk_channels=8 * 32, + v_channels=text_preprocessor.num_channels, + num_heads=8, + use_query_residual=False, + final_project=False, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + ), + ) + self.embedding_decoder = PerceiverEmbeddingDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + input_ids: Optional[torch.Tensor] = None, + ) -> Union[tuple, PerceiverMaskedLMOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Examples: + + ```python + >>> from transformers import AutoTokenizer, PerceiverForMaskedLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver") + >>> model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver") + + >>> # training + >>> text = "This is an incomplete sentence where some words are missing." + >>> inputs = tokenizer(text, padding="max_length", return_tensors="pt") + >>> # mask " missing." + >>> inputs["input_ids"][0, 52:61] = tokenizer.mask_token_id + >>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids + + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> round(loss.item(), 2) + 19.87 + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2048, 262] + + >>> # inference + >>> text = "This is an incomplete sentence where some words are missing." + >>> encoding = tokenizer(text, padding="max_length", return_tensors="pt") + + >>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space. + >>> encoding["input_ids"][0, 52:61] = tokenizer.mask_token_id + + >>> # forward pass + >>> with torch.no_grad(): + ... outputs = model(**encoding) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2048, 262] + + >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist() + >>> tokenizer.decode(masked_tokens_predictions) + ' missing.' + ```""" + if inputs is not None and input_ids is not None: + raise ValueError("You cannot use both `inputs` and `input_ids`") + elif inputs is None and input_ids is not None: + inputs = input_ids + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.embedding_decoder( + outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings + ) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return PerceiverMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Example use of Perceiver for text classification. + """ +) +class PerceiverForSequenceClassification(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverTextPreprocessor(config), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + input_ids: Optional[torch.Tensor] = None, + ) -> Union[tuple, PerceiverClassifierOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the classification/regression loss. Indices should be in `[0, ..., config.num_labels - + 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > + 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoTokenizer, PerceiverForSequenceClassification + + >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver") + >>> model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver") + + >>> text = "hello world" + >>> inputs = tokenizer(text, return_tensors="pt").input_ids + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2] + ```""" + if inputs is not None and input_ids is not None: + raise ValueError("You cannot use both `inputs` and `input_ids`") + elif inputs is None and input_ids is not None: + inputs = input_ids + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Example use of Perceiver for image classification, for tasks such as ImageNet. + + This model uses learned position embeddings. In other words, this model is not given any privileged information about + the structure of images. As shown in the paper, this model can achieve a top-1 accuracy of 72.7 on ImageNet. + + [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] + (with `prep_type="conv1x1"`) to preprocess the input images, and + [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of + [`PerceiverModel`] into classification logits. + """ +) +class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + trainable_position_encoding_kwargs_preprocessor = {"num_channels": 256, "index_dims": config.image_size**2} + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="conv1x1", + spatial_downsample=1, + out_channels=256, + position_encoding_type="trainable", + concat_or_add_pos="concat", + project_pos_dim=256, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[tuple, PerceiverClassifierOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationLearned + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-learned") + >>> model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned") + + >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 1000] + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: tabby, tabby cat + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Example use of Perceiver for image classification, for tasks such as ImageNet. + + This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of + 79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT). + + [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] + (with `prep_type="pixels"`) to preprocess the input images, and + [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of + [`PerceiverModel`] into classification logits. + """ +) +class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = { + "concat_pos": True, + "max_resolution": (224, 224), + "num_bands": 64, + "sine_only": False, + } + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="pixels", + spatial_downsample=1, + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[tuple, PerceiverClassifierOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationFourier + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-fourier") + >>> model = PerceiverForImageClassificationFourier.from_pretrained("deepmind/vision-perceiver-fourier") + + >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 1000] + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: tabby, tabby cat + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Example use of Perceiver for image classification, for tasks such as ImageNet. + + This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy + of 82.1 on ImageNet. + + [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] + (with `prep_type="conv"`) to preprocess the input images, and + [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of + [`PerceiverModel`] into classification logits. + """ +) +class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = { + "concat_pos": True, + "max_resolution": (56, 56), + "num_bands": 64, + "sine_only": False, + } + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="conv", + spatial_downsample=1, + position_encoding_type="fourier", + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[tuple, PerceiverClassifierOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationConvProcessing + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-conv") + >>> model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv") + + >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 1000] + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: tabby, tabby cat + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses + [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the + input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent + representation of [`PerceiverModel`]. + + As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel + (leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position + of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation + using the same encoding used for the input. + """ +) +class PerceiverForOpticalFlow(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = { + "num_bands": 64, + "max_resolution": config.train_size, + "sine_only": False, + "concat_pos": True, + } + fourier_position_encoding_kwargs_decoder = { + "concat_pos": True, + "max_resolution": config.train_size, + "num_bands": 64, + "sine_only": False, + } + + image_preprocessor = PerceiverImagePreprocessor( + config, + prep_type="patches", + spatial_downsample=1, + conv_after_patching=True, + conv_after_patching_in_channels=54, + temporal_downsample=2, + position_encoding_type="fourier", + # position_encoding_kwargs + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ) + + self.perceiver = PerceiverModel( + config, + input_preprocessor=image_preprocessor, + decoder=PerceiverOpticalFlowDecoder( + config, + num_channels=image_preprocessor.num_channels, + output_image_shape=config.train_size, + rescale_factor=100.0, + # decoder kwargs + use_query_residual=False, + output_num_channels=2, + # We query the decoder using the first frame features + # rather than a standard decoder position encoding. + position_encoding_type="fourier", + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PerceiverClassifierOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`. + + Examples: + + ```python + >>> from transformers import PerceiverForOpticalFlow + >>> import torch + + >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver") + + >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel, + >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels) + >>> # patches have shape (batch_size, num_frames, num_channels, height, width) + >>> # the authors train on resolutions of 368 x 496 + >>> patches = torch.randn(1, 2, 27, 368, 496) + >>> outputs = model(inputs=patches) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 368, 496, 2] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + if labels is not None: + raise NotImplementedError("Optical flow training is not yet supported") + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Example use of Perceiver for multimodal (video) autoencoding, for tasks such as Kinetics-700. + + [`PerceiverForMultimodalAutoencoding`] uses [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] to + preprocess the 3 modalities: images, audio and class labels. This preprocessor uses modality-specific preprocessors to + preprocess every modality separately, after which they are concatenated. Trainable position embeddings are used to pad + each modality to the same number of channels to make concatenation along the time dimension possible. Next, one applies + the Perceiver encoder. + + [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] is used to decode the latent representation of + [`PerceiverModel`]. This decoder uses each modality-specific decoder to construct queries. The decoder queries are + created based on the inputs after preprocessing. However, autoencoding an entire video in a single forward pass is + computationally infeasible, hence one only uses parts of the decoder queries to do cross-attention with the latent + representation. This is determined by the subsampled indices for each modality, which can be provided as additional + input to the forward pass of [`PerceiverForMultimodalAutoencoding`]. + + [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] also pads the decoder queries of the different + modalities to the same number of channels, in order to concatenate them along the time dimension. Next, cross-attention + is performed with the latent representation of [`PerceiverModel`]. + + Finally, [`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] is used to turn this tensor into an + actual video. It first splits up the output into the different modalities, and then applies the respective + postprocessor for each modality. + + Note that, by masking the classification label during evaluation (i.e. simply providing a tensor of zeros for the + "label" modality), this auto-encoding model becomes a Kinetics 700 video classifier. + """ +) +class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): + def __init__(self, config: PerceiverConfig): + super().__init__(config) + + n_audio_samples = config.num_frames * config.audio_samples_per_frame + + input_preprocessor = PerceiverMultimodalPreprocessor( + min_padding_size=4, + modalities={ + "audio": PerceiverAudioPreprocessor( + config, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 192, + "max_resolution": (n_audio_samples,), + "sine_only": False, + "concat_pos": True, + }, + prep_type="patches", + samples_per_patch=config.samples_per_patch, + ), + "image": PerceiverImagePreprocessor( + config, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 32, + "max_resolution": (config.num_frames, config.image_size, config.image_size), + "sine_only": False, + "concat_pos": True, + }, + prep_type="patches", + spatial_downsample=4, + temporal_downsample=1, + ), + "label": PerceiverOneHotPreprocessor(config), + }, + mask_probs={"image": 0.0, "audio": 0.0, "label": 1.0}, + ) + + image_decoder = PerceiverBasicVideoAutoencodingDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + output_shape=config.output_shape, + output_num_channels=config.output_num_channels, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 32, + "max_resolution": (config.num_frames, config.image_size, config.image_size), + "sine_only": False, + "concat_pos": True, + }, + ) + + decoder = PerceiverMultimodalDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + # Modality specific decoders are used ONLY to generate queries. + # All modalties are decoded together using a unified decoder. + modalities={ + "audio": PerceiverBasicDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + output_index_dims=(n_audio_samples // config.samples_per_patch,), + output_num_channels=config.output_num_channels, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 192, + "max_resolution": (n_audio_samples,), + "sine_only": False, + "concat_pos": True, + }, + ), + "image": image_decoder, + "label": PerceiverClassificationDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="trainable", + trainable_position_encoding_kwargs={ + "num_channels": config._label_trainable_num_channels, + "index_dims": 1, + }, + ), + }, + num_outputs=None, + output_num_channels=config.output_num_channels, + use_query_residual=False, + ) + + output_postprocessor = PerceiverMultimodalPostprocessor( + modalities={ + "audio": PerceiverAudioPostprocessor(config, in_channels=config.output_num_channels), + "image": PerceiverProjectionPostprocessor(in_channels=config.output_num_channels, out_channels=3), + "label": PerceiverClassificationPostprocessor(config, in_channels=config.output_num_channels), + } + ) + + self.perceiver = PerceiverModel( + config, + input_preprocessor=input_preprocessor, + decoder=decoder, + output_postprocessor=output_postprocessor, + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + subsampled_output_points: Optional[dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PerceiverClassifierOutput]: + r""" + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + subsampled_output_points (`dict[str, torch.Tensor]`, *optional*): + Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent + representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import PerceiverForMultimodalAutoencoding + >>> import torch + >>> import numpy as np + + >>> # create multimodal inputs + >>> images = torch.randn((1, 16, 3, 224, 224)) + >>> audio = torch.randn((1, 30720, 1)) + >>> inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700))) + + >>> model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver") + + >>> # in the Perceiver IO paper, videos are auto-encoded in chunks + >>> # each chunk subsamples different index dimensions of the image and audio modality decoder queries + >>> nchunks = 128 + >>> image_chunk_size = np.prod((16, 224, 224)) // nchunks + >>> audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks + >>> # process the first chunk + >>> chunk_idx = 0 + >>> subsampling = { + ... "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)), + ... "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)), + ... "label": None, + ... } + + >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling) + >>> logits = outputs.logits + >>> list(logits["audio"].shape) + [1, 240] + + >>> list(logits["image"].shape) + [1, 6272, 3] + + >>> list(logits["label"].shape) + [1, 700] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + if labels is not None: + raise NotImplementedError("Multimodal autoencoding training is not yet supported") + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + subsampled_output_points=subsampled_output_points, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Below: position encodings + + +def build_position_encoding( + position_encoding_type, + out_channels=None, + project_pos_dim=-1, + trainable_position_encoding_kwargs=None, + fourier_position_encoding_kwargs=None, +): + """ + Builds the position encoding. + + Args: + - out_channels: refers to the number of channels of the position encodings. + - project_pos_dim: if specified, will project the position encodings to this dimension. + + """ + + if position_encoding_type == "trainable": + if not trainable_position_encoding_kwargs: + raise ValueError("Make sure to pass trainable_position_encoding_kwargs") + output_pos_enc = PerceiverTrainablePositionEncoding(**trainable_position_encoding_kwargs) + elif position_encoding_type == "fourier": + # We don't use the index_dims argument, as this is only known during the forward pass + if not fourier_position_encoding_kwargs: + raise ValueError("Make sure to pass fourier_position_encoding_kwargs") + output_pos_enc = PerceiverFourierPositionEncoding(**fourier_position_encoding_kwargs) + else: + raise ValueError(f"Unknown position encoding type: {position_encoding_type}.") + + # Optionally, project the position encoding to a target dimension: + positions_projection = nn.Linear(out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity() + + return output_pos_enc, positions_projection + + +# Below: Perceiver decoders + + +class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta): + """Perceiver abstract decoder.""" + + @abc.abstractmethod + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + raise NotImplementedError + + @property + @abc.abstractmethod + def num_query_channels(self): + raise NotImplementedError + + @abc.abstractmethod + def forward(self, query, z, query_mask=None): + raise NotImplementedError + + +class PerceiverProjectionDecoder(PerceiverAbstractDecoder): + """ + Baseline projection decoder (no cross-attention). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config): + super().__init__() + self.classifier = nn.Linear(config.d_latents, config.num_labels) + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return None + + def forward( + self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + # (batch_size, num_latents, d_latents) -> (batch_size, d_latents) + z = torch.mean(z, dim=1) + # (batch_size, d_latents) -> (batch_size, config.num_labels) + logits = self.classifier(z) + return logits + + +class PerceiverBasicDecoder(PerceiverAbstractDecoder): + """ + Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a + cross-attention operation, in which the latents produce keys and values. + + The shape of the output of this class depends on how one defines the output queries (also called decoder queries). + + Args: + config ([*PerceiverConfig*]): + Model configuration. + output_num_channels (`int`, *optional*): + The number of channels in the output. Will only be used in case *final_project* is set to `True`. + position_encoding_type (`str`, *optional*, defaults to "trainable"): + The type of position encoding to use. Can be either "trainable", "fourier", or "none". + output_index_dims (`int`, *optional*): + The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'. + num_channels (`int`, *optional*, defaults to 128): + The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'. + qk_channels (`int`, *optional*): + The number of channels of the queries and keys in the cross-attention layer. + v_channels (`int`, *optional*): + The number of channels of the values in the cross-attention layer. + num_heads (`int`, *optional*, defaults to 1): + The number of attention heads in the cross-attention layer. + widening_factor (`int`, *optional*, defaults to 1): + The widening factor of the cross-attention layer. + use_query_residual (`bool`, *optional*, defaults to `False`): + Whether to use a residual connection between the query and the output of the cross-attention layer. + concat_preprocessed_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the preprocessed input to the query. + final_project (`bool`, *optional*, defaults to `True`): + Whether to project the output of the cross-attention layer to a target dimension. + position_encoding_only (`bool`, *optional*, defaults to `False`): + Whether to only use this class to define output queries. + """ + + def __init__( + self, + config: PerceiverConfig, + output_num_channels: int, + position_encoding_type: Optional[str] = "trainable", + # The following 2 arguments are ignored if position_encoding_type == 'none': + output_index_dims: Optional[int] = None, + num_channels: Optional[int] = 128, + subsampled_index_dims: Optional[int] = None, + qk_channels: Optional[int] = None, + v_channels: Optional[int] = None, + num_heads: Optional[int] = 1, + widening_factor: Optional[int] = 1, + use_query_residual: Optional[bool] = False, + concat_preprocessed_input: Optional[bool] = False, + final_project: Optional[bool] = True, + position_encoding_only: Optional[bool] = False, + **position_encoding_kwargs, + ) -> None: + super().__init__() + + self.output_num_channels = output_num_channels + # If `none`, the decoder will not construct any position encodings. + # You should construct your own when querying the decoder. + self.output_position_encodings = None + self.position_encoding_type = position_encoding_type + self.position_encoding_kwargs = position_encoding_kwargs + if position_encoding_type != "none": + self.output_position_encodings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, **position_encoding_kwargs + ) + + self.output_index_dims = output_index_dims + self.num_channels = num_channels + if subsampled_index_dims is None: + subsampled_index_dims = output_index_dims + self.subsampled_index_dims = subsampled_index_dims + self.concat_preprocessed_input = concat_preprocessed_input + self.final_project = final_project + self.position_encoding_only = position_encoding_only + + # for multimodal autoencoding, we don't need the decoder cross-attention and final layer + # so then we will set position_encoding_only to True + if not self.position_encoding_only: + self.decoding_cross_attention = PerceiverLayer( + config, + is_cross_attention=True, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=num_channels, + kv_dim=config.d_latents, + widening_factor=widening_factor, + use_query_residual=use_query_residual, + ) + self.final_layer = nn.Linear(num_channels, output_num_channels) if final_project else nn.Identity() + + @property + def num_query_channels(self) -> int: + if self.position_encoding_type == "none": # Queries come from elsewhere + raise ValueError( + "You cannot calculate number of decoder query channels when position_encoding_type is set to none" + ) + if self.position_encoding_only: + if "project_pos_dim" in self.position_encoding_kwargs: + return self.position_encoding_kwargs["project_pos_dim"] + return self.output_position_encodings.output_size() + if self.final_project: + return self.output_num_channels + return self.num_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if self.position_encoding_type == "none": # Queries come from elsewhere + raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none") + if subsampled_points is not None: + # subsampled_points are the indices if the inputs would be flattened + # however, the inputs aren't flattened, that's why we use unravel_index + # to get the indices for the unflattened array + # unravel_index returns a tuple (x_idx, y_idx, ...) + # stack to get the [n, d] tensor of coordinates + indices = [torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)] + pos = torch.stack(indices, dim=1) + batch_size = inputs.shape[0] + # Map these coordinates to [-1, 1] + pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :] + pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]]) + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_emb = self.output_position_encodings(batch_size) + elif self.position_encoding_type == "fourier": + pos_emb = self.output_position_encodings( + self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos + ) + + # Optionally project them to a target dimension. + pos_emb = self.positions_projection(pos_emb) + pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]]) + else: + batch_size = inputs.shape[0] + index_dims = inputs.shape[2:] + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_emb = self.output_position_encodings(batch_size) + elif self.position_encoding_type == "fourier": + pos_emb = self.output_position_encodings( + index_dims, batch_size, device=inputs.device, dtype=inputs.dtype + ) + + # Optionally project them to a target dimension. + pos_emb = self.positions_projection(pos_emb) + + if self.concat_preprocessed_input: + if inputs_without_pos is None: + raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True") + pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1) + + return pos_emb + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: + # Cross-attention decoding. + # key, value: B x N x K; query: B x M x K + # Attention maps -> B x N x M + # Output -> B x M x K + cross_attentions = () if output_attentions else None + + layer_outputs = self.decoding_cross_attention( + query, + attention_mask=query_mask, + head_mask=None, + inputs=z, + inputs_mask=None, + output_attentions=output_attentions, + ) + output = layer_outputs[0] + + if output_attentions: + cross_attentions = cross_attentions + (layer_outputs[1],) + + logits = self.final_layer(output) + + return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions) + + +class PerceiverClassificationDecoder(PerceiverAbstractDecoder): + """ + Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output. + Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of + shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config, **decoder_kwargs): + super().__init__() + + self.num_labels = config.num_labels + self.decoder = PerceiverBasicDecoder( + config, + output_num_channels=self.num_labels, + output_index_dims=1, # Predict a single logit array. + **decoder_kwargs, + ) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return self.decoder.decoder_query( + inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points + ) + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: + decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) + + # B x 1 x num_classes -> B x num_classes + logits = decoder_outputs.logits[:, 0, :] + + return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder): + """Cross-attention based optical flow decoder.""" + + def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs): + super().__init__() + + self.output_image_shape = output_image_shape + self.output_num_channels = output_num_channels + self.rescale_factor = rescale_factor + self.decoder = PerceiverBasicDecoder(config, output_num_channels=output_num_channels, **decoder_kwargs) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if subsampled_points is not None: + raise ValueError("FlowDecoder doesn't support subsampling yet.") + return inputs + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: + decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) + preds = decoder_outputs.logits + # Output flow and rescale. + preds /= self.rescale_factor + preds = preds.reshape([preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]]) + return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder): + """ + Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video + reshaping logic. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + output_shape (`list[int]`): + Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension. + position_encoding_type (`str`): + The type of position encoding to use. Can be either "trainable", "fourier", or "none". + """ + + def __init__( + self, config: PerceiverConfig, output_shape: list[int], position_encoding_type: str, **decoder_kwargs + ) -> None: + super().__init__() + if len(output_shape) != 4: # B, T, H, W + raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.") + # Build the decoder components: + self.output_shape = output_shape + self.output_num_channels = decoder_kwargs["output_num_channels"] + + self.decoder = PerceiverBasicDecoder( + config, + output_index_dims=self.output_shape[1:4], # T*H*W + position_encoding_type=position_encoding_type, + **decoder_kwargs, + ) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return self.decoder.decoder_query( + inputs, + modality_sizes=modality_sizes, + inputs_without_pos=inputs_without_pos, + subsampled_points=subsampled_points, + ) + + def forward( + self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None + ) -> PerceiverDecoderOutput: + decoder_outputs = self.decoder(query, z) + logits = decoder_outputs.logits + + logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]]) + return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions) + + +def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]: + """ + Partitions a [B, N, C] tensor into tensors for each modality. + + Args: + modality_sizes + dict specifying the size of the modality + inputs: + input tensor + + Returns: + dict mapping name of modality to its associated tensor. + """ + outputs = {} + index = 0 + # Apply a predictable ordering to the modalities + for modality in sorted(modality_sizes.keys()): + size = modality_sizes[modality] + inp = inputs[:, index : index + size] + index += size + outputs[modality] = inp + return outputs + + +class PerceiverMultimodalDecoder(PerceiverAbstractDecoder): + """ + Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary + mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that + modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are + concatenated along the time dimension. + + Next, there is a shared cross attention operation across all modalities. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + modalities (`dict[str, PerceiverAbstractDecoder]`): + Dictionary mapping modality name to the decoder of that modality. + num_outputs (`int`): + The number of outputs of the decoder. + output_num_channels (`int`): + The number of channels in the output. + min_padding_size (`int`, *optional*, defaults to 2): + The minimum padding size for all modalities. The final output will have num_channels equal to the maximum + channels across all modalities plus min_padding_size. + subsampled_index_dims (`dict[str, PerceiverAbstractDecoder]`, *optional*): + Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that + modality. + """ + + def __init__( + self, + config: PerceiverConfig, + modalities: dict[str, PerceiverAbstractDecoder], + num_outputs: int, + output_num_channels: int, + min_padding_size: Optional[int] = 2, + subsampled_index_dims: Optional[dict[str, PerceiverAbstractDecoder]] = None, + **decoder_kwargs, + ) -> None: + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.subsampled_index_dims = subsampled_index_dims + self.min_padding_size = min_padding_size + self.output_num_channels = output_num_channels + self.num_outputs = num_outputs + self.decoder = PerceiverBasicDecoder( + config, + output_index_dims=(num_outputs,), + output_num_channels=output_num_channels, + position_encoding_type="none", + num_channels=self.num_query_channels, + **decoder_kwargs, + ) + self.padding = nn.ParameterDict( + { + modality: nn.Parameter(torch.randn(1, self.num_query_channels - decoder.num_query_channels)) + for modality, decoder in modalities.items() + } + ) + + @property + def num_query_channels(self) -> int: + max_channel_size = max(decoder.num_query_channels for _, decoder in self.modalities.items()) + common_channel_size = max_channel_size + self.min_padding_size + return common_channel_size + + def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None): + # Partition the flat inputs among the different modalities + inputs = restructure(modality_sizes, inputs) + + # Obtain modality-specific decoders' queries + subsampled_points = subsampled_points or {} + + decoder_queries = {} + for modality, decoder in self.modalities.items(): + # Get input_without_pos for this modality if it exists. + input_without_pos = None + if inputs_without_pos is not None: + input_without_pos = inputs_without_pos.get(modality, None) + query = decoder.decoder_query( + inputs=inputs[modality], + modality_sizes=None, + inputs_without_pos=input_without_pos, + subsampled_points=subsampled_points.get(modality, None), + ) + decoder_queries[modality] = query + + # Pad all queries with trainable position encodings to make them have the same channels + + def embed(modality, x): + x = torch.reshape(x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]]) + pos = self.padding[modality] + pos = torch.broadcast_to(pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]]) + return torch.cat([x, pos], dim=2) + + # Apply a predictable ordering to the modalities + return torch.cat( + [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1 + ) + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + # B x 1 x num_classes -> B x num_classes + decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) + + return decoder_outputs + + +# Below: IO pre- and post-processor classes for Perceiver. +def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor: + """ + Space to depth transform. Rearranges blocks of spatial data, into depth. + + This function assumes the channels to be first, but will place the channels last after transformation. + + Based on https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/15. + """ + if len(frames.shape) == 4: + batch_size, num_channels, height, width = frames.shape + # split up dimensions (height by spatial_block_size, width by spatial_block_size) + frames = frames.view( + batch_size, + num_channels, + height // spatial_block_size, + spatial_block_size, + width // spatial_block_size, + spatial_block_size, + ) + # move blocks to last dimension: (batch_size, H//bs, W//bs, bs, bs, C) + frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous() + # concatenate blocks along channel dimension: (batch_size, H//bs, W//bs, bs*bs*C) + frames = frames.view( + batch_size, + height // spatial_block_size, + width // spatial_block_size, + (spatial_block_size**2) * num_channels, + ) + return frames + elif len(frames.shape) == 5: + batch_size, time, num_channels, height, width = frames.shape + # split up dimensions (time by temporal_block_size, height by spatial_block_size, width by spatial_block_size) + frames = frames.view( + batch_size, + time // temporal_block_size, + temporal_block_size, + num_channels, + height // spatial_block_size, + spatial_block_size, + width // spatial_block_size, + spatial_block_size, + ) + # move blocks to last dimension: (batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C) + frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous() + # concatenate blocks along channel dimension: (batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C) + frames = frames.view( + batch_size, + time // temporal_block_size, + height // spatial_block_size, + width // spatial_block_size, + temporal_block_size * (spatial_block_size**2) * num_channels, + ) + return frames + else: + raise ValueError( + "Frames should be of rank 4 (batch, channels, height, width)" + " or rank 5 (batch, time, channels, height, width)" + ) + + +class Conv2dSamePadding(nn.Conv2d): + """ + Conv2d layer with padding="same" support. Source: + https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.zero_pad_2d = nn.ZeroPad2d( + reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]]) + ) + + def forward(self, input): + return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias) + + +class Conv2DDownsample(nn.Module): + """Downsamples 4x by applying a 2D convolution and doing max pooling.""" + + def __init__( + self, + num_layers: int = 1, + in_channels: int = 3, + out_channels: int = 64, + use_batchnorm: bool = True, + ): + """ + Constructs a Conv2DDownsample model. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 64): + The number of conv output channels. + use_batchnorm (`bool`, *optional*, defaults to `True`): + Whether to use batchnorm. + """ + super().__init__() + + self.conv = Conv2dSamePadding( + in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False + ) + self.batchnorm = nn.BatchNorm2d(num_features=out_channels) if use_batchnorm else nn.Identity() + self.relu = nn.ReLU() + self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + out = self.conv(inputs) + out = self.batchnorm(out) + out = self.relu(out) + out = self.max_pool(out) + return out + + +def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False): + """ + Generate a Fourier frequency position encoding with linear spacing. + + Args: + pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`): + The Tensor containing the position of n points in d dimensional space. + num_bands (`int`): + The number of frequency bands (K) to use. + max_resolution (`tuple[int]`, *optional*, defaults to (224, 224)): + The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension. + concat_pos (`bool`, *optional*, defaults to `True`): + Whether to concatenate the input position encoding to the Fourier features. + sine_only (`bool`, *optional*, defaults to `False`): + Whether to use a single phase (sin) or two (sin/cos) for each frequency band. + + Returns: + `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If + `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d, + sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1), + ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the + kth frequency band. + """ + + batch_size = pos.shape[0] + + min_freq = 1.0 + # Nyquist frequency at the target resolution: + freq_bands = torch.stack( + [torch.linspace(start=min_freq, end=res / 2, steps=num_bands) for res in max_resolution], dim=0 + ) + + # Get frequency bands for each spatial dimension. + # Output is size [n, d * num_bands] + per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :] + per_pos_features = torch.reshape(per_pos_features, [-1, np.prod(per_pos_features.shape[1:])]) + + if sine_only: + # Output is size [n, d * num_bands] + per_pos_features = torch.sin(np.pi * (per_pos_features)) + else: + # Output is size [n, 2 * d * num_bands] + per_pos_features = torch.cat( + [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1 + ) + # Concatenate the raw input positions. + if concat_pos: + # Adds d bands to the encoding. + per_pos_features = torch.cat([pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1) + return per_pos_features + + +def build_linear_positions(index_dims, output_range=(-1.0, 1.0)): + """ + Generate an array of position indices for an N-D input array. + + Args: + index_dims (`list[int]`): + The shape of the index dimensions of the input array. + output_range (`tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`): + The min and max values taken by each input index dimension. + + Returns: + `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`. + """ + + def _linspace(n_xels_per_dim): + return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32) + + dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims] + array_index_grid = meshgrid(*dim_ranges, indexing="ij") + + return torch.stack(array_index_grid, dim=-1) + + +class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta): + """Perceiver abstract position encoding.""" + + @property + @abc.abstractmethod + def num_dimensions(self) -> int: + raise NotImplementedError + + @abc.abstractmethod + def output_size(self, *args, **kwargs) -> int: + raise NotImplementedError + + @abc.abstractmethod + def forward(self, batch_size, pos): + raise NotImplementedError + + +class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): + """Trainable position encoding.""" + + def __init__(self, index_dims, num_channels=128): + super().__init__() + self._num_channels = num_channels + self._index_dims = index_dims + index_dim = np.prod(index_dims) + self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels)) + + @property + def num_dimensions(self) -> int: + if isinstance(self._index_dims, int): + return 1 + return len(self._index_dims) + + def output_size(self, *args, **kwargs) -> int: + return self._num_channels + + def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + num_positions = position_embeddings.shape[0] + new_height = new_width = torch_int(num_positions**0.5) + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and height == new_height and width == new_width: + return position_embeddings + + position_embeddings = position_embeddings.reshape(1, new_height, new_width, self._num_channels).permute( + 0, 3, 1, 2 + ) + + position_embeddings = nn.functional.interpolate( + position_embeddings, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0) + return position_embeddings + + def forward( + self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: Optional[torch.Size] = None + ) -> torch.Tensor: + position_embeddings = self.position_embeddings + + if interpolate_pos_encoding: + height, width = input_size + position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width) + + if batch_size is not None: + position_embeddings = position_embeddings.expand(batch_size, -1, -1) + return position_embeddings + + +def _check_or_build_spatial_positions(pos, index_dims, batch_size): + """ + Checks or builds spatial position features (x, y, ...). + + Args: + pos (`torch.FloatTensor`): + None, or an array of position features. If None, position features are built. Otherwise, their size is checked. + index_dims (`list[int]`): + An iterable giving the spatial/index size of the data to be featurized. + batch_size (`int`): + The batch size of the data to be featurized. + + Returns: + `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features. + """ + if pos is None: + pos = build_linear_positions(index_dims) + # equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)` + # but `torch.broadcast_to` cannot be converted to ONNX + pos = pos[None].expand((batch_size,) + pos.shape) + pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1]) + else: + # Just a warning label: you probably don't want your spatial features to + # have a different spatial layout than your pos coordinate system. + # But feel free to override if you think it'll work! + if pos.shape[-1] != len(index_dims): + raise ValueError("Spatial features have the wrong number of dimensions.") + return pos + + +class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding): + """Fourier (Sinusoidal) position encoding.""" + + def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False): + super().__init__() + self.num_bands = num_bands + self.max_resolution = max_resolution + self.concat_pos = concat_pos + self.sine_only = sine_only + + @property + def num_dimensions(self) -> int: + return len(self.max_resolution) + + def output_size(self): + """Returns size of positional encodings last dimension.""" + num_dims = len(self.max_resolution) + encoding_size = self.num_bands * num_dims + if not self.sine_only: + encoding_size *= 2 + if self.concat_pos: + encoding_size += self.num_dimensions + + return encoding_size + + def forward( + self, + index_dims: list[int], + batch_size: int, + device: torch.device, + dtype: torch.dtype, + pos: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + pos = _check_or_build_spatial_positions(pos, index_dims, batch_size) + fourier_pos_enc = generate_fourier_features( + pos, + num_bands=self.num_bands, + max_resolution=self.max_resolution, + concat_pos=self.concat_pos, + sine_only=self.sine_only, + ).to(device=device, dtype=dtype) + return fourier_pos_enc + + +class AbstractPreprocessor(nn.Module): + @property + def num_channels(self) -> int: + """Returns size of preprocessor output.""" + raise NotImplementedError() + + +class PerceiverTextPreprocessor(AbstractPreprocessor): + """ + Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings. + + The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration. + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config: PerceiverConfig) -> None: + super().__init__() + self.config = config + self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) + + @property + def num_channels(self) -> int: + return self.config.d_model + + def forward( + self, + inputs: torch.LongTensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): + embeddings_without_pos = self.embeddings(inputs) + + seq_length = inputs.shape[1] + position_ids = torch.arange(0, seq_length, device=inputs.device) + embeddings = embeddings_without_pos + self.position_embeddings(position_ids) + + return embeddings, None, embeddings_without_pos + + +class PerceiverEmbeddingDecoder(nn.Module): + """ + Module to decode embeddings (for masked language modeling). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config: PerceiverConfig) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.bias = nn.Parameter(torch.zeros(self.vocab_size)) + + def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, d_model = hidden_states.shape + # Flatten batch dim + output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1)) + output = output + self.bias + + return output.reshape([batch_size, seq_len, self.vocab_size]) + + +class PerceiverMultimodalPostprocessor(nn.Module): + """ + Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single + postprocessor. + + Args: + modalities (`Mapping[str, PostprocessorType]`): + Dictionary mapping modality name to postprocessor class for that modality. + input_is_dict (`bool`, *optional*, defaults to `False`): + If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If + False, input is a tensor which is sliced up during postprocessing by *modality_sizes*. + """ + + def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False): + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.input_is_dict = input_is_dict + + def forward( + self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None + ) -> Mapping[str, torch.Tensor]: + if not self.input_is_dict: + # Slice up modalities by their sizes. + if modality_sizes is None: + raise ValueError("Modality sizes should be specified if input is not a dictionary.") + inputs = restructure(modality_sizes=modality_sizes, inputs=inputs) + + outputs = { + modality: postprocessor(inputs[modality], pos=pos, modality_sizes=None) + for modality, postprocessor in self.modalities.items() + } + return outputs + + +class PerceiverClassificationPostprocessor(nn.Module): + """ + Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + in_channels (`int`): + Number of channels in the input. + """ + + def __init__(self, config: PerceiverConfig, in_channels: int) -> None: + super().__init__() + self.classifier = nn.Linear(in_channels, config.num_labels) + + def forward(self, inputs, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return logits[:, 0, :] + + +class PerceiverAudioPostprocessor(nn.Module): + """ + Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + in_channels (`int`): + Number of channels in the input. + postproc_type (`str`, *optional*, defaults to `"patches"`): + Postprocessor type to use. Currently, only "patches" is supported. + """ + + def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None: + super().__init__() + + if postproc_type != "patches": # to be supported: 'conv', 'patches', 'pixels' + raise ValueError("Invalid postproc_type!") + + # Architecture parameters: + self.classifier = nn.Linear(in_channels, config.samples_per_patch) + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return torch.reshape(logits, [inputs.shape[0], -1]) + + +class PerceiverProjectionPostprocessor(nn.Module): + """ + Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower + dimension. + + Args: + in_channels (`int`): + Number of channels in the input. + out_channels (`int`): + Number of channels in the output. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.classifier = nn.Linear(in_channels, out_channels) + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return logits + + +class PerceiverImagePreprocessor(AbstractPreprocessor): + """ + Image preprocessing for Perceiver Encoder. + + Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to + "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the + position encoding kwargs are set equal to the *out_channels*. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + prep_type (`str`, *optional*, defaults to `"conv"`): + Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels". + spatial_downsample (`int`, *optional*, defaults to 4): + Spatial downsampling factor. + temporal_downsample (`int`, *optional*, defaults to 1): + Temporal downsampling factor (only relevant in case a time dimension is present). + position_encoding_type (`str`, *optional*, defaults to `"fourier"`): + Position encoding type. Can be "fourier" or "trainable". + in_channels (`int`, *optional*, defaults to 3): + Number of channels in the input. + out_channels (`int`, *optional*, defaults to 64): + Number of channels in the output. + conv_after_patching (`bool`, *optional*, defaults to `False`): + Whether to apply a convolutional layer after patching. + conv_after_patching_in_channels (`int`, *optional*, defaults to 54): + Number of channels in the input of the convolutional layer after patching. + conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`): + Whether to use batch normalization in the convolutional layer. + concat_or_add_pos (`str`, *optional*, defaults to `"concat"`): + How to concatenate the position encoding to the input. Can be "concat" or "add". + project_pos_dim (`int`, *optional*, defaults to -1): + Dimension of the position encoding to project to. If -1, no projection is applied. + **position_encoding_kwargs (`Dict`, *optional*): + Keyword arguments for the position encoding. + """ + + def __init__( + self, + config, + prep_type="conv", + spatial_downsample: int = 4, + temporal_downsample: int = 1, + position_encoding_type: str = "fourier", + in_channels: int = 3, + out_channels: int = 64, + conv_after_patching: bool = False, + conv_after_patching_in_channels: int = 54, # only relevant when conv_after_patching = True + conv2d_use_batchnorm: bool = True, + concat_or_add_pos: str = "concat", + project_pos_dim: int = -1, + **position_encoding_kwargs, + ): + super().__init__() + self.config = config + + if prep_type not in ("conv", "patches", "pixels", "conv1x1"): + raise ValueError(f"Prep_type {prep_type} is invalid") + + if concat_or_add_pos not in ["concat", "add"]: + raise ValueError(f"Invalid value {concat_or_add_pos} for concat_or_add_pos.") + + self.in_channels = in_channels + self.prep_type = prep_type + self.spatial_downsample = spatial_downsample + self.temporal_downsample = temporal_downsample + self.position_encoding_type = position_encoding_type + self.concat_or_add_pos = concat_or_add_pos + self.conv_after_patching = conv_after_patching + self.out_channels = out_channels + + if self.prep_type == "conv": + # Downsampling with conv is currently restricted + convnet_num_layers = math.log(spatial_downsample, 4) + convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers) + if not convnet_num_layers_is_int or temporal_downsample != 1: + raise ValueError( + "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv." + ) + self.convnet = Conv2DDownsample( + in_channels=in_channels, + num_layers=int(convnet_num_layers), + out_channels=out_channels, + use_batchnorm=conv2d_use_batchnorm, + ) + + elif self.prep_type == "conv1x1": + if temporal_downsample != 1: + raise ValueError("Conv1x1 does not downsample in time.") + self.convnet_1x1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + # spatial_downsample is unconstrained for 1x1 convolutions. + stride=(spatial_downsample, spatial_downsample), + ) + + # Position embeddings + self.project_pos_dim = project_pos_dim + self.position_embeddings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, + out_channels=out_channels, + project_pos_dim=project_pos_dim, + **position_encoding_kwargs, + ) + + # Optional convolutional layer after patches. + self.conv_after_patches = ( + nn.Linear(conv_after_patching_in_channels, self.out_channels) if conv_after_patching else nn.Identity() + ) + + @property + def num_channels(self) -> int: + # Let's assume that the number of resolutions (in the context of image preprocessing) + # of the input data is 2 or 3 depending on whether we are processing image or video respectively. + # In this case, for convenience, we will declare is_temporal variable, + # which will show whether the data has a temporal dimension or not. + is_temporal = self.position_embeddings.num_dimensions > 2 + + # position embedding + if self.project_pos_dim > 0: + pos_dim = self.project_pos_dim + else: + pos_dim = self.position_embeddings.output_size() + if self.concat_or_add_pos == "add": + return pos_dim + + # inputs + if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"): + inp_dim = self.out_channels + elif self.prep_type == "pixels": + inp_dim = self.in_channels + if not is_temporal: + inp_dim = math.ceil(inp_dim / self.spatial_downsample) + elif self.prep_type == "patches": + if self.conv_after_patching: + inp_dim = self.out_channels + else: + inp_dim = self.in_channels * self.spatial_downsample**2 + if is_temporal: + inp_dim *= self.temporal_downsample + + return inp_dim + pos_dim + + def _build_network_inputs( + self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False + ): + """ + Construct the final input, including position encoding. + + This method expects the inputs to always have channels as last dimension. + + """ + batch_size = inputs.shape[0] + input_size = inputs.shape[1:3] + index_dims = inputs.shape[1:-1] + indices = np.prod(index_dims) + + # Flatten input features to a 1D index dimension if necessary. + if len(inputs.shape) > 3 and network_input_is_1d: + inputs = torch.reshape(inputs, [batch_size, indices, -1]) + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size) + elif self.position_encoding_type == "fourier": + pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) + + # Optionally project them to a target dimension. + pos_enc = self.positions_projection(pos_enc) + + if not network_input_is_1d: + # Reshape pos to match the input feature shape + # if the network takes non-1D inputs + sh = inputs.shape + pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1]) + if self.concat_or_add_pos == "concat": + inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1) + elif self.concat_or_add_pos == "add": + inputs_with_pos = inputs + pos_enc + return inputs_with_pos, inputs + + def forward( + self, + inputs: torch.Tensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): + if self.prep_type == "conv": + # Convnet image featurization. + # Downsamples spatially by a factor of 4 + inputs = self.convnet(inputs) + + elif self.prep_type == "conv1x1": + # map inputs to self.out_channels + inputs = self.convnet_1x1(inputs) + + elif self.prep_type == "pixels": + # if requested, downsamples in the crudest way + if inputs.ndim == 4: + inputs = inputs[:: self.spatial_downsample, :: self.spatial_downsample] + elif inputs.ndim == 5: + inputs = inputs[ + :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample + ] + else: + raise ValueError("Unsupported data format for pixels.") + + elif self.prep_type == "patches": + # Space2depth featurization. + # Video: B x T x C x H x W + inputs = space_to_depth( + inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample + ) + + if inputs.ndim == 5 and inputs.shape[1] == 1: + # for flow + inputs = inputs.squeeze(dim=1) + + # Optionally apply conv layer. + inputs = self.conv_after_patches(inputs) + + if self.prep_type != "patches": + # move channels to last dimension, as the _build_network_inputs method below expects this + if inputs.ndim == 4: + inputs = inputs.permute(0, 2, 3, 1) + elif inputs.ndim == 5: + inputs = inputs.permute(0, 1, 3, 4, 2) + else: + raise ValueError("Unsupported data format for conv1x1.") + + inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding) + modality_sizes = None # Size for each modality, only needed for multimodal + + return inputs, modality_sizes, inputs_without_pos + + +class PerceiverOneHotPreprocessor(AbstractPreprocessor): + """ + One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input. + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config: PerceiverConfig) -> None: + super().__init__() + self.config: PerceiverConfig = config + + @property + def num_channels(self) -> int: + return self.config.num_labels + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): + # Add a dummy index dimension. + inputs = inputs[:, None, :] + + # No position encodings, so the 1st (input) and 3rd (inputs_without_pos) + # outputs are identical. + return inputs, None, inputs + + +class PerceiverAudioPreprocessor(AbstractPreprocessor): + """ + Audio preprocessing for Perceiver Encoder. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + prep_type (`str`, *optional*, defaults to `"patches"`): + Preprocessor type to use. Only "patches" is supported. + samples_per_patch (`int`, *optional*, defaults to 96): + Number of samples per patch. + position_encoding_type (`str`, *optional*, defaults to `"fourier"`): + Type of position encoding to use. Can be "trainable" or "fourier". + concat_or_add_pos (`str`, *optional*, defaults to `"concat"`): + How to concatenate the position encoding to the input. Can be "concat" or "add". + out_channels (`int`, *optional*, defaults to 64): + Number of channels in the output. + project_pos_dim (`int`, *optional*, defaults to -1): + Dimension of the position encoding to project to. If -1, no projection is applied. + **position_encoding_kwargs (`Dict`, *optional*): + Keyword arguments for the position encoding. + """ + + def __init__( + self, + config, + prep_type: str = "patches", + samples_per_patch: int = 96, + position_encoding_type: str = "fourier", + concat_or_add_pos: str = "concat", + out_channels=64, + project_pos_dim=-1, + **position_encoding_kwargs, + ): + super().__init__() + self.config = config + + if prep_type != "patches": + raise ValueError(f"Prep_type {prep_type} is invalid, can only be 'patches'.") + + if concat_or_add_pos not in ["concat", "add"]: + raise ValueError(f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.") + + self.samples_per_patch = samples_per_patch + self.position_encoding_type = position_encoding_type + self.concat_or_add_pos = concat_or_add_pos + self.project_pos_dim = project_pos_dim + + # Position embeddings + self.position_embeddings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, + out_channels=out_channels, + project_pos_dim=project_pos_dim, + **position_encoding_kwargs, + ) + + @property + def num_channels(self) -> int: + # position embedding + if self.project_pos_dim > 0: + pos_dim = self.project_pos_dim + else: + pos_dim = self.position_embeddings.output_size() + if self.concat_or_add_pos == "add": + return pos_dim + return self.samples_per_patch + pos_dim + + def _build_network_inputs(self, inputs): + """Construct the final input, including position encoding.""" + batch_size = inputs.shape[0] + index_dims = inputs.shape[1:-1] + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_enc = self.position_embeddings(batch_size) + elif self.position_encoding_type == "fourier": + pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) + + # Optionally project them to a target dimension. + pos_enc = self.positions_projection(pos_enc) + + if self.concat_or_add_pos == "concat": + inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1) + elif self.concat_or_add_pos == "add": + inputs_with_pos = inputs + pos_enc + + return inputs_with_pos, inputs + + def forward( + self, + inputs: torch.Tensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): + inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch]) + + inputs, inputs_without_pos = self._build_network_inputs(inputs) + modality_sizes = None # Size for each modality, only needed for multimodal + + return inputs, modality_sizes, inputs_without_pos + + +class PerceiverMultimodalPreprocessor(AbstractPreprocessor): + """ + Multimodal preprocessing for Perceiver Encoder. + + Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number + of channels. + + Args: + modalities (`Mapping[str, PreprocessorType]`): + Dict mapping modality name to preprocessor. + mask_probs (`dict[str, float]`): + Dict mapping modality name to masking probability of that modality. + min_padding_size (`int`, *optional*, defaults to 2): + The minimum padding size for all modalities. The final output will have num_channels equal to the maximum + channels across all modalities plus min_padding_size. + """ + + def __init__( + self, + modalities: Mapping[str, PreprocessorType], + mask_probs: Optional[Mapping[str, float]] = None, + min_padding_size: int = 2, + ): + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.min_padding_size = min_padding_size + self.mask_probs = mask_probs if mask_probs is not None else {} + self.padding = nn.ParameterDict( + { + modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels)) + for modality, preprocessor in modalities.items() + } + ) + self.mask = nn.ParameterDict( + {modality: nn.Parameter(torch.randn(1, self.num_channels)) for modality, _ in self.mask_probs.items()} + ) + + @property + def num_channels(self) -> int: + max_channel_size = max(processor.num_channels for _, processor in self.modalities.items()) + common_channel_size = max_channel_size + self.min_padding_size + return common_channel_size + + def forward( + self, + inputs: Mapping[str, torch.Tensor], + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ) -> PreprocessorOutputType: + padded = {} + modality_sizes = {} + inputs_without_pos = {} + for modality, preprocessor in self.modalities.items(): + # preprocess each modality using the respective preprocessor. + output, _, inputs_without_pos[modality] = preprocessor( + inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d + ) + + # pad to the same common_channel_size. + batch_size, num_samples, num_channels = output.shape + pos_enc = self.padding[modality].expand(batch_size, -1, -1) + + padding = torch.broadcast_to( + pos_enc, + [batch_size, num_samples, self.num_channels - num_channels], + ) + output_padded = torch.cat([output, padding], dim=2) + + # mask if required + if modality in self.mask_probs: + mask_token = self.mask[modality].expand(batch_size, -1, -1) + mask_prob = self.mask_probs[modality] + mask = torch.bernoulli(torch.full([batch_size, num_samples], mask_prob)) + mask = torch.unsqueeze(mask, dim=2).to(mask_token.device) + output_padded = (1 - mask) * output_padded + mask * mask_token + + padded[modality] = output_padded + modality_sizes[modality] = output_padded.shape[1] + + # Apply a predictable ordering to the modalities + padded_ls = [padded[k] for k in sorted(padded.keys())] + + # Finally, concatenate along the time dimension + final_inputs = torch.cat(padded_ls, dim=1) + + return final_inputs, modality_sizes, inputs_without_pos + + +__all__ = [ + "PerceiverForImageClassificationConvProcessing", + "PerceiverForImageClassificationFourier", + "PerceiverForImageClassificationLearned", + "PerceiverForMaskedLM", + "PerceiverForMultimodalAutoencoding", + "PerceiverForOpticalFlow", + "PerceiverForSequenceClassification", + "PerceiverLayer", + "PerceiverModel", + "PerceiverPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/perceiver/tokenization_perceiver.py b/venv/lib/python3.13/site-packages/transformers/models/perceiver/tokenization_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..f17e7e99ac9deb2712838efbee71bb4f21f37e65 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/perceiver/tokenization_perceiver.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Perceiver.""" + +from typing import Optional + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PerceiverTokenizer(PreTrainedTokenizer): + """ + Construct a Perceiver tokenizer. The Perceiver simply uses raw bytes utf-8 encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + bos_token (`str`, *optional*, defaults to `"[BOS]"`): + The BOS token (reserved in the vocab, but not actually used). + eos_token (`str`, *optional*, defaults to `"[EOS]"`): + The end of sequence token (reserved in the vocab, but not actually used). + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The MASK token, useful for masked language modeling. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The CLS token (reserved in the vocab, but not actually used). + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from two sequences. + + """ + + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + mask_token="[MASK]", + cls_token="[CLS]", + sep_token="[SEP]", + model_max_length=2048, + **kwargs, + ) -> None: + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + mask_token = AddedToken(mask_token, lstrip=False, rstrip=False) if isinstance(mask_token, str) else mask_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + + self._utf_vocab_size = 2**8 # utf is 8 bits + + # Since these tokens are not part of the vocabulary, we manually add them + self._added_tokens_decoder: dict[str, int] = { + 0: pad_token, + 1: bos_token, + 2: eos_token, + 3: mask_token, + 4: cls_token, + 5: sep_token, + } + self._num_special_tokens = len(self._added_tokens_decoder) + super().__init__( + pad_token=pad_token, + bos_token=bos_token, + eos_token=eos_token, + mask_token=mask_token, + cls_token=cls_token, + sep_token=sep_token, + model_max_length=model_max_length, + **kwargs, + ) + + def get_vocab(self) -> dict[str, int]: + vocab = {} + for i in range(self._utf_vocab_size): + token = chr(i) + vocab[token] = i + self._num_special_tokens + vocab.update(self.added_tokens_encoder) + return vocab + + @property + def vocab_size(self): + return self._utf_vocab_size + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return [1] + [0] * len(token_ids_0) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks. A sequence has the + following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + else: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + [self.sep_token_id] + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + tokens = [chr(i) for i in text.encode("utf-8")] + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if len(token) != 1: + token_id = self.unk_token_id + else: + token_id = ord(token) + self._num_special_tokens + return token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = chr(index - self._num_special_tokens) + return token + + # TODO @ArthurZ refactor this as well.... + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + bstring = b"" + for token in tokens: + if token in self.added_tokens_encoder: + tok_string = str(token).encode("utf-8") + else: + tok_string = bytes([ord(token)]) + bstring += tok_string + string = bstring.decode("utf-8", errors="replace") + return string + + # PerceiverTokenizer has no vocab file + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + return () + + +__all__ = ["PerceiverTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/phi3/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/phi3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb1e7a9cd04fb32cb8eb29516d95c0fc8e9d108 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/phi3/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_phi3 import * + from .modeling_phi3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/phi3/configuration_phi3.py b/venv/lib/python3.13/site-packages/transformers/models/phi3/configuration_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..33cee6b37ba57ed4c1010f78646274a09489e33d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/phi3/configuration_phi3.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Phi-3 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to 1.0): + Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation + "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=1.0, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type != "longrope": + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor) + if not len(rope_scaling_short_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" + ) + + +__all__ = ["Phi3Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/phi3/modeling_phi3.py b/venv/lib/python3.13/site-packages/transformers/models/phi3/modeling_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..cf803b3e2130005ef24fd2b0d87be6d60a4bf71e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/phi3/modeling_phi3.py @@ -0,0 +1,547 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi3/modular_phi3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from transformers.utils.generic import check_model_inputs + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from .configuration_phi3 import Phi3Config + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Phi3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + return hidden_states + + +@auto_docstring +class Phi3PreTrainedModel(PreTrainedModel): + config: Phi3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Phi3DecoderLayer, + "attentions": Phi3Attention, + } + _version = "0.0.5" + + +class Phi3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Phi3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Phi3Model(Phi3PreTrainedModel): + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +class Phi3ForSequenceClassification(GenericForSequenceClassification, Phi3PreTrainedModel): + pass + + +class Phi3ForTokenClassification(GenericForTokenClassification, Phi3PreTrainedModel): + pass + + +__all__ = [ + "Phi3PreTrainedModel", + "Phi3Model", + "Phi3ForCausalLM", + "Phi3ForSequenceClassification", + "Phi3ForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/phi3/modular_phi3.py b/venv/lib/python3.13/site-packages/transformers/models/phi3/modular_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..d355c3792a6b056165d6acbd0db8677bb9784728 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/phi3/modular_phi3.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-3 model.""" + +from typing import Callable, Optional + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ...utils.deprecation import deprecate_kwarg +from ..mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralPreTrainedModel, + eager_attention_forward, + rotate_half, +) +from .configuration_phi3 import Phi3Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_CONFIG_FOR_DOC = "Phi3Config" + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi3DecoderLayer(MistralDecoderLayer): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.config = config + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + return hidden_states + + +class Phi3PreTrainedModel(MistralPreTrainedModel): + _version = "0.0.5" + + +class Phi3ForCausalLM(MistralForCausalLM): + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = GenerationMixin.prepare_inputs_for_generation( + self, + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +class Phi3ForSequenceClassification(MistralForSequenceClassification): + pass + + +class Phi3ForTokenClassification(MistralForTokenClassification): + pass + + +__all__ = [ + "Phi3PreTrainedModel", + "Phi3Model", # noqa: F822 + "Phi3ForCausalLM", + "Phi3ForSequenceClassification", + "Phi3ForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pix2struct/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa645dff0494dd08019e5a5a3e0069abc9f63c88 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pix2struct import * + from .image_processing_pix2struct import * + from .modeling_pix2struct import * + from .processing_pix2struct import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/pix2struct/configuration_pix2struct.py b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/configuration_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..89109350d95a0c6e8d65c2611d75b1c4c29b7211 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/configuration_pix2struct.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pix2Struct model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Pix2StructTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate + a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Pix2Struct text decoder used by + the [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50244): + Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`Pix2StructTextModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Dimensionality of the key, query, value projections in each attention head. + d_ff (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + dense_act_fn (`Union[Callable, str]`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string). + decoder_start_token_id (`int`, *optional*, defaults to 0): + The id of the `decoder_start_token_id` token. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + pad_token_id (`int`, *optional*, defaults to 0): + The id of the `padding` token. + eos_token_id (`int`, *optional*, defaults to 1): + The id of the `end-of-sequence` token. + + Example: + + ```python + >>> from transformers import Pix2StructTextConfig, Pix2StructTextModel + + >>> # Initializing a Pix2StructTextConfig with google/pix2struct-base style configuration + >>> configuration = Pix2StructTextConfig() + + >>> # Initializing a Pix2StructTextModel (with random weights) from the google/pix2struct-base style configuration + >>> model = Pix2StructTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pix2struct_text_model" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "hidden_size", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "decoder_attention_heads": "num_heads", + "encoder_attention_heads": "num_heads", + "encoder_layers": "num_layers", + "decoder_layers": "num_layers", + } + + def __init__( + self, + vocab_size=50244, + hidden_size=768, + d_kv=64, + d_ff=2048, + num_layers=12, + num_heads=12, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + dense_act_fn="gelu_new", + decoder_start_token_id=0, + use_cache=False, + pad_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + is_decoder=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.use_cache = use_cache + + self.eos_token_id = eos_token_id + self.decoder_start_token_id = decoder_start_token_id + + # for backwards compatibility + self.dense_act_fn = dense_act_fn + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + tie_word_embeddings=tie_word_embeddings, + is_decoder=is_decoder, + **kwargs, + ) + + +class Pix2StructVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to + instantiate a Pix2Struct vision model according to the specified arguments, defining the model architecture. + Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base + [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + patch_embed_hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the input patch_embedding layer in the Transformer encoder. + d_ff (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + d_kv (`int`, *optional*, defaults to 64): + Dimensionality of the key, query, value projections per attention head. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + dense_act_fn (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + seq_len (`int`, *optional*, defaults to 4096): + Maximum sequence length (here number of patches) supported by the model. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance (in tokens) to use for each attention layer. + + Example: + + ```python + >>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel + + >>> # Initializing a Pix2StructVisionConfig with google/pix2struct-base style configuration + >>> configuration = Pix2StructVisionConfig() + + >>> # Initializing a Pix2StructVisionModel (with random weights) from the google/pix2struct-base style configuration + >>> model = Pix2StructVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pix2struct_vision_model" + + def __init__( + self, + hidden_size=768, + patch_embed_hidden_size=768, + d_ff=2048, + d_kv=64, + num_hidden_layers=12, + num_attention_heads=12, + dense_act_fn="gelu_new", + layer_norm_eps=1e-6, + dropout_rate=0.0, + attention_dropout=0.0, + initializer_range=1e-10, + initializer_factor=1.0, + seq_len=4096, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.patch_embed_hidden_size = patch_embed_hidden_size + self.d_ff = d_ff + self.dropout_rate = dropout_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.dense_act_fn = dense_act_fn + self.seq_len = seq_len + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.d_kv = d_kv + + +class Pix2StructConfig(PretrainedConfig): + r""" + [`Pix2StructConfig`] is the configuration class to store the configuration of a + [`Pix2StructForConditionalGeneration`]. It is used to instantiate a Pix2Struct model according to the specified + arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will + yield a similar configuration to that of the Pix2Struct-base + [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Pix2StructTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Pix2StructVisionConfig`]. + initializer_factor (`float`, *optional*, defaults to 1.0): + Factor to multiply the initialization range with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + is_vqa (`bool`, *optional*, defaults to `False`): + Whether the model has been fine-tuned for VQA or not. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration + + >>> # Initializing a Pix2StructConfig with google/pix2struct-base style configuration + >>> configuration = Pix2StructConfig() + + >>> # Initializing a Pix2StructForConditionalGeneration (with random weights) from the google/pix2struct-base style configuration + >>> model = Pix2StructForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig + + >>> # Initializing a Pix2Struct text and Pix2Struct vision configuration + >>> config_text = Pix2StructTextConfig() + >>> config_vision = Pix2StructVisionConfig() + + >>> config = Pix2StructConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "pix2struct" + sub_configs = {"text_config": Pix2StructTextConfig, "vision_config": Pix2StructVisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + initializer_factor=1.0, + initializer_range=0.02, + is_vqa=False, + tie_word_embeddings=False, + is_encoder_decoder=True, + **kwargs, + ): + super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the Pix2StructTextConfig with default values.") + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. Initializing the Pix2StructVisionConfig with default values.") + + text_config["is_encoder_decoder"] = is_encoder_decoder + text_config["tie_word_embeddings"] = tie_word_embeddings + self.text_config = Pix2StructTextConfig(**text_config) + self.vision_config = Pix2StructVisionConfig(**vision_config) + + self.decoder_start_token_id = self.text_config.decoder_start_token_id + self.pad_token_id = self.text_config.pad_token_id + self.eos_token_id = self.text_config.eos_token_id + + self.initializer_factor = initializer_factor + self.initializer_range = initializer_range + + self.text_config.initializer_range = self.initializer_range + self.vision_config.initializer_range = self.initializer_range + + self.is_vqa = is_vqa + + +__all__ = ["Pix2StructConfig", "Pix2StructTextConfig", "Pix2StructVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pix2struct/image_processing_pix2struct.py b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/image_processing_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..94ae657776926c653024911b0eee43d85c0f745a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/image_processing_pix2struct.py @@ -0,0 +1,464 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pix2Struct.""" + +import io +import math +from typing import Optional, Union + +import numpy as np +from huggingface_hub import hf_hub_download + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image +from ...image_utils import ( + ChannelDimension, + ImageInput, + get_image_size, + infer_channel_dimension_format, + make_flat_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, is_torch_available, is_vision_available, logging +from ...utils.import_utils import requires_backends + + +if is_vision_available(): + import textwrap + + from PIL import Image, ImageDraw, ImageFont + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) +DEFAULT_FONT_PATH = "ybelkada/fonts" + + +# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2 +def torch_extract_patches(image_tensor, patch_height, patch_width): + """ + Utility function to extract patches from a given image tensor. Returns a tensor of shape + (1, `rows`, `columns`, `num_channels`x `patch_height` x `patch_width`). + + Args: + image_tensor (torch.Tensor): + The image tensor to extract patches from. + patch_height (int): + The height of the patches to extract. + patch_width (int): + The width of the patches to extract. + """ + requires_backends(torch_extract_patches, ["torch"]) + + image_tensor = image_tensor.unsqueeze(0) + patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) + patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1) + patches = patches.permute(0, 4, 2, 3, 1).reshape( + image_tensor.size(2) // patch_height, + image_tensor.size(3) // patch_width, + image_tensor.size(1) * patch_height * patch_width, + ) + return patches.unsqueeze(0) + + +# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L106 +def render_text( + text: str, + text_size: int = 36, + text_color: str = "black", + background_color: str = "white", + left_padding: int = 5, + right_padding: int = 5, + top_padding: int = 5, + bottom_padding: int = 5, + font_bytes: Optional[bytes] = None, + font_path: Optional[str] = None, +) -> Image.Image: + """ + Render text. This script is entirely adapted from the original script that can be found here: + https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py + + Args: + text (`str`, *optional*, defaults to ): + Text to render. + text_size (`int`, *optional*, defaults to 36): + Size of the text. + text_color (`str`, *optional*, defaults to `"black"`): + Color of the text. + background_color (`str`, *optional*, defaults to `"white"`): + Color of the background. + left_padding (`int`, *optional*, defaults to 5): + Padding on the left. + right_padding (`int`, *optional*, defaults to 5): + Padding on the right. + top_padding (`int`, *optional*, defaults to 5): + Padding on the top. + bottom_padding (`int`, *optional*, defaults to 5): + Padding on the bottom. + font_bytes (`bytes`, *optional*): + Bytes of the font to use. If `None`, the default font will be used. + font_path (`str`, *optional*): + Path to the font to use. If `None`, the default font will be used. + """ + requires_backends(render_text, "vision") + # Add new lines so that each line is no more than 80 characters. + + wrapper = textwrap.TextWrapper(width=80) + lines = wrapper.wrap(text=text) + wrapped_text = "\n".join(lines) + + if font_bytes is not None and font_path is None: + font = io.BytesIO(font_bytes) + elif font_path is not None: + font = font_path + else: + font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF") + font = ImageFont.truetype(font, encoding="UTF-8", size=text_size) + + # Use a temporary canvas to determine the width and height in pixels when + # rendering the text. + temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color)) + _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font) + + # Create the actual image with a bit of padding around the text. + image_width = text_width + left_padding + right_padding + image_height = text_height + top_padding + bottom_padding + image = Image.new("RGB", (image_width, image_height), background_color) + draw = ImageDraw.Draw(image) + draw.text(xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font) + return image + + +# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87 +def render_header( + image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs +): + """ + Renders the input text as a header on the input image. + + Args: + image (`np.ndarray`): + The image to render the header on. + header (`str`): + The header text. + data_format (`Union[ChannelDimension, str]`, *optional*): + The data format of the image. Can be either "ChannelDimension.channels_first" or + "ChannelDimension.channels_last". + + Returns: + `np.ndarray`: The image with the header rendered. + """ + requires_backends(render_header, "vision") + + # Convert to PIL image if necessary + image = to_pil_image(image, input_data_format=input_data_format) + + header_image = render_text(header, **kwargs) + new_width = max(header_image.width, image.width) + + new_height = int(image.height * (new_width / image.width)) + new_header_height = int(header_image.height * (new_width / header_image.width)) + + new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white") + new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0)) + new_image.paste(image.resize((new_width, new_height)), (0, new_header_height)) + + # Convert back to the original framework if necessary + new_image = to_numpy_array(new_image) + + if infer_channel_dimension_format(new_image) == ChannelDimension.LAST: + new_image = to_channel_dimension_format(new_image, ChannelDimension.LAST) + + return new_image + + +class Pix2StructImageProcessor(BaseImageProcessor): + r""" + Constructs a Pix2Struct image processor. + + Args: + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. According to Pix2Struct paper and code, the image is normalized with its own mean and standard + deviation. + patch_size (`dict[str, int]`, *optional*, defaults to `{"height": 16, "width": 16}`): + The patch size to use for the image. According to Pix2Struct paper and code, the patch size is 16x16. + max_patches (`int`, *optional*, defaults to 2048): + The maximum number of patches to extract from the image as per the [Pix2Struct + paper](https://huggingface.co/papers/2210.03347). + is_vqa (`bool`, *optional*, defaults to `False`): + Whether or not the image processor is for the VQA task. If `True` and `header_text` is passed in, text is + rendered onto the input images. + """ + + model_input_names = ["flattened_patches", "attention_mask"] + + def __init__( + self, + do_convert_rgb: bool = True, + do_normalize: bool = True, + patch_size: Optional[dict[str, int]] = None, + max_patches: int = 2048, + is_vqa: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb + self.max_patches = max_patches + self.is_vqa = is_vqa + + def extract_flattened_patches( + self, + image: np.ndarray, + max_patches: int, + patch_size: dict, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Extract flattened patches from an image. + + Args: + image (`np.ndarray`): + Image to extract flattened patches from. + max_patches (`int`): + Maximum number of patches to extract. + patch_size (`dict`): + Dictionary containing the patch height and width. + + Returns: + result (`np.ndarray`): + A sequence of `max_patches` flattened patches. + """ + requires_backends(self.extract_flattened_patches, "torch") + + # convert to torch + image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format) + image = torch.from_numpy(image) + + patch_height, patch_width = patch_size["height"], patch_size["width"] + image_height, image_width = get_image_size(image, ChannelDimension.FIRST) + + # maximize scale s.t. + scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width)) + num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1) + num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1) + resized_height = max(num_feasible_rows * patch_height, 1) + resized_width = max(num_feasible_cols * patch_width, 1) + + image = torch.nn.functional.interpolate( + image.unsqueeze(0), + size=(resized_height, resized_width), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze(0) + + # [1, rows, columns, patch_height * patch_width * image_channels] + patches = torch_extract_patches(image, patch_height, patch_width) + + patches_shape = patches.shape + rows = patches_shape[1] + columns = patches_shape[2] + depth = patches_shape[3] + + # [rows * columns, patch_height * patch_width * image_channels] + patches = patches.reshape([rows * columns, depth]) + + # [rows * columns, 1] + row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1]) + col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1]) + + # Offset by 1 so the ids do not contain zeros, which represent padding. + row_ids += 1 + col_ids += 1 + + # Prepare additional patch features. + # [rows * columns, 1] + row_ids = row_ids.to(torch.float32) + col_ids = col_ids.to(torch.float32) + + # [rows * columns, 2 + patch_height * patch_width * image_channels] + result = torch.cat([row_ids, col_ids, patches], -1) + + # [max_patches, 2 + patch_height * patch_width * image_channels] + result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float() + + result = to_numpy_array(result) + + return result + + def normalize( + self, + image: np.ndarray, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + The image std is to mimic the tensorflow implementation of the `per_image_standardization`: + https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization + + Args: + image (`np.ndarray`): + Image to normalize. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if image.dtype == np.uint8: + image = image.astype(np.float32) + + # take mean across the whole `image` + mean = np.mean(image) + std = np.std(image) + adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape))) + + return normalize( + image, + mean=mean, + std=adjusted_stddev, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + header_text: Optional[str] = None, + do_convert_rgb: Optional[bool] = None, + do_normalize: Optional[bool] = None, + max_patches: Optional[int] = None, + patch_size: Optional[dict[str, int]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> ImageInput: + """ + Preprocess an image or batch of images. The processor first computes the maximum possible number of + aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the + image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the + images are standardized following the tensorflow implementation of `per_image_standardization` + (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization). + + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images. + header_text (`Union[list[str], str]`, *optional*): + Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + max_patches (`int`, *optional*, defaults to `self.max_patches`): + Maximum number of patches to extract. + patch_size (`dict`, *optional*, defaults to `self.patch_size`): + Dictionary containing the patch height and width. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + patch_size = patch_size if patch_size is not None else self.patch_size + max_patches = max_patches if max_patches is not None else self.max_patches + is_vqa = self.is_vqa + + if kwargs.get("data_format") is not None: + raise ValueError("data_format is not an accepted input as the outputs are ") + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if is_vqa: + if header_text is None: + raise ValueError("A header text must be provided for VQA models.") + font_bytes = kwargs.pop("font_bytes", None) + font_path = kwargs.pop("font_path", None) + + if isinstance(header_text, str): + header_text = [header_text] * len(images) + + images = [ + render_header(image, header_text[i], font_bytes=font_bytes, font_path=font_path) + for i, image in enumerate(images) + ] + + if do_normalize: + images = [self.normalize(image=image, input_data_format=input_data_format) for image in images] + + # convert to torch tensor and permute + images = [ + self.extract_flattened_patches( + image=image, max_patches=max_patches, patch_size=patch_size, input_data_format=input_data_format + ) + for image in images + ] + + # create attention mask in numpy + attention_masks = [(image.sum(axis=-1) != 0).astype(np.float32) for image in images] + + encoded_outputs = BatchFeature( + data={"flattened_patches": images, "attention_mask": attention_masks}, tensor_type=return_tensors + ) + + return encoded_outputs + + +__all__ = ["Pix2StructImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pix2struct/modeling_pix2struct.py b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/modeling_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0631611070933f54b2167de5315ad3b5ee64e8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/modeling_pix2struct.py @@ -0,0 +1,1612 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. & Google team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pix2Struct modeling file""" + +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + auto_docstring, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +# General docstring + + +# Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct +class Pix2StructLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + Pix2StructLayerNorm = FusedRMSNorm + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm") +except ImportError: + # using the normal Pix2StructLayerNorm + pass +except Exception: + logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm") + pass + + +class Pix2StructVisionEmbeddings(nn.Module): + r""" + Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models. + Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch + is represented by a vector of `hidden_size` values. + """ + + def __init__(self, config: Pix2StructConfig) -> None: + super().__init__() + self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size) + + self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size) + self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size) + + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor: + # the row and column indices are stored in the first and second position of the flattened_patches + # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2 + row_indices = flattened_patches[:, :, 0].long() + col_indices = flattened_patches[:, :, 1].long() + + flattened_patches = flattened_patches[:, :, 2:] + + embeddings = self.patch_projection(flattened_patches) + row_embeddings = self.row_embedder(row_indices) + col_embeddings = self.column_embedder(col_indices) + + # sum all embeddings together + embeddings = embeddings + row_embeddings + col_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Pix2StructVisionAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + output_attentions=False, + ): + """ + Self-attention block + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_values[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + def to_projection_shape(states): + """projection""" + return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # get query states + # (batch_size, n_heads, seq_length, dim_per_head) + query_states = to_projection_shape(self.query(hidden_states)) + + # get key/value states + key_states = to_projection_shape(self.key(hidden_states)) + value_states = to_projection_shape(self.value(hidden_states)) + + # compute scores + # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + + if attention_mask.dim() == 2: + position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) + elif attention_mask is not None: + # (batch_size, n_heads, seq_length, key_length) + position_bias = position_bias + attention_mask.to(position_bias.device) + elif not is_torchdynamo_compiling(): + attention_mask = torch.ones( + (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype + ) + position_bias = position_bias + attention_mask.to(position_bias.device) + + position_bias = 1 - position_bias + + position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) + scores += position_bias_masked + scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min)) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + # (batch_size, seq_length, dim) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + attn_output = self.output(attn_output) + + outputs = (attn_output,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate +class Pix2StructVisionMlp(nn.Module): + def __init__(self, config: Pix2StructVisionConfig): + super().__init__() + self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class Pix2StructVisionLayer(GradientCheckpointingLayer): + def __init__(self, config: Pix2StructConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Pix2StructVisionAttention(config) + self.mlp = Pix2StructVisionMlp(config) + self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + residual = hidden_states + + # in Pix2StructVision, layernorm is applied before self-attention + hidden_states = self.pre_attention_layer_norm(hidden_states) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + residual + + # in Pix2StructVision, layernorm is also applied after self-attention + layer_output = self.pre_mlp_layer_norm(hidden_states) + layer_output = self.mlp(layer_output) + hidden_states # second residual connection + + outputs = (layer_output,) + outputs + + return outputs + + +class Pix2StructVisionEncoder(nn.Module): + def __init__(self, config: Pix2StructVisionConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class Pix2StructPreTrainedModel(PreTrainedModel): + config: Pix2StructConfig + + _can_compile_fullgraph = False + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, Pix2StructLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, Pix2StructTextDenseGatedActDense): + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff + + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, Pix2StructTextAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + key_value_proj_dim = ( + self.config.text_config.d_kv if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size + ) + n_heads = ( + self.config.text_config.num_heads + if isinstance(self.config, Pix2StructConfig) + else self.config.num_heads + ) + + module.query.weight.data.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5)) + module.key.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) + module.value.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) + module.output.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + elif isinstance(module, nn.Embedding): + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + + module.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Pix2StructTextModel): + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + + module.lm_head.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, Pix2StructLayerNorm): + if module.weight is not None: + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. " + "See Pix2Struct docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +@auto_docstring +class Pix2StructVisionModel(Pix2StructPreTrainedModel): + config: Pix2StructVisionConfig + main_input_name = "flattened_patches" + supports_gradient_checkpointing = True + _no_split_modules = ["Pix2StructVisionLayer"] + + def __init__(self, config: Pix2StructVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = Pix2StructVisionEmbeddings(config) + self.encoder = Pix2StructVisionEncoder(config) + + self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_projection + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + flattened_patches: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`): + Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See + [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original + paper](https://huggingface.co/papers/2210.03347) (figure 5) for more details. + + Example: + + ```python + >>> import requests + >>> from PIL import Image + >>> from transformers import AutoProcessor, Pix2StructVisionModel + + >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") + >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base") + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 2048, 768] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if flattened_patches is None: + raise ValueError("You have to specify flattened_patches") + + if attention_mask is None: + # check where `flattened_patches` is not 0 + attention_mask = (flattened_patches.sum(dim=-1) != 0).float() + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(flattened_patches) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pix2StructText,d_model->hidden_size +class Pix2StructTextDenseGatedActDense(nn.Module): + def __init__(self, config: Pix2StructTextConfig): + super().__init__() + self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class Pix2StructTextLayerFF(nn.Module): + def __init__(self, config: Pix2StructTextConfig): + super().__init__() + self.DenseReluDense = Pix2StructTextDenseGatedActDense(config) + + self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Copied from transformers.models.t5.modeling_t5.T5LayerFF.forward + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class Pix2StructTextAttention(nn.Module): + def __init__( + self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None + ): + super().__init__() + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.hidden_size = config.hidden_size + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + @staticmethod + # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=False, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_values=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.query(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.key(current_states) + value_states = self.value(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_values.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.output(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size +class Pix2StructTextLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.attention = Pix2StructTextAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size +class Pix2StructTextLayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class Pix2StructTextBlock(GradientCheckpointingLayer): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + + self.self_attention = Pix2StructTextLayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + + self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention( + config, + layer_idx=layer_idx, + ) + + self.mlp = Pix2StructTextLayerFF(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.self_attention( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.encoder_decoder_attention( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.mlp(hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + attention_outputs + + +@auto_docstring( + custom_intro=""" + The standalone text decoder of Pix2Struct + """ +) +class Pix2StructTextModel(Pix2StructPreTrainedModel): + config: Pix2StructTextConfig + _no_split_modules = ["Pix2StructTextBlock"] + _tied_weights_keys = ["lm_head.weight"] + supports_gradient_checkpointing = True + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layer = nn.ModuleList( + [ + Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + self.gradient_checkpointing = False + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText + Training](./t5#training). + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoProcessor, Pix2StructTextModel + + >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") + >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base") + + >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> loss = outputs.loss + ``` + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache( + DynamicCache(config=self.config), DynamicCache(config=self.config) + ) + else: + past_key_values = DynamicCache(config=self.config) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + if encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") + + loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) + + if not return_dict: + return tuple( + v + for v in [ + loss, + logits, + past_key_values, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring( + custom_intro=""" + A conditional generation model with a language modeling head. Can be used for sequence generation tasks. + """ +) +class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin): + config: Pix2StructConfig + main_input_name = "flattened_patches" + _tied_weights_keys = ["decoder.lm_head.weight"] + + def __init__(self, config: Pix2StructConfig): + super().__init__(config) + + self.encoder = Pix2StructVisionModel(config.vision_config) + self.decoder = Pix2StructTextModel(config.text_config) + + self.is_vqa = config.is_vqa + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.decoder.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.decoder.set_output_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + flattened_patches: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`): + Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` = + `num_channels` * `patch_size` * `patch_size` + + The process of flattening the pixel patches is done by `Pix2StructProcessor`. + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss for the decoder. + + Example: + + Inference: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") + >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> # autoregressive generation + >>> generated_ids = model.generate(**inputs, max_new_tokens=50) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print(generated_text) + A stop sign is on a street corner. + + >>> # conditional generation + >>> text = "A picture of" + >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=50) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print(generated_text) + A picture of a stop sign with a red stop sign + ``` + + Training: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base") + >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base") + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "A stop sign is on the street corner." + + >>> inputs = processor(images=image, return_tensors="pt") + >>> labels = processor(text=text, return_tensors="pt").input_ids + + >>> # forward pass + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> print(f"{loss.item():.5f}") + 5.94282 + ```""" + use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + flattened_patches=flattened_patches, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + decoder_attention_mask = ( + decoder_attention_mask + if decoder_attention_mask is not None + else decoder_input_ids.ne(self.config.pad_token_id).float() + ) + # Always attend to the first token + decoder_attention_mask[:, 0] = 1 + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + labels=labels, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +__all__ = [ + "Pix2StructPreTrainedModel", + "Pix2StructForConditionalGeneration", + "Pix2StructVisionModel", + "Pix2StructTextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pix2struct/processing_pix2struct.py b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/processing_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..f21dd5d7a0025e16b6ed15edcc14b29abe404ad2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pix2struct/processing_pix2struct.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Pix2Struct. +""" + +from typing import Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput +from ...utils import logging + + +class Pix2StructImagesKwargs(ImagesKwargs, total=False): + max_patches: Optional[int] + header_text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] + + +class Pix2StructProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Pix2StructImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "max_patches": 2048, + }, + } + + +logger = logging.get_logger(__name__) + + +class Pix2StructProcessor(ProcessorMixin): + r""" + Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single + processor. + + [`Pix2StructProcessor`] offers all the functionalities of [`Pix2StructImageProcessor`] and [`T5TokenizerFast`]. See + the docstring of [`~Pix2StructProcessor.__call__`] and [`~Pix2StructProcessor.decode`] for more information. + + Args: + image_processor (`Pix2StructImageProcessor`): + An instance of [`Pix2StructImageProcessor`]. The image processor is a required input. + tokenizer (Union[`T5TokenizerFast`, `T5Tokenizer`]): + An instance of ['T5TokenizerFast`] or ['T5Tokenizer`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "Pix2StructImageProcessor" + tokenizer_class = ("T5Tokenizer", "T5TokenizerFast") + + def __init__(self, image_processor, tokenizer): + tokenizer.return_token_type_ids = False + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images=None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Pix2StructProcessorKwargs], + ) -> Union[BatchEncoding, BatchFeature]: + """ + This method uses [`Pix2StructImageProcessor.preprocess`] method to prepare image(s) for the model, and + [`T5TokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + """ + if images is None and text is None: + raise ValueError("You have to specify either images or text.") + + output_kwargs = self._merge_kwargs( + Pix2StructProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None) + # Get only text + if images is None and not self.image_processor.is_vqa: + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + add_special_tokens if add_special_tokens is not None else True + ) + self.current_processor = self.tokenizer + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + return text_encoding + + if not self.image_processor.is_vqa: + # add pixel_values + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + else: + # add pixel_values and bbox + output_kwargs["images_kwargs"].setdefault("header_text", text) + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text is not None and not self.image_processor.is_vqa: + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + add_special_tokens if add_special_tokens is not None else False + ) + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + + if "attention_mask" in text_encoding: + text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask") + if "input_ids" in text_encoding: + text_encoding["decoder_input_ids"] = text_encoding.pop("input_ids") + else: + text_encoding = None + + if text_encoding is not None: + encoding_image_processor.update(text_encoding) + + return encoding_image_processor + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + decoder_ids = ["decoder_attention_mask", "decoder_input_ids"] + return image_processor_input_names + decoder_ids + + +__all__ = ["Pix2StructProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pixtral/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/pixtral/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8afbbfb1f47f71d478a9ed475f36719d405b4124 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pixtral/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pixtral import * + from .image_processing_pixtral import * + from .image_processing_pixtral_fast import * + from .modeling_pixtral import * + from .processing_pixtral import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/pixtral/configuration_pixtral.py b/venv/lib/python3.13/site-packages/transformers/models/pixtral/configuration_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..d4710e00e421662541e8f2c0417e2719de847ee9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pixtral/configuration_pixtral.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pixtral model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PixtralVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an + Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B. + + e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of input channels in the input images. + image_size (`int`, *optional*, defaults to 1024): + Max dimension of the input images. + patch_size (`int`, *optional*, defaults to 16): + Size of the image patches. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + Activation function used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the attention layers. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import PixtralVisionModel, PixtralVisionConfig + + >>> # Initializing a Pixtral-12B style configuration + >>> config = PixtralVisionConfig() + + >>> # Initializing a model (with randomly initialized weights) from the configuration + >>> model = PixtralVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pixtral" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + attention_dropout=0.0, + rope_theta=10000.0, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.rope_theta = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.initializer_range = initializer_range + + +__all__ = ["PixtralVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pixtral/image_processing_pixtral.py b/venv/lib/python3.13/site-packages/transformers/models/pixtral/image_processing_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..33e7676f9de9f3d8a233d6d8c2f8294086d72874 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pixtral/image_processing_pixtral.py @@ -0,0 +1,473 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pixtral.""" + +import math +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + pad, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging +from ...utils.import_utils import requires_backends + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white. +def convert_to_rgb(image: ImageInput) -> ImageInput: + """ + Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image + as is. + Args: + image (Image): + The image to convert. + """ + requires_backends(convert_to_rgb, ["vision"]) + + if not isinstance(image, PIL.Image.Image): + return image + + if image.mode == "RGB": + return image + + # First we convert to RGBA to set background to white. + image = image.convert("RGBA") + + # Create a new image with a white background. + new_image = PIL.Image.new("RGBA", image.size, "WHITE") + new_image.paste(image, (0, 0), image) + new_image = new_image.convert("RGB") + return new_image + + +def _num_image_tokens(image_size: tuple[int, int], patch_size: tuple[int, int]) -> int: + """ + Calculate the number of image tokens given the image size and patch size. + + Args: + image_size (`tuple[int, int]`): + The size of the image as `(height, width)`. + patch_size (`tuple[int, int]`): + The patch size as `(height, width)`. + + Returns: + `int`: The number of image tokens. + """ + height, width = image_size + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + num_width_tokens = (width - 1) // patch_width + 1 + num_height_tokens = (height - 1) // patch_height + 1 + return num_height_tokens, num_width_tokens + + +def get_resize_output_image_size( + input_image: ImageInput, + size: Union[int, tuple[int, int], list[int], tuple[int]], + patch_size: Union[int, tuple[int, int], list[int], tuple[int]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple: + """ + Find the target (height, width) dimension of the output image after resizing given the input image and the desired + size. + + Args: + input_image (`ImageInput`): + The image to resize. + size (`int` or `tuple[int, int]`): + Max image size an input image can be. Must be a dictionary with the key "longest_edge". + patch_size (`int` or `tuple[int, int]`): + The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)` + will be used + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `tuple`: The target (height, width) dimension of the output image after resizing. + """ + max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size) + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + height, width = get_image_size(input_image, input_data_format) + + ratio = max(height / max_height, width / max_width) + + if ratio > 1: + # Original implementation uses `round` which utilises bankers rounding, which can lead to surprising results + # Here we use floor to ensure the image is always smaller than the given "longest_edge" + height = int(math.floor(height / ratio)) + width = int(math.floor(width / ratio)) + + num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) + return num_height_tokens * patch_height, num_width_tokens * patch_width + + +class PixtralImageProcessor(BaseImageProcessor): + r""" + Constructs a Pixtral image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`): + Size of the maximum dimension of either the height or width dimension of the image. Used to control how + images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)` + patch_size (`dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values", "image_sizes"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + patch_size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 1024} + patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + patch_size = get_size_dict(patch_size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.patch_size = patch_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] + self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "patch_size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + patch_size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dict containing the longest possible edge of the image. + patch_size (`dict[str, int]`): + Patch size used to calculate the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "longest_edge" in size: + size = (size["longest_edge"], size["longest_edge"]) + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") + + if "height" in patch_size and "width" in patch_size: + patch_size = (patch_size["height"], patch_size["width"]) + else: + raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + patch_size=patch_size, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _pad_for_batching( + self, + pixel_values: list[np.ndarray], + image_sizes: list[list[int]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + Args: + pixel_values (`list[np.ndarray]`): + An array of pixel values of each images of shape (`batch_size`, `height`, `width`, `channels`) + image_sizes (`list[list[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + Returns: + list[`np.ndarray`]: The padded images. + """ + + max_shape = ( + max(size[0] for size in image_sizes), + max(size[1] for size in image_sizes), + ) + pixel_values = [ + pad( + image, + padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])), + data_format=data_format, + input_data_format=input_data_format, + ) + for image, size in zip(pixel_values, image_sizes) + ] + return pixel_values + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + patch_size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. + patch_size (`dict[str, int]`, *optional*, defaults to `self.patch_size`): + Patch size in the model. Used to calculate the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + patch_size = patch_size if patch_size is not None else self.patch_size + patch_size = get_size_dict(patch_size, default_to_square=True) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images[0]): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + batch_images = [] + batch_image_sizes = [] + for image in images: + if do_resize: + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + batch_images.append(image) + batch_image_sizes.append(get_image_size(image, data_format)) + + pixel_values = self._pad_for_batching( + pixel_values=batch_images, + image_sizes=batch_image_sizes, + input_data_format=data_format, + data_format=data_format, + ) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors + ) + + +__all__ = ["PixtralImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pixtral/image_processing_pixtral_fast.py b/venv/lib/python3.13/site-packages/transformers/models/pixtral/image_processing_pixtral_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..b31f910e481713ca916c920b95a3ecabb318de38 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pixtral.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ImageInput, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + logging, +) +from .image_processing_pixtral import get_resize_output_image_size + + +logger = logging.get_logger(__name__) + + +class PixtralFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + patch_size (`dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. + """ + + patch_size: Optional[dict[str, int]] + + +@auto_docstring +class PixtralImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = [0.48145466, 0.4578275, 0.40821073] + image_std = [0.26862954, 0.26130258, 0.27577711] + patch_size = {"height": 16, "width": 16} + size = {"longest_edge": 1024} + default_to_square = True + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + valid_kwargs = PixtralFastImageProcessorKwargs + + model_input_names = ["pixel_values", "image_sizes"] + + def __init__(self, **kwargs: Unpack[PixtralFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[PixtralFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + patch_size: SizeDict, + interpolation: Optional["F.InterpolationMode"] = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dict containing the longest possible edge of the image. + patch_size (`SizeDict`): + Patch size used to calculate the size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use when resiizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.longest_edge: + size = (size.longest_edge, size.longest_edge) + elif size.height and size.width: + size = (size.height, size.width) + else: + raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") + + if patch_size.height and patch_size.width: + patch_size = (patch_size.height, patch_size.width) + else: + raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size(image, size=size, patch_size=patch_size) + return F.resize(image, size=output_size, interpolation=interpolation, **kwargs) + + # Adapted from transformers.models.pixtral.image_processing_pixtral.PixtralImageProcessor._pad_for_batching + def _pad_for_batching( + self, + pixel_values: list[torch.Tensor], + image_sizes: list[list[int]], + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + Args: + pixel_values (`list[torch.Tensor]`): + An array of pixel values of each images of shape (`batch_size`, `channels`, `height`, `width`) + image_sizes (`list[list[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + Returns: + list[`torch.Tensor`]: The padded images. + """ + + max_shape = (max(size[0] for size in image_sizes), max(size[1] for size in image_sizes)) + pixel_values = [ + torch.nn.functional.pad(image, pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0])) + for image, size in zip(pixel_values, image_sizes) + ] + return torch.stack(pixel_values) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + patch_size: dict[str, int], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: dict[str, int], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + patch_size = get_size_dict(patch_size, default_to_square=True) + patch_size = SizeDict(**patch_size) + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, patch_size=patch_size, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + batch_image_sizes = [grouped_images_index[i][0] for i in range(len(grouped_images_index))] + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + padded_images = self._pad_for_batching( + pixel_values=processed_images, + image_sizes=batch_image_sizes, + ) + + return BatchFeature( + data={"pixel_values": padded_images, "image_sizes": batch_image_sizes}, tensor_type=return_tensors + ) + + +__all__ = ["PixtralImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pixtral/modeling_pixtral.py b/venv/lib/python3.13/site-packages/transformers/models/pixtral/modeling_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..79bf0ee6bbdad7c5545408992ce073a7895a60e9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pixtral/modeling_pixtral.py @@ -0,0 +1,516 @@ +# coding=utf-8 +# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Pixtral model.""" + +from collections.abc import Callable +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput +from ...modeling_rope_utils import dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, logging +from .configuration_pixtral import PixtralVisionConfig + + +logger = logging.get_logger(__name__) + + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +class PixtralRotaryEmbedding(nn.Module): + """ + The key with pixtral embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels), then the frequency used for ROPE + is given by indexing the pre_computed frequency on the width and height. + + What you output is of dimension (batch, height * width, dim) with dim the embed dim. + + This simply means that for each image hidden state, you are going to add + a corresponding positional embedding, based on its index in the grid. + """ + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config, device=None): + super().__init__() + self.rope_type = "default" + self.dim = config.head_dim + self.base = config.rope_theta + max_patches_per_side = config.image_size // config.patch_size + freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + + h = torch.arange(max_patches_per_side, device=freqs.device) + w = torch.arange(max_patches_per_side, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes + # Different from paper, but it uses a different permutation in order to obtain the same calculation + + # TODO maybe make it torch compatible later on. We can also just slice + self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False) + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + freqs = self.inv_freq[position_ids] + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + emb = freqs + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class PixtralAttention(nn.Module): + """ + Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = False + + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # Since we use packing, if flash_attention_2 is selected we rely on position_ids + if self.config._attn_implementation == "flash_attention_2": + kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral +class PixtralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral +class PixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + PixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class PixtralAttentionLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.feed_forward = PixtralMLP(config) + self.attention = PixtralAttention(config) + self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + return outputs + + +class PixtralTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layers = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(PixtralAttentionLayer(config)) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embeddings which serve as input to the Transformer. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +@auto_docstring +class PixtralPreTrainedModel(PreTrainedModel): + config: PixtralVisionConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _no_split_modules = ["PixtralAttentionLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, PixtralRMSNorm): + module.weight.data.fill_(1.0) + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + +@auto_docstring +class PixtralVisionModel(PixtralPreTrainedModel): + base_model_prefix = "vision_encoder" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.patch_size = config.patch_size + self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.transformer = PixtralTransformer(config) + self.patch_positional_embedding = PixtralRotaryEmbedding(config) + + self.post_init() + + def get_input_embeddings(self): + return self.patch_conv + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + image_sizes: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + *args, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_sizes is None: + batch_size, _, height, width = pixel_values.shape + image_sizes = [(height, width)] * batch_size + + # pass images through initial convolution independently + patch_embeds = self.patch_conv(pixel_values) + patch_embeds_list = [ + embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)] + for embed, size in zip(patch_embeds, image_sizes) + ] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + ) + kwargs["position_ids"] = position_ids + + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) + + if self.config._attn_implementation == "flash_attention_2": + # We only rely on position_ids when using flash_attention_2 + attention_mask = None + else: + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + + return self.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + **kwargs, + ) + + +__all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pixtral/processing_pixtral.py b/venv/lib/python3.13/site-packages/transformers/models/pixtral/processing_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..bb868156fb40b4f6874124f3c210dee48b5abc36 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pixtral/processing_pixtral.py @@ -0,0 +1,262 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Pixtral. +""" + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import ( + MultiModalData, + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_vision_available, logging + + +if is_vision_available(): + from .image_processing_pixtral import get_resize_output_image_size + + +logger = logging.get_logger(__name__) + + +class PixtralProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + +# Copied from transformers.models.idefics2.processing_idefics2.is_url +def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url +def is_image_or_image_url(elem): + return is_url(elem) or is_valid_image(elem) + + +class PixtralProcessor(ProcessorMixin): + r""" + Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor. + + [`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information. + + Args: + image_processor ([`PixtralImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*, defaults to 16): + Patch size from the vision tower. + spatial_merge_size (`int`, *optional*, defaults to 1): + The downsampling factor for the spatial merge operation. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + image_token (`str`, *optional*, defaults to `"[IMG]"`): + Special token used to denote image location. + image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`): + Special token used to denote the end of a line of pixels in an image. + image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`): + Special token used to denote the end of an image input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 16, + spatial_merge_size: int = 1, + chat_template=None, + image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases + image_break_token="[IMG_BREAK]", + image_end_token="[IMG_END]", + **kwargs, + ): + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.image_token = image_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.image_break_token = image_break_token + self.image_end_token = image_end_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.image_break_token_id = tokenizer.convert_tokens_to_ids(self.image_break_token) + self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token) + self.image_ids = [self.image_token_id, self.image_break_token_id, self.image_end_token_id] + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[PixtralProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + output_kwargs = self._merge_kwargs( + PixtralProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + patch_size = self.patch_size * self.spatial_merge_size + + if images is not None: + image_inputs = self.image_processor(images, patch_size=patch_size, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = text + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + image_sizes = iter(image_inputs["image_sizes"]) + prompt_strings = [] + replace_strings = [] + + for sample in text: + while self.image_token in sample: + height, width = next(image_sizes) + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size + replace_tokens = [ + [self.image_token] * num_width_tokens + [self.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + sample = sample.replace(self.image_token, "", 1) + + while "" in sample: + replace_str = replace_strings.pop(0) + sample = sample.replace("", replace_str, 1) + prompt_strings.append(sample) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + vision_data = {} + if image_sizes is not None: + images_kwargs = PixtralProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + size = images_kwargs.get("size", None) or self.image_processor.size + patch_size = self.patch_size * self.spatial_merge_size + + num_image_tokens = [] + for height, width in image_sizes: + resized_height, resized_width = get_resize_output_image_size( + np.zeros((height, width, 3)), + size=(size["longest_edge"], size["longest_edge"]), + patch_size=(patch_size, patch_size), + ) + num_height_tokens = resized_height // patch_size + num_width_tokens = resized_width // patch_size + num_image_tokens.append((num_width_tokens + 1) * num_height_tokens) + + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return tokenizer_input_names + image_processor_input_names + ["image_sizes"] + + +__all__ = ["PixtralProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/plbart/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/plbart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33e5618ca7e1727c0ce28458d952ce0bfbbfee59 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/plbart/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_plbart import * + from .modeling_plbart import * + from .tokenization_plbart import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/plbart/configuration_plbart.py b/venv/lib/python3.13/site-packages/transformers/models/plbart/configuration_plbart.py new file mode 100644 index 0000000000000000000000000000000000000000..a4aaa3ff3703407bd86abe1af61d621aea326025 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/plbart/configuration_plbart.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PLBART model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PLBartConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PLBartModel`]. It is used to instantiate an + PLBART model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PLBART + [uclanlp/plbart-base](https://huggingface.co/uclanlp/plbart-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50005): + Vocabulary size of the PLBART model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PLBartModel`]. + d_model (`int`, *optional*, defaults to 768): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 2): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import PLBartConfig, PLBartModel + + >>> # Initializing a PLBART uclanlp/plbart-base style configuration + >>> configuration = PLBartConfig() + + >>> # Initializing a model (with random weights) from the uclanlp/plbart-base style configuration + >>> model = PLBartModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "plbart" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model", + "initializer_range": "init_std", + } + + def __init__( + self, + vocab_size=50005, + max_position_embeddings=1024, + encoder_layers=6, + encoder_ffn_dim=3072, + encoder_attention_heads=12, + decoder_layers=6, + decoder_ffn_dim=3072, + decoder_attention_heads=12, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=768, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + forced_eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + +class PLBartOnnxConfig(OnnxConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.use_past: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("past_keys", {0: "batch", 2: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) + else: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) + + +__all__ = ["PLBartConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/plbart/modeling_plbart.py b/venv/lib/python3.13/site-packages/transformers/models/plbart/modeling_plbart.py new file mode 100644 index 0000000000000000000000000000000000000000..60239bf9ac541d2446a90a87e3a90b800b7ce5bd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/plbart/modeling_plbart.py @@ -0,0 +1,1724 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/plbart/modular_plbart.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_plbart.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_plbart import PLBartConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +class PLBartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +@auto_docstring +class PLBartPreTrainedModel(PreTrainedModel): + config: PLBartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class PLBartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) + + return super().forward(position_ids + self.offset) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class PLBartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PLBartConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class PLBartEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class PLBartEncoder(PLBartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`PLBartEncoderLayer`]. + + Args: + config: PLBartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = PLBartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([PLBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class PLBartDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PLBartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class PLBartDecoder(PLBartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`] + + Args: + config: PLBartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = PLBartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([PLBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + attention_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): + """ + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that PLBart does not + have a single `decoder_start_token_id` in contrast to other Bart-like models. + """ + prev_output_tokens = input_ids.clone() + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + + index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + prev_output_tokens[:, 0] = decoder_start_tokens + + return prev_output_tokens + + +@auto_docstring +class PLBartModel(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = PLBartEncoder(config, self.shared) + self.decoder = PLBartDecoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # different to other models, PLBart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code. + """ +) +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + self.model = PLBartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example Mask-filling: + + ```python + >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration + + >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base") + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + + >>> # en_XX is the language symbol id for English + >>> TXT = " Is 0 the Fibonacci number ? en_XX" + >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids + + >>> logits = model(input_ids).logits + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['first', 'same', 'highest', 'result', 'number'] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + + +class PLBartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +@auto_docstring( + custom_intro=""" + PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. + for GLUE tasks. + """ +) +class PLBartForSequenceClassification(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PLBartModel(config) + self.classification_head = PLBartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +class PLBartDecoderWrapper(PLBartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PLBartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@auto_docstring( + custom_intro=""" + PLBART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings). + """ +) +class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = PLBartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PLBartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +__all__ = [ + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/plbart/modular_plbart.py b/venv/lib/python3.13/site-packages/transformers/models/plbart/modular_plbart.py new file mode 100644 index 0000000000000000000000000000000000000000..29c253144557139a1653dafc40fafb2f51e3b343 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/plbart/modular_plbart.py @@ -0,0 +1,659 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PLBART model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_torch_flex_attn_available +from ..bart.modeling_bart import ( + BartClassificationHead, + BartDecoder, + BartEncoder, + BartForCausalLM, + BartScaledWordEmbedding, +) +from ..bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForSequenceClassification +from ..mbart.modeling_mbart import shift_tokens_right +from .configuration_plbart import PLBartConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +class PLBartScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +@auto_docstring +class PLBartPreTrainedModel(PreTrainedModel): + config: PLBartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class PLBartEncoder(BartEncoder): + pass + + +class PLBartDecoder(BartDecoder): + pass + + +@auto_docstring +class PLBartModel(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = PLBartEncoder(config, self.shared) + self.decoder = PLBartDecoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # different to other models, PLBart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code. + """ +) +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + self.model = PLBartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example Mask-filling: + + ```python + >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration + + >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base") + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + + >>> # en_XX is the language symbol id for English + >>> TXT = " Is 0 the Fibonacci number ? en_XX" + >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids + + >>> logits = model(input_ids).logits + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['first', 'same', 'highest', 'result', 'number'] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + + +class PLBartClassificationHead(BartClassificationHead): + pass + + +class PLBartForSequenceClassification(BigBirdPegasusForSequenceClassification): + def forward(**super_kwargs): + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + super().forward(**super_kwargs) + + +class PLBartForCausalLM(BartForCausalLM): + @auto_docstring + def forward(**super_kwargs): + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PLBartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + super().forward(**super_kwargs) + + +__all__ = [ + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/plbart/tokenization_plbart.py b/venv/lib/python3.13/site-packages/transformers/models/plbart/tokenization_plbart.py new file mode 100644 index 0000000000000000000000000000000000000000..1eed6dc18bddd7b9ba0fa68dc31186fd35c03a99 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/plbart/tokenization_plbart.py @@ -0,0 +1,432 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +FAIRSEQ_LANGUAGE_CODES = { + "base": ["__java__", "__python__", "__en_XX__"], + "multi": ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"], +} + +FAIRSEQ_LANGUAGE_CODES_MAP = { + "java": "__java__", + "python": "__python__", + "en_XX": "__en_XX__", + "javascript": "__javascript__", + "php": "__php__", + "ruby": "__ruby__", + "go": "__go__", +} + + +@requires(backends=("sentencepiece",)) +class PLBartTokenizer(PreTrainedTokenizer): + """ + Construct an PLBART tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and ` + ` for target language documents. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + src_lang (`str`, *optional*): + A string representing the source language. + tgt_lang (`str`, *optional*): + A string representing the target language. + bos_token (`str`, *optional*, defaults to `""`): + The start of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The cls token, which is a special token used as the first token for all tasks. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token(`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masking tasks. This + is only used in the `"base"` tokenizer type. For `"multi"` tokenizer, masking is never done for the + downstream tasks. + language_codes (`str`, *optional*, defaults to `"base"`): + What language codes to use. Should be one of `"base"` or `"multi"`. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Examples: + + ```python + >>> from transformers import PLBartTokenizer + + >>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX") + >>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])" + >>> expected_translation_english = "Returns the maximum value of a b c." + >>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt") + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: list[int] = [] + suffix_tokens: list[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + language_codes="base", + tokenizer_file=None, + src_lang=None, + tgt_lang=None, + sp_model_kwargs: Optional[dict[str, Any]] = None, + additional_special_tokens=None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + src_lang = self._convert_lang_code_special_format(src_lang) + tgt_lang = self._convert_lang_code_special_format(tgt_lang) + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + self.language_codes = language_codes + + fairseq_language_codes = FAIRSEQ_LANGUAGE_CODES[self.language_codes] + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(fairseq_language_codes) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + + if self.language_codes == "base": + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + _additional_special_tokens = list(self.lang_code_to_id.keys()) + + if additional_special_tokens is not None: + # Only add those special tokens if they are not already there. + _additional_special_tokens.extend( + [t for t in additional_special_tokens if t not in _additional_special_tokens] + ) + + if self.language_codes == "base": + self._src_lang = src_lang + self.cur_lang_code_id = ( + self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang + ) + else: + self._src_lang = src_lang if src_lang is not None else "__en_XX__" + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + language_codes=language_codes, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=_additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + if self.language_codes == "base": + return ( + len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 + ) # Plus 1 for the mask token + else: + return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + new_src_lang = self._convert_lang_code_special_format(new_src_lang) + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An PLBART sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. PLBart does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = self._convert_lang_code_special_format(src_lang) + self.tgt_lang = self._convert_lang_code_special_format(tgt_lang) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(self.tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> list[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def prepare_seq2seq_batch( + self, + src_texts: list[str], + src_lang: str = "en_XX", + tgt_texts: Optional[list[str]] = None, + tgt_lang: str = "python", + **kwargs, + ) -> BatchEncoding: + self.src_lang = self._convert_lang_code_special_format(src_lang) + self.tgt_lang = self._convert_lang_code_special_format(tgt_lang) + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" + src_lang = self._convert_lang_code_special_format(src_lang) + self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None + self.prefix_tokens = [] + if self.cur_lang_code is not None: + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code].""" + lang = self._convert_lang_code_special_format(lang) + + self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None + self.prefix_tokens = [] + if self.cur_lang_code is not None: + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.suffix_tokens = [self.eos_token_id] + + def _convert_lang_code_special_format(self, lang: str) -> str: + """Convert Language Codes to format tokenizer uses if required""" + lang = FAIRSEQ_LANGUAGE_CODES_MAP.get(lang, lang) + return lang + + +__all__ = ["PLBartTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pop2piano/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbfcbf157b58aa582e2a8b2a99c7d561372904b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pop2piano import * + from .modeling_pop2piano import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/pop2piano/configuration_pop2piano.py b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/configuration_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc90961154bbbefe560706052b6ab492a7adaae --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/configuration_pop2piano.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pop2Piano model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Pop2PianoConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Pop2PianoForConditionalGeneration`]. It is used + to instantiate a Pop2PianoForConditionalGeneration model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + Pop2Piano [sweetcocoa/pop2piano](https://huggingface.co/sweetcocoa/pop2piano) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 2400): + Vocabulary size of the `Pop2PianoForConditionalGeneration` model. Defines the number of different tokens + that can be represented by the `inputs_ids` passed when calling [`Pop2PianoForConditionalGeneration`]. + composer_vocab_size (`int`, *optional*, defaults to 21): + Denotes the number of composers. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `Pop2PianoBlock`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"gated-gelu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + dense_act_fn (`string`, *optional*, defaults to `"relu"`): + Type of Activation Function to be used in `Pop2PianoDenseActDense` and in `Pop2PianoDenseGatedActDense`. + """ + + model_type = "pop2piano" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=2400, + composer_vocab_size=21, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="gated-gelu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + dense_act_fn="relu", + **kwargs, + ): + self.vocab_size = vocab_size + self.composer_vocab_size = composer_vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + self.dense_act_fn = dense_act_fn + self.is_gated_act = self.feed_forward_proj.split("-")[0] == "gated" + self.hidden_size = self.d_model + self.num_attention_heads = num_heads + self.num_hidden_layers = num_layers + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +__all__ = ["Pop2PianoConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pop2piano/feature_extraction_pop2piano.py b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/feature_extraction_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..97255f969aa5f2b35db1bfd38001cf0bdd1e6730 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/feature_extraction_pop2piano.py @@ -0,0 +1,455 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Pop2Piano""" + +import warnings +from typing import Optional, Union + +import numpy +import numpy as np + +from ...audio_utils import mel_filter_bank, spectrogram +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import ( + TensorType, + is_essentia_available, + is_librosa_available, + is_scipy_available, + logging, + requires_backends, +) +from ...utils.import_utils import requires + + +if is_essentia_available(): + import essentia + import essentia.standard + +if is_librosa_available(): + import librosa + +if is_scipy_available(): + import scipy + + +logger = logging.get_logger(__name__) + + +@requires(backends=("essentia", "librosa", "scipy", "torch")) +class Pop2PianoFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Pop2Piano feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed + to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as + well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate + extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate + and preprocessed and then log mel spectogram is computed from that to be used in our transformer model. + + Args: + sampling_rate (`int`, *optional*, defaults to 22050): + Target Sampling rate of audio signal. It's the sampling rate that we forward to the model. + padding_value (`int`, *optional*, defaults to 0): + Padding value used to pad the audio. Should correspond to silences. + window_size (`int`, *optional*, defaults to 4096): + Length of the window in samples to which the Fourier transform is applied. + hop_length (`int`, *optional*, defaults to 1024): + Step size between each window of the waveform, in samples. + min_frequency (`float`, *optional*, defaults to 10.0): + Lowest frequency that will be used in the log-mel spectrogram. + feature_size (`int`, *optional*, defaults to 512): + The feature dimension of the extracted features. + num_bars (`int`, *optional*, defaults to 2): + Determines interval between each sequence. + """ + + model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"] + + def __init__( + self, + sampling_rate: int = 22050, + padding_value: int = 0, + window_size: int = 4096, + hop_length: int = 1024, + min_frequency: float = 10.0, + feature_size: int = 512, + num_bars: int = 2, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + **kwargs, + ) + self.sampling_rate = sampling_rate + self.padding_value = padding_value + self.window_size = window_size + self.hop_length = hop_length + self.min_frequency = min_frequency + self.feature_size = feature_size + self.num_bars = num_bars + self.mel_filters = mel_filter_bank( + num_frequency_bins=(self.window_size // 2) + 1, + num_mel_filters=self.feature_size, + min_frequency=self.min_frequency, + max_frequency=float(self.sampling_rate // 2), + sampling_rate=self.sampling_rate, + norm=None, + mel_scale="htk", + ) + + def mel_spectrogram(self, sequence: np.ndarray): + """ + Generates MelSpectrogram. + + Args: + sequence (`numpy.ndarray`): + The sequence of which the mel-spectrogram will be computed. + """ + mel_specs = [] + for seq in sequence: + window = np.hanning(self.window_size + 1)[:-1] + mel_specs.append( + spectrogram( + waveform=seq, + window=window, + frame_length=self.window_size, + hop_length=self.hop_length, + power=2.0, + mel_filters=self.mel_filters, + ) + ) + mel_specs = np.array(mel_specs) + + return mel_specs + + def extract_rhythm(self, audio: np.ndarray): + """ + This algorithm(`RhythmExtractor2013`) extracts the beat positions and estimates their confidence as well as + tempo in bpm for an audio signal. For more information please visit + https://essentia.upf.edu/reference/std_RhythmExtractor2013.html . + + Args: + audio(`numpy.ndarray`): + raw audio waveform which is passed to the Rhythm Extractor. + """ + requires_backends(self, ["essentia"]) + essentia_tracker = essentia.standard.RhythmExtractor2013(method="multifeature") + bpm, beat_times, confidence, estimates, essentia_beat_intervals = essentia_tracker(audio) + + return bpm, beat_times, confidence, estimates, essentia_beat_intervals + + def interpolate_beat_times( + self, beat_times: numpy.ndarray, steps_per_beat: numpy.ndarray, n_extend: numpy.ndarray + ): + """ + This method takes beat_times and then interpolates that using `scipy.interpolate.interp1d` and the output is + then used to convert raw audio to log-mel-spectrogram. + + Args: + beat_times (`numpy.ndarray`): + beat_times is passed into `scipy.interpolate.interp1d` for processing. + steps_per_beat (`int`): + used as an parameter to control the interpolation. + n_extend (`int`): + used as an parameter to control the interpolation. + """ + + requires_backends(self, ["scipy"]) + beat_times_function = scipy.interpolate.interp1d( + np.arange(beat_times.size), + beat_times, + bounds_error=False, + fill_value="extrapolate", + ) + + ext_beats = beat_times_function( + np.linspace(0, beat_times.size + n_extend - 1, beat_times.size * steps_per_beat + n_extend) + ) + + return ext_beats + + def preprocess_mel(self, audio: np.ndarray, beatstep: np.ndarray): + """ + Preprocessing for log-mel-spectrogram + + Args: + audio (`numpy.ndarray` of shape `(audio_length, )` ): + Raw audio waveform to be processed. + beatstep (`numpy.ndarray`): + Interpolated values of the raw audio. If beatstep[0] is greater than 0.0, then it will be shifted by + the value at beatstep[0]. + """ + + if audio is not None and len(audio.shape) != 1: + raise ValueError( + f"Expected `audio` to be a single channel audio input of shape `(n, )` but found shape {audio.shape}." + ) + if beatstep[0] > 0.0: + beatstep = beatstep - beatstep[0] + + num_steps = self.num_bars * 4 + num_target_steps = len(beatstep) + extrapolated_beatstep = self.interpolate_beat_times( + beat_times=beatstep, steps_per_beat=1, n_extend=(self.num_bars + 1) * 4 + 1 + ) + + sample_indices = [] + max_feature_length = 0 + for i in range(0, num_target_steps, num_steps): + start_idx = i + end_idx = min(i + num_steps, num_target_steps) + start_sample = int(extrapolated_beatstep[start_idx] * self.sampling_rate) + end_sample = int(extrapolated_beatstep[end_idx] * self.sampling_rate) + sample_indices.append((start_sample, end_sample)) + max_feature_length = max(max_feature_length, end_sample - start_sample) + padded_batch = [] + for start_sample, end_sample in sample_indices: + feature = audio[start_sample:end_sample] + padded_feature = np.pad( + feature, + ((0, max_feature_length - feature.shape[0]),), + "constant", + constant_values=0, + ) + padded_batch.append(padded_feature) + + padded_batch = np.asarray(padded_batch) + return padded_batch, extrapolated_beatstep + + def _pad(self, features: np.ndarray, add_zero_line=True): + features_shapes = [each_feature.shape for each_feature in features] + attention_masks, padded_features = [], [] + for i, each_feature in enumerate(features): + # To pad "input_features". + if len(each_feature.shape) == 3: + features_pad_value = max([*zip(*features_shapes)][1]) - features_shapes[i][1] + attention_mask = np.ones(features_shapes[i][:2], dtype=np.int64) + feature_padding = ((0, 0), (0, features_pad_value), (0, 0)) + attention_mask_padding = (feature_padding[0], feature_padding[1]) + + # To pad "beatsteps" and "extrapolated_beatstep". + else: + each_feature = each_feature.reshape(1, -1) + features_pad_value = max([*zip(*features_shapes)][0]) - features_shapes[i][0] + attention_mask = np.ones(features_shapes[i], dtype=np.int64).reshape(1, -1) + feature_padding = attention_mask_padding = ((0, 0), (0, features_pad_value)) + + each_padded_feature = np.pad(each_feature, feature_padding, "constant", constant_values=self.padding_value) + attention_mask = np.pad( + attention_mask, attention_mask_padding, "constant", constant_values=self.padding_value + ) + + if add_zero_line: + # if it is batched then we separate each examples using zero array + zero_array_len = max([*zip(*features_shapes)][1]) + + # we concatenate the zero array line here + each_padded_feature = np.concatenate( + [each_padded_feature, np.zeros([1, zero_array_len, self.feature_size])], axis=0 + ) + attention_mask = np.concatenate( + [attention_mask, np.zeros([1, zero_array_len], dtype=attention_mask.dtype)], axis=0 + ) + + padded_features.append(each_padded_feature) + attention_masks.append(attention_mask) + + padded_features = np.concatenate(padded_features, axis=0).astype(np.float32) + attention_masks = np.concatenate(attention_masks, axis=0).astype(np.int64) + + return padded_features, attention_masks + + def pad( + self, + inputs: BatchFeature, + is_batched: bool, + return_attention_mask: bool, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Pads the inputs to same length and returns attention_mask. + + Args: + inputs (`BatchFeature`): + Processed audio features. + is_batched (`bool`): + Whether inputs are batched or not. + return_attention_mask (`bool`): + Whether to return attention mask or not. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + If nothing is specified, it will return list of `np.ndarray` arrays. + Return: + `BatchFeature` with attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep added + to it: + - **attention_mask** numpy.ndarray of shape `(batch_size, max_input_features_seq_length)` -- + Example : + 1, 1, 1, 0, 0 (audio 1, also here it is padded to max length of 5 that's why there are 2 zeros at + the end indicating they are padded) + + 0, 0, 0, 0, 0 (zero pad to separate audio 1 and 2) + + 1, 1, 1, 1, 1 (audio 2) + + 0, 0, 0, 0, 0 (zero pad to separate audio 2 and 3) + + 1, 1, 1, 1, 1 (audio 3) + - **attention_mask_beatsteps** numpy.ndarray of shape `(batch_size, max_beatsteps_seq_length)` + - **attention_mask_extrapolated_beatstep** numpy.ndarray of shape `(batch_size, + max_extrapolated_beatstep_seq_length)` + """ + + processed_features_dict = {} + for feature_name, feature_value in inputs.items(): + if feature_name == "input_features": + padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=True) + processed_features_dict[feature_name] = padded_feature_values + if return_attention_mask: + processed_features_dict["attention_mask"] = attention_mask + else: + padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=False) + processed_features_dict[feature_name] = padded_feature_values + if return_attention_mask: + processed_features_dict[f"attention_mask_{feature_name}"] = attention_mask + + # If we are processing only one example, we should remove the zero array line since we don't need it to + # separate examples from each other. + if not is_batched and not return_attention_mask: + processed_features_dict["input_features"] = processed_features_dict["input_features"][:-1, ...] + + outputs = BatchFeature(processed_features_dict, tensor_type=return_tensors) + + return outputs + + def __call__( + self, + audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + sampling_rate: Union[int, list[int]], + steps_per_beat: int = 2, + resample: Optional[bool] = True, + return_attention_mask: Optional[bool] = False, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model. + + Args: + audio (`np.ndarray`, `List`): + The audio or batch of audio to be processed. Each audio can be a numpy array, a list of float values, a + list of numpy arrays or a list of list of float values. + sampling_rate (`int`): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + steps_per_beat (`int`, *optional*, defaults to 2): + This is used in interpolating `beat_times`. + resample (`bool`, *optional*, defaults to `True`): + Determines whether to resample the audio to `sampling_rate` or not before processing. Must be True + during inference. + return_attention_mask (`bool` *optional*, defaults to `False`): + Denotes if attention_mask for input_features, beatsteps and extrapolated_beatstep will be given as + output or not. Automatically set to True for batched inputs. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + If nothing is specified, it will return list of `np.ndarray` arrays. + """ + + requires_backends(self, ["librosa"]) + is_batched = isinstance(audio, (list, tuple)) and isinstance(audio[0], (np.ndarray, tuple, list)) + if is_batched: + # This enables the user to process files of different sampling_rate at same time + if not isinstance(sampling_rate, list): + raise ValueError( + "Please give sampling_rate of each audio separately when you are passing multiple raw_audios at the same time. " + f"Received {sampling_rate}, expected [audio_1_sr, ..., audio_n_sr]." + ) + return_attention_mask = True if return_attention_mask is None else return_attention_mask + else: + audio = [audio] + sampling_rate = [sampling_rate] + return_attention_mask = False if return_attention_mask is None else return_attention_mask + + batch_input_features, batch_beatsteps, batch_ext_beatstep = [], [], [] + for single_raw_audio, single_sampling_rate in zip(audio, sampling_rate): + bpm, beat_times, confidence, estimates, essentia_beat_intervals = self.extract_rhythm( + audio=single_raw_audio + ) + beatsteps = self.interpolate_beat_times(beat_times=beat_times, steps_per_beat=steps_per_beat, n_extend=1) + + if self.sampling_rate != single_sampling_rate and self.sampling_rate is not None: + if resample: + # Change sampling_rate to self.sampling_rate + single_raw_audio = librosa.core.resample( + single_raw_audio, + orig_sr=single_sampling_rate, + target_sr=self.sampling_rate, + res_type="kaiser_best", + ) + else: + warnings.warn( + f"The sampling_rate of the provided audio is different from the target sampling_rate " + f"of the Feature Extractor, {self.sampling_rate} vs {single_sampling_rate}. " + f"In these cases it is recommended to use `resample=True` in the `__call__` method to " + f"get the optimal behaviour." + ) + + single_sampling_rate = self.sampling_rate + start_sample = int(beatsteps[0] * single_sampling_rate) + end_sample = int(beatsteps[-1] * single_sampling_rate) + + input_features, extrapolated_beatstep = self.preprocess_mel( + single_raw_audio[start_sample:end_sample], beatsteps - beatsteps[0] + ) + + mel_specs = self.mel_spectrogram(input_features.astype(np.float32)) + + # apply np.log to get log mel-spectrograms + log_mel_specs = np.log(np.clip(mel_specs, a_min=1e-6, a_max=None)) + + input_features = np.transpose(log_mel_specs, (0, -1, -2)) + + batch_input_features.append(input_features) + batch_beatsteps.append(beatsteps) + batch_ext_beatstep.append(extrapolated_beatstep) + + output = BatchFeature( + { + "input_features": batch_input_features, + "beatsteps": batch_beatsteps, + "extrapolated_beatstep": batch_ext_beatstep, + } + ) + + output = self.pad( + output, + is_batched=is_batched, + return_attention_mask=return_attention_mask, + return_tensors=return_tensors, + ) + + return output + + +__all__ = ["Pop2PianoFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pop2piano/modeling_pop2piano.py b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/modeling_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..996fc5635866f397256e6509d5d169f6ac2de3a9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/modeling_pop2piano.py @@ -0,0 +1,1331 @@ +# coding=utf-8 +# Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Pop2Piano model.""" + +import copy +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.generation import GenerationConfig + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_pop2piano import Pop2PianoConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_load_pop2piano_layer_norm = True + +try: + from apex.normalization import FusedRMSNorm + + _load_pop2piano_layer_norm = False + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm") +except ImportError: + # using the normal Pop2PianoLayerNorm + pass +except Exception: + logger.warning("Discovered apex but it failed to load, falling back to Pop2PianoLayerNorm") + pass + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pop2Piano +class Pop2PianoLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # Pop2Piano uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +if not _load_pop2piano_layer_norm: + Pop2PianoLayerNorm = FusedRMSNorm + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Pop2Piano,t5->pop2piano +class Pop2PianoDenseActDense(nn.Module): + def __init__(self, config: Pop2PianoConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pop2Piano +class Pop2PianoDenseGatedActDense(nn.Module): + def __init__(self, config: Pop2PianoConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Pop2Piano +class Pop2PianoLayerFF(nn.Module): + def __init__(self, config: Pop2PianoConfig): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = Pop2PianoDenseGatedActDense(config) + else: + self.DenseReluDense = Pop2PianoDenseActDense(config) + + self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano +class Pop2PianoAttention(nn.Module): + def __init__( + self, + config: Pop2PianoConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_values=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + is_updated = False + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano +class Pop2PianoLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = Pop2PianoAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano +class Pop2PianoLayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano +class Pop2PianoBlock(GradientCheckpointingLayer): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + Pop2PianoLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) + if self.is_decoder: + self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(Pop2PianoLayerFF(config)) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return ( + outputs + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +@auto_docstring +class Pop2PianoPreTrainedModel(PreTrainedModel): + config: Pop2PianoConfig + base_model_prefix = "transformer" + is_parallelizable = False + supports_gradient_checkpointing = True + + _can_compile_fullgraph = False + _no_split_modules = ["Pop2PianoBlock"] + _keep_in_fp32_modules = ["wo"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, Pop2PianoLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, Pop2PianoConcatEmbeddingToMel): + module.embedding.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, Pop2PianoForConditionalGeneration): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, Pop2PianoDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, Pop2PianoDenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, Pop2PianoAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class Pop2PianoStack(Pop2PianoPreTrainedModel): + # Copied from transformers.models.t5.modeling_t5.T5Stack.__init__ with T5->Pop2Piano,t5->pop2piano + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [ + Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache( + DynamicCache(config=self.config), DynamicCache(config=self.config) + ) + else: + past_key_values = DynamicCache(config=self.config) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Pop2PianoConcatEmbeddingToMel(nn.Module): + """Embedding Matrix for `composer` tokens.""" + + def __init__(self, config): + super().__init__() + self.embedding = nn.Embedding(num_embeddings=config.composer_vocab_size, embedding_dim=config.d_model) + + def forward(self, feature, index_value, embedding_offset): + index_shifted = index_value - embedding_offset + composer_embedding = self.embedding(index_shifted).unsqueeze(1) + inputs_embeds = torch.cat([composer_embedding, feature], dim=1) + return inputs_embeds + + +@auto_docstring( + custom_intro=""" + Pop2Piano Model with a `language modeling` head on top. + """ +) +class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: Pop2PianoConfig): + super().__init__(config) + self.config = config + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + self.mel_conditioner = Pop2PianoConcatEmbeddingToMel(config) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + + self.encoder = Pop2PianoStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = Pop2PianoStack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_mel_conditioner_outputs( + self, + input_features: torch.FloatTensor, + composer: str, + generation_config: GenerationConfig, + attention_mask: Optional[torch.FloatTensor] = None, + ): + """ + This method is used to concatenate mel conditioner tokens at the front of the input_features in order to + control the type of MIDI token generated by the model. + + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + input features extracted from the feature extractor. + composer (`str`): + composer token which determines the type of MIDI tokens to be generated. + generation_config (`~generation.GenerationConfig`): + The generation is used to get the composer-feature_token pair. + attention_mask (``, *optional*): + For batched generation `input_features` are padded to have the same shape across all examples. + `attention_mask` helps to determine which areas were padded and which were not. + - 1 for tokens that are **not padded**, + - 0 for tokens that are **padded**. + """ + composer_to_feature_token = generation_config.composer_to_feature_token + if composer not in composer_to_feature_token: + raise ValueError( + f"Please choose a composer from {list(composer_to_feature_token.keys())}. Composer received - {composer}" + ) + composer_value = composer_to_feature_token[composer] + composer_value = torch.tensor(composer_value, device=self.device) + composer_value = composer_value.repeat(input_features.shape[0]) + + embedding_offset = min(composer_to_feature_token.values()) + + input_features = self.mel_conditioner( + feature=input_features, + index_value=composer_value, + embedding_offset=embedding_offset, + ) + if attention_mask is not None: + input_features[~attention_mask[:, 0].bool()] = 0.0 + + # since self.mel_conditioner adds a new array at the front of inputs_embeds we need to do the same for attention_mask to keep the shapes same + attention_mask = torch.concatenate([attention_mask[:, 0].view(-1, 1), attention_mask], axis=1) + return input_features, attention_mask + + return input_features, None + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings + so you should be able to pad the inputs on both the right and the left. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. + [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining + take a look a [Pop2Piano Training](./Pop2Piano#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the + starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last + `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None and input_features is not None: + raise ValueError("Both `inputs_embeds` and `input_features` received! Please provide only one of them") + elif input_features is not None and inputs_embeds is None: + inputs_embeds = input_features + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features, + attention_mask=None, + composer="composer1", + generation_config=None, + **kwargs, + ): + """ + Generates token ids for midi outputs. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation + strategies and code examples, check out the [following guide](./generation_strategies). + + + + Parameters: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`. + attention_mask: + For batched generation `input_features` are padded to have the same shape across all examples. + `attention_mask` helps to determine which areas were padded and which were not. + - 1 for tokens that are **not padded**, + - 0 for tokens that are **padded**. + composer (`str`, *optional*, defaults to `"composer1"`): + This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each + `"composer"`. Please make sure that the composer value is present in `composer_to_feature_token` in + `generation_config`. For an example please see + https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json . + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + + if generation_config is None: + generation_config = self.generation_config + generation_config.update(**kwargs) + + # check for composer_to_feature_token + if not hasattr(generation_config, "composer_to_feature_token"): + raise ValueError( + "`composer_to_feature_token` was not found! Please refer to " + "https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json" + "and parse a dict like that." + ) + + if len(generation_config.composer_to_feature_token) != self.config.composer_vocab_size: + raise ValueError( + "config.composer_vocab_size must be same as the number of keys in " + f"generation_config.composer_to_feature_token! " + f"Found {self.config.composer_vocab_size} vs {len(generation_config.composer_to_feature_token)}." + ) + + # to control the variation of generated MIDI tokens we concatenate mel-conditioner tokens(which depends on composer_token) + # at the front of input_features. + input_features, attention_mask = self.get_mel_conditioner_outputs( + input_features=input_features, + attention_mask=attention_mask, + composer=composer, + generation_config=generation_config, + ) + + return super().generate( + inputs=None, + inputs_embeds=input_features, + attention_mask=attention_mask, + generation_config=generation_config, + **kwargs, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +__all__ = ["Pop2PianoForConditionalGeneration", "Pop2PianoPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pop2piano/processing_pop2piano.py b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/processing_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..bb914b3f87106197bccdd8c7e82828a889098e05 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/processing_pop2piano.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for Pop2Piano.""" + +import os +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessorMixin +from ...tokenization_utils import BatchEncoding, PaddingStrategy, TruncationStrategy +from ...utils import TensorType +from ...utils.import_utils import requires + + +@requires(backends=("essentia", "librosa", "pretty_midi", "scipy", "torch")) +class Pop2PianoProcessor(ProcessorMixin): + r""" + Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single + processor. + + [`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`]. + See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information. + + Args: + feature_extractor (`Pop2PianoFeatureExtractor`): + An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`Pop2PianoTokenizer`): + An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input. + """ + + attributes = ["feature_extractor", "tokenizer"] + feature_extractor_class = "Pop2PianoFeatureExtractor" + tokenizer_class = "Pop2PianoTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__( + self, + audio: Union[np.ndarray, list[float], list[np.ndarray]] = None, + sampling_rate: Optional[Union[int, list[int]]] = None, + steps_per_beat: int = 2, + resample: Optional[bool] = True, + notes: Union[list, TensorType] = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + verbose: bool = True, + **kwargs, + ) -> Union[BatchFeature, BatchEncoding]: + """ + This method uses [`Pop2PianoFeatureExtractor.__call__`] method to prepare log-mel-spectrograms for the model, + and [`Pop2PianoTokenizer.__call__`] to prepare token_ids from notes. + + Please refer to the docstring of the above two methods for more information. + """ + + # Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and + # feature_extractor_output, we must check for both. + if (audio is None and sampling_rate is None) and (notes is None): + raise ValueError( + "You have to specify at least audios and sampling_rate in order to use feature extractor or " + "notes to use the tokenizer part." + ) + + if audio is not None and sampling_rate is not None: + inputs = self.feature_extractor( + audio=audio, + sampling_rate=sampling_rate, + steps_per_beat=steps_per_beat, + resample=resample, + **kwargs, + ) + if notes is not None: + encoded_token_ids = self.tokenizer( + notes=notes, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + if notes is None: + return inputs + + elif audio is None or sampling_rate is None: + return encoded_token_ids + + else: + inputs["token_ids"] = encoded_token_ids["token_ids"] + return inputs + + def batch_decode( + self, + token_ids, + feature_extractor_output: BatchFeature, + return_midi: bool = True, + ) -> BatchEncoding: + """ + This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes. + + Please refer to the docstring of the above two methods for more information. + """ + + return self.tokenizer.batch_decode( + token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi + ) + + def save_pretrained(self, save_directory, **kwargs): + if os.path.isfile(save_directory): + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + os.makedirs(save_directory, exist_ok=True) + return super().save_pretrained(save_directory, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) + return cls(*args) + + +__all__ = ["Pop2PianoProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/pop2piano/tokenization_pop2piano.py b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/tokenization_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..f7aea3479f6ff3a97439154b4d4851bb723922b1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/pop2piano/tokenization_pop2piano.py @@ -0,0 +1,723 @@ +# coding=utf-8 +# Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Pop2Piano.""" + +import json +import os +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...tokenization_utils import AddedToken, BatchEncoding, PaddingStrategy, PreTrainedTokenizer, TruncationStrategy +from ...utils import TensorType, is_pretty_midi_available, logging, requires_backends, to_numpy +from ...utils.import_utils import requires + + +if is_pretty_midi_available(): + import pretty_midi + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = { + "vocab": "vocab.json", +} + + +def token_time_to_note(number, cutoff_time_idx, current_idx): + current_idx += number + if cutoff_time_idx is not None: + current_idx = min(current_idx, cutoff_time_idx) + + return current_idx + + +def token_note_to_note(number, current_velocity, default_velocity, note_onsets_ready, current_idx, notes): + if note_onsets_ready[number] is not None: + # offset with onset + onset_idx = note_onsets_ready[number] + if onset_idx < current_idx: + # Time shift after previous note_on + offset_idx = current_idx + notes.append([onset_idx, offset_idx, number, default_velocity]) + onsets_ready = None if current_velocity == 0 else current_idx + note_onsets_ready[number] = onsets_ready + else: + note_onsets_ready[number] = current_idx + return notes + + +@requires(backends=("pretty_midi", "torch")) +class Pop2PianoTokenizer(PreTrainedTokenizer): + """ + Constructs a Pop2Piano tokenizer. This tokenizer does not require training. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab (`str`): + Path to the vocab file which contains the vocabulary. + default_velocity (`int`, *optional*, defaults to 77): + Determines the default velocity to be used while creating midi Notes. + num_bars (`int`, *optional*, defaults to 2): + Determines cutoff_time_idx in for each token. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + """ + + model_input_names = ["token_ids", "attention_mask"] + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab, + default_velocity=77, + num_bars=2, + unk_token="-1", + eos_token="1", + pad_token="0", + bos_token="2", + **kwargs, + ): + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + + self.default_velocity = default_velocity + self.num_bars = num_bars + + # Load the vocab + with open(vocab, "rb") as file: + self.encoder = json.load(file) + + # create mappings for encoder + self.decoder = {v: k for k, v in self.encoder.items()} + + super().__init__( + unk_token=unk_token, + eos_token=eos_token, + pad_token=pad_token, + bos_token=bos_token, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns the vocabulary size of the tokenizer.""" + return len(self.encoder) + + def get_vocab(self): + """Returns the vocabulary of the tokenizer.""" + return dict(self.encoder, **self.added_tokens_encoder) + + def _convert_id_to_token(self, token_id: int) -> list: + """ + Decodes the token ids generated by the transformer into notes. + + Args: + token_id (`int`): + This denotes the ids generated by the transformers to be converted to Midi tokens. + + Returns: + `List`: A list consists of token_type (`str`) and value (`int`). + """ + + token_type_value = self.decoder.get(token_id, f"{self.unk_token}_TOKEN_TIME") + token_type_value = token_type_value.split("_") + token_type, value = "_".join(token_type_value[1:]), int(token_type_value[0]) + + return [token_type, value] + + def _convert_token_to_id(self, token, token_type="TOKEN_TIME") -> int: + """ + Encodes the Midi tokens to transformer generated token ids. + + Args: + token (`int`): + This denotes the token value. + token_type (`str`): + This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME", + "TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL". + + Returns: + `int`: returns the id of the token. + """ + return self.encoder.get(f"{token}_{token_type}", int(self.unk_token)) + + def relative_batch_tokens_ids_to_notes( + self, + tokens: np.ndarray, + beat_offset_idx: int, + bars_per_batch: int, + cutoff_time_idx: int, + ): + """ + Converts relative tokens to notes which are then used to generate pretty midi object. + + Args: + tokens (`numpy.ndarray`): + Tokens to be converted to notes. + beat_offset_idx (`int`): + Denotes beat offset index for each note in generated Midi. + bars_per_batch (`int`): + A parameter to control the Midi output generation. + cutoff_time_idx (`int`): + Denotes the cutoff time index for each note in generated Midi. + """ + + notes = None + + for index in range(len(tokens)): + _tokens = tokens[index] + _start_idx = beat_offset_idx + index * bars_per_batch * 4 + _cutoff_time_idx = cutoff_time_idx + _start_idx + _notes = self.relative_tokens_ids_to_notes( + _tokens, + start_idx=_start_idx, + cutoff_time_idx=_cutoff_time_idx, + ) + + if len(_notes) == 0: + pass + elif notes is None: + notes = _notes + else: + notes = np.concatenate((notes, _notes), axis=0) + + if notes is None: + return [] + return notes + + def relative_batch_tokens_ids_to_midi( + self, + tokens: np.ndarray, + beatstep: np.ndarray, + beat_offset_idx: int = 0, + bars_per_batch: int = 2, + cutoff_time_idx: int = 12, + ): + """ + Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens + to notes then uses `notes_to_midi` method to convert them to Midi. + + Args: + tokens (`numpy.ndarray`): + Denotes tokens which alongside beatstep will be converted to Midi. + beatstep (`np.ndarray`): + We get beatstep from feature extractor which is also used to get Midi. + beat_offset_idx (`int`, *optional*, defaults to 0): + Denotes beat offset index for each note in generated Midi. + bars_per_batch (`int`, *optional*, defaults to 2): + A parameter to control the Midi output generation. + cutoff_time_idx (`int`, *optional*, defaults to 12): + Denotes the cutoff time index for each note in generated Midi. + """ + beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx + notes = self.relative_batch_tokens_ids_to_notes( + tokens=tokens, + beat_offset_idx=beat_offset_idx, + bars_per_batch=bars_per_batch, + cutoff_time_idx=cutoff_time_idx, + ) + midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx]) + return midi + + # Taken from the original code + # Please see https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L257 + def relative_tokens_ids_to_notes( + self, tokens: np.ndarray, start_idx: float, cutoff_time_idx: Optional[float] = None + ): + """ + Converts relative tokens to notes which will then be used to create Pretty Midi objects. + + Args: + tokens (`numpy.ndarray`): + Relative Tokens which will be converted to notes. + start_idx (`float`): + A parameter which denotes the starting index. + cutoff_time_idx (`float`, *optional*): + A parameter used while converting tokens to notes. + """ + words = [self._convert_id_to_token(token) for token in tokens] + + current_idx = start_idx + current_velocity = 0 + note_onsets_ready = [None for i in range(sum([k.endswith("NOTE") for k in self.encoder]) + 1)] + notes = [] + for token_type, number in words: + if token_type == "TOKEN_SPECIAL": + if number == 1: + break + elif token_type == "TOKEN_TIME": + current_idx = token_time_to_note( + number=number, cutoff_time_idx=cutoff_time_idx, current_idx=current_idx + ) + elif token_type == "TOKEN_VELOCITY": + current_velocity = number + + elif token_type == "TOKEN_NOTE": + notes = token_note_to_note( + number=number, + current_velocity=current_velocity, + default_velocity=self.default_velocity, + note_onsets_ready=note_onsets_ready, + current_idx=current_idx, + notes=notes, + ) + else: + raise ValueError("Token type not understood!") + + for pitch, note_onset in enumerate(note_onsets_ready): + # force offset if no offset for each pitch + if note_onset is not None: + if cutoff_time_idx is None: + cutoff = note_onset + 1 + else: + cutoff = max(cutoff_time_idx, note_onset + 1) + + offset_idx = max(current_idx, cutoff) + notes.append([note_onset, offset_idx, pitch, self.default_velocity]) + + if len(notes) == 0: + return [] + else: + notes = np.array(notes) + note_order = notes[:, 0] * 128 + notes[:, 1] + notes = notes[note_order.argsort()] + return notes + + def notes_to_midi(self, notes: np.ndarray, beatstep: np.ndarray, offset_sec: int = 0.0): + """ + Converts notes to Midi. + + Args: + notes (`numpy.ndarray`): + This is used to create Pretty Midi objects. + beatstep (`numpy.ndarray`): + This is the extrapolated beatstep that we get from feature extractor. + offset_sec (`int`, *optional*, defaults to 0.0): + This represents the offset seconds which is used while creating each Pretty Midi Note. + """ + + requires_backends(self, ["pretty_midi"]) + + new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0) + new_inst = pretty_midi.Instrument(program=0) + new_notes = [] + + for onset_idx, offset_idx, pitch, velocity in notes: + new_note = pretty_midi.Note( + velocity=velocity, + pitch=pitch, + start=beatstep[onset_idx] - offset_sec, + end=beatstep[offset_idx] - offset_sec, + ) + new_notes.append(new_note) + new_inst.notes = new_notes + new_pm.instruments.append(new_inst) + new_pm.remove_invalid_notes() + return new_pm + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + """ + Saves the tokenizer's vocabulary dictionary to the provided save_directory. + + Args: + save_directory (`str`): + A path to the directory where to saved. It will be created if it doesn't exist. + filename_prefix (`Optional[str]`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + # Save the encoder. + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"] + ) + with open(out_vocab_file, "w") as file: + file.write(json.dumps(self.encoder)) + + return (out_vocab_file,) + + def encode_plus( + self, + notes: Union[np.ndarray, list[pretty_midi.Note]], + truncation_strategy: Optional[TruncationStrategy] = None, + max_length: Optional[int] = None, + **kwargs, + ) -> BatchEncoding: + r""" + This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer + generated token ids. It only works on a single batch, to process multiple batches please use + `batch_encode_plus` or `__call__` method. + + Args: + notes (`numpy.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi.Note` objects): + This represents the midi notes. If `notes` is a `numpy.ndarray`: + - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. + If `notes` is a `list` containing `pretty_midi.Note` objects: + - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. + truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*): + Indicates the truncation strategy that is going to be used during truncation. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + + Returns: + `BatchEncoding` containing the tokens ids. + """ + + requires_backends(self, ["pretty_midi"]) + + # check if notes is a pretty_midi object or not, if yes then extract the attributes and put them into a numpy + # array. + if isinstance(notes[0], pretty_midi.Note): + notes = np.array( + [[each_note.start, each_note.end, each_note.pitch, each_note.velocity] for each_note in notes] + ).reshape(-1, 4) + + # to round up all the values to the closest int values. + notes = np.round(notes).astype(np.int32) + max_time_idx = notes[:, :2].max() + + times = [[] for i in range(max_time_idx + 1)] + for onset, offset, pitch, velocity in notes: + times[onset].append([pitch, velocity]) + times[offset].append([pitch, 0]) + + tokens = [] + current_velocity = 0 + for i, time in enumerate(times): + if len(time) == 0: + continue + tokens.append(self._convert_token_to_id(i, "TOKEN_TIME")) + for pitch, velocity in time: + velocity = int(velocity > 0) + if current_velocity != velocity: + current_velocity = velocity + tokens.append(self._convert_token_to_id(velocity, "TOKEN_VELOCITY")) + tokens.append(self._convert_token_to_id(pitch, "TOKEN_NOTE")) + + total_len = len(tokens) + + # truncation + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + tokens, _, _ = self.truncate_sequences( + ids=tokens, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + **kwargs, + ) + + return BatchEncoding({"token_ids": tokens}) + + def batch_encode_plus( + self, + notes: Union[np.ndarray, list[pretty_midi.Note]], + truncation_strategy: Optional[TruncationStrategy] = None, + max_length: Optional[int] = None, + **kwargs, + ) -> BatchEncoding: + r""" + This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer + generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop. + + Args: + notes (`numpy.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi.Note` objects): + This represents the midi notes. If `notes` is a `numpy.ndarray`: + - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. + If `notes` is a `list` containing `pretty_midi.Note` objects: + - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. + truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*): + Indicates the truncation strategy that is going to be used during truncation. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + + Returns: + `BatchEncoding` containing the tokens ids. + """ + + encoded_batch_token_ids = [] + for i in range(len(notes)): + encoded_batch_token_ids.append( + self.encode_plus( + notes[i], + truncation_strategy=truncation_strategy, + max_length=max_length, + **kwargs, + )["token_ids"] + ) + + return BatchEncoding({"token_ids": encoded_batch_token_ids}) + + def __call__( + self, + notes: Union[ + np.ndarray, + list[pretty_midi.Note], + list[list[pretty_midi.Note]], + ], + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + r""" + This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated + token ids. + + Args: + notes (`numpy.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi.Note` objects): + This represents the midi notes. + + If `notes` is a `numpy.ndarray`: + - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. + If `notes` is a `list` containing `pretty_midi.Note` objects: + - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to + `None`, this will use the predefined model maximum length if a maximum length is required by one of the + truncation/padding parameters. If the model has no specific maximum input length (like XLNet) + truncation/padding to a maximum length will be deactivated. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + + Returns: + `BatchEncoding` containing the token_ids. + """ + + # check if it is batched or not + # it is batched if its a list containing a list of `pretty_midi.Notes` where the outer list contains all the + # batches and the inner list contains all Notes for a single batch. Otherwise if np.ndarray is passed it will be + # considered batched if it has shape of `[batch_size, sequence_length, 4]` or ndim=3. + is_batched = notes.ndim == 3 if isinstance(notes, np.ndarray) else isinstance(notes[0], list) + + # get the truncation and padding strategy + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + if is_batched: + # If the user has not explicitly mentioned `return_attention_mask` as False, we change it to True + return_attention_mask = True if return_attention_mask is None else return_attention_mask + token_ids = self.batch_encode_plus( + notes=notes, + truncation_strategy=truncation_strategy, + max_length=max_length, + **kwargs, + ) + else: + token_ids = self.encode_plus( + notes=notes, + truncation_strategy=truncation_strategy, + max_length=max_length, + **kwargs, + ) + + # since we already have truncated sequnences we are just left to do padding + token_ids = self.pad( + token_ids, + padding=padding_strategy, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_tensors=return_tensors, + verbose=verbose, + ) + + return token_ids + + def batch_decode( + self, + token_ids, + feature_extractor_output: BatchFeature, + return_midi: bool = True, + ): + r""" + This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the + transformer to midi_notes and returns them. + + Args: + token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`): + Output token_ids of `Pop2PianoConditionalGeneration` model. + feature_extractor_output (`BatchFeature`): + Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and + `"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and + `"attention_mask_extrapolated_beatstep"` + should be present if they were returned by the feature extractor. + return_midi (`bool`, *optional*, defaults to `True`): + Whether to return midi object or not. + Returns: + If `return_midi` is True: + - `BatchEncoding` containing both `notes` and `pretty_midi.pretty_midi.PrettyMIDI` objects. + If `return_midi` is False: + - `BatchEncoding` containing `notes`. + """ + + # check if they have attention_masks(attention_mask, attention_mask_beatsteps, attention_mask_extrapolated_beatstep) or not + attention_masks_present = bool( + hasattr(feature_extractor_output, "attention_mask") + and hasattr(feature_extractor_output, "attention_mask_beatsteps") + and hasattr(feature_extractor_output, "attention_mask_extrapolated_beatstep") + ) + + # if we are processing batched inputs then we must need attention_masks + if not attention_masks_present and feature_extractor_output["beatsteps"].shape[0] > 1: + raise ValueError( + "attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present " + "for batched inputs! But one of them were not present." + ) + + # check for length mismatch between inputs_embeds, beatsteps and extrapolated_beatstep + if attention_masks_present: + # since we know about the number of examples in token_ids from attention_mask + if ( + sum(feature_extractor_output["attention_mask"][:, 0] == 0) + != feature_extractor_output["beatsteps"].shape[0] + or feature_extractor_output["beatsteps"].shape[0] + != feature_extractor_output["extrapolated_beatstep"].shape[0] + ): + raise ValueError( + "Length mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found " + f"token_ids length - {token_ids.shape[0]}, beatsteps shape - {feature_extractor_output['beatsteps'].shape[0]} " + f"and extrapolated_beatsteps shape - {feature_extractor_output['extrapolated_beatstep'].shape[0]}" + ) + if feature_extractor_output["attention_mask"].shape[0] != token_ids.shape[0]: + raise ValueError( + f"Found attention_mask of length - {feature_extractor_output['attention_mask'].shape[0]} but token_ids of length - {token_ids.shape[0]}" + ) + else: + # if there is no attention mask present then it's surely a single example + if ( + feature_extractor_output["beatsteps"].shape[0] != 1 + or feature_extractor_output["extrapolated_beatstep"].shape[0] != 1 + ): + raise ValueError( + "Length mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, " + f"But found beatsteps length - {feature_extractor_output['beatsteps'].shape[0]}, extrapolated_beatsteps length - {feature_extractor_output['extrapolated_beatstep'].shape[0]}." + ) + + if attention_masks_present: + # check for zeros(since token_ids are separated by zero arrays) + batch_idx = np.where(feature_extractor_output["attention_mask"][:, 0] == 0)[0] + else: + batch_idx = [token_ids.shape[0]] + + notes_list = [] + pretty_midi_objects_list = [] + start_idx = 0 + for index, end_idx in enumerate(batch_idx): + each_tokens_ids = token_ids[start_idx:end_idx] + # check where the whole example ended by searching for eos_token_id and getting the upper bound + each_tokens_ids = each_tokens_ids[:, : np.max(np.where(each_tokens_ids == int(self.eos_token))[1]) + 1] + beatsteps = feature_extractor_output["beatsteps"][index] + extrapolated_beatstep = feature_extractor_output["extrapolated_beatstep"][index] + + # if attention mask is present then mask out real array/tensor + if attention_masks_present: + attention_mask_beatsteps = feature_extractor_output["attention_mask_beatsteps"][index] + attention_mask_extrapolated_beatstep = feature_extractor_output[ + "attention_mask_extrapolated_beatstep" + ][index] + beatsteps = beatsteps[: np.max(np.where(attention_mask_beatsteps == 1)[0]) + 1] + extrapolated_beatstep = extrapolated_beatstep[ + : np.max(np.where(attention_mask_extrapolated_beatstep == 1)[0]) + 1 + ] + + each_tokens_ids = to_numpy(each_tokens_ids) + beatsteps = to_numpy(beatsteps) + extrapolated_beatstep = to_numpy(extrapolated_beatstep) + + pretty_midi_object = self.relative_batch_tokens_ids_to_midi( + tokens=each_tokens_ids, + beatstep=extrapolated_beatstep, + bars_per_batch=self.num_bars, + cutoff_time_idx=(self.num_bars + 1) * 4, + ) + + for note in pretty_midi_object.instruments[0].notes: + note.start += beatsteps[0] + note.end += beatsteps[0] + notes_list.append(note) + + pretty_midi_objects_list.append(pretty_midi_object) + start_idx += end_idx + 1 # 1 represents the zero array + + if return_midi: + return BatchEncoding({"notes": notes_list, "pretty_midi_objects": pretty_midi_objects_list}) + + return BatchEncoding({"notes": notes_list}) + + +__all__ = ["Pop2PianoTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/prophetnet/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7485dfbb530c3e11683c2b7dceaf35bf9edc9086 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_prophetnet import * + from .modeling_prophetnet import * + from .tokenization_prophetnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/prophetnet/configuration_prophetnet.py b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/configuration_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a6da708da4c500bf6e8411031fcb3abf0537b78d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/configuration_prophetnet.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ProphetNet model configuration""" + +from typing import Callable, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ProphetNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ProphetNetModel`]. It is used to instantiate a + ProphetNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ProphetNet + [microsoft/prophetnet-large-uncased](https://huggingface.co/microsoft/prophetnet-large-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the ProphetNET model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`ProphetNetModel`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + num_encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + num_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the `intermediate` (often named feed-forward) layer in decoder. + num_decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + num_decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + add_cross_attention (`bool`, *optional*, defaults to `True`): + Whether cross-attention layers should be added to the model. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether this is an encoder/decoder model. + pad_token_id (`int`, *optional*, defaults to 1) + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0) + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2) + End of stream token id. + ngram (`int`, *optional*, defaults to 2) + Number of future tokens to predict. Set to 1 to be same as traditional Language model to predict next first + token. + num_buckets (`int`, *optional*, defaults to 32) + The number of buckets to use for each attention layer. This is for relative position calculation. See the + [T5 paper](see https://huggingface.co/papers/1910.10683) for more details. + relative_max_distance (`int`, *optional*, defaults to 128) + Relative distances greater than this number will be put into the last same bucket. This is for relative + position calculation. See the [T5 paper](see https://huggingface.co/papers/1910.10683) for more details. + disable_ngram_loss (`bool`, *optional*, defaults to `False`): + Whether be trained predicting only the next first token. + eps (`float`, *optional*, defaults to 0.0): + Controls the `epsilon` parameter value for label smoothing in the loss calculation. If set to 0, no label + smoothing is performed. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "prophetnet" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "num_encoder_attention_heads", + } + + def __init__( + self, + activation_dropout: Optional[float] = 0.1, + activation_function: Optional[Union[str, Callable]] = "gelu", + vocab_size: Optional[int] = 30522, + hidden_size: Optional[int] = 1024, + encoder_ffn_dim: Optional[int] = 4096, + num_encoder_layers: Optional[int] = 12, + num_encoder_attention_heads: Optional[int] = 16, + decoder_ffn_dim: Optional[int] = 4096, + num_decoder_layers: Optional[int] = 12, + num_decoder_attention_heads: Optional[int] = 16, + attention_dropout: Optional[float] = 0.1, + dropout: Optional[float] = 0.1, + max_position_embeddings: Optional[int] = 512, + init_std: Optional[float] = 0.02, + is_encoder_decoder: Optional[bool] = True, + add_cross_attention: Optional[bool] = True, + decoder_start_token_id: Optional[int] = 0, + ngram: Optional[int] = 2, + num_buckets: Optional[int] = 32, + relative_max_distance: Optional[int] = 128, + disable_ngram_loss: Optional[bool] = False, + eps: Optional[float] = 0.0, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = 0, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.encoder_ffn_dim = encoder_ffn_dim + self.num_encoder_layers = num_encoder_layers + self.num_encoder_attention_heads = num_encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.num_decoder_layers = num_decoder_layers + self.num_decoder_attention_heads = num_decoder_attention_heads + self.max_position_embeddings = max_position_embeddings + self.init_std = init_std # Normal(0, this parameter) + self.activation_function = activation_function + + # parameters for prophetnet + self.ngram = ngram + self.num_buckets = num_buckets + self.relative_max_distance = relative_max_distance + self.disable_ngram_loss = disable_ngram_loss + self.eps = eps + + # 3 Types of Dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.dropout = dropout + + self.use_cache = use_cache + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + add_cross_attention=add_cross_attention, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + @property + def num_hidden_layers(self) -> int: + return self.num_encoder_layers + + @num_hidden_layers.setter + def num_hidden_layers(self, value): + raise NotImplementedError( + "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and" + " `num_decoder_layers`." + ) + + +__all__ = ["ProphetNetConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/prophetnet/modeling_prophetnet.py b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/modeling_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..fb78c1e505f2b25089ac1f4ce65e9be64174687f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/modeling_prophetnet.py @@ -0,0 +1,2056 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ProphetNet model, ported from ProphetNet repo(fairsequery_states version).""" + +import copy +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn +from torch.nn import LayerNorm + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_prophetnet import ProphetNetConfig + + +logger = logging.get_logger(__name__) + + +def softmax(hidden_state, dim, onnx_trace=False): + if onnx_trace: + return nn.functional.softmax(hidden_state.float(), dim=dim) + else: + return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32) + + +def ngram_attention_bias(sequence_length, ngram, device, dtype): + """ + This function computes the bias for the predict stream + """ + left_block = ( + torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min + ) + right_block = left_block.detach().clone() + # create bias + for stream_idx in range(ngram): + right_block[stream_idx].fill_diagonal_(0, wrap=False) + left_block[stream_idx].triu_(-stream_idx + 1) + + left_block[:, :, 0] = 0 + return torch.cat([left_block, right_block], dim=2) + + +def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False): + """ + This function computes individual parts of the relative position buckets. For more detail, see paper. + """ + inv_relative_positions = -relative_positions + rel_positions_bucket = 0 + + if is_bidirectional: + num_buckets = num_buckets // 2 + rel_positions_bucket = ( + rel_positions_bucket + + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets + ) + inv_relative_positions = torch.abs(inv_relative_positions) + else: + inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions)) + + max_exact = num_buckets // 2 + is_small = torch.lt(inv_relative_positions, max_exact) + val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log( + max_distance / max_exact + ) * (num_buckets - max_exact) + val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int() + rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large) + return rel_positions_bucket + + +def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids): + """ + This function computes both main and predict relative position buckets. For more detail, see paper. + """ + # main stream + main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1) + main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1) + + # predicting stream + predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1) + predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1) + predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1) + + # get both position buckets + main_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False + ) + predict_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False + ) + return main_relative_position_buckets, predict_relative_position_buckets + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for sequence-to-sequence language models outputs. + """ +) +class ProphetNetSeq2SeqLMOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`): + Prediction scores of the main stream language modeling head (scores for each vocabulary token before + SoftMax). + logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Prediction scores of the predict stream language modeling head (scores for each vocabulary token before + SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + logits_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_ngram_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + decoder_ngram_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`" + " instead.", + FutureWarning, + ) + return self.cross_attentions + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + """ +) +class ProphetNetSeq2SeqModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`): + Sequence of main stream hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*): + Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + """ + + last_hidden_state: torch.FloatTensor + last_hidden_state_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_ngram_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + decoder_ngram_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`" + " instead.", + FutureWarning, + ) + return self.cross_attentions + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + """ +) +class ProphetNetDecoderModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`): + Sequence of main stream hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + """ + + last_hidden_state: torch.FloatTensor + last_hidden_state_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + hidden_states_ngram: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + ngram_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + """ +) +class ProphetNetDecoderLMOutput(ModelOutput): + r""" + ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`): + Prediction scores of the main stream language modeling head (scores for each vocabulary token before + SoftMax). + logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Prediction scores of the predict stream language modeling head (scores for each vocabulary token before + SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + logits_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + hidden_states_ngram: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + ngram_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class ProphetNetPreTrainedModel(PreTrainedModel): + config: ProphetNetConfig + base_model_prefix = "prophetnet" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the" + " pad_token_id. See ProphetNet docs for more information" + ) + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class ProphetNetPositionalEmbeddings(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting + based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to + the forward function. + """ + + def __init__(self, config: ProphetNetConfig) -> None: + self.max_length = config.max_position_embeddings + super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id) + + def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None): + assert (position_ids is None) or (self.padding_idx is None), ( + "If position_ids is pre-computed then padding_idx should not be set." + ) + + if position_ids is None: + if past_key_values is not None and past_key_values.get_seq_length() != 0: + # position_ids is the same for every token when decoding a single step + # Without the int() cast, it doesn't work in some cases when exporting to ONNX + prev_num_input_ids = past_key_values.get_seq_length() + num_input_ids = inputs_shape[1] + prev_num_input_ids + position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( + int(self.padding_idx + num_input_ids) + ) + else: + if attention_mask is None: + attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device) + + # retrieve position_ids from input_ids / attention_mask + position_ids = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() + self.padding_idx + + # make sure position_ids are not bigger then max_length + position_ids = position_ids.clamp(0, self.max_length - 1) + + return super().forward(position_ids), position_ids + + def _forward(self, position_ids): + return super().forward(position_ids) + + +class ProphetNetAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ProphetNetConfig, num_attn_heads: int, layer_idx: Optional[int] = None): + super().__init__() + hidden_size = config.hidden_size + + self.attention_dropout = config.attention_dropout + self.dropout = config.dropout + self.num_attn_heads = num_attn_heads + self.head_dim = hidden_size // num_attn_heads + self.layer_idx = layer_idx + + assert self.head_dim * num_attn_heads == hidden_size, ( + "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and" + " `config.num_decoder_attention_heads`" + ) + + self.key_proj = nn.Linear(hidden_size, hidden_size) + self.value_proj = nn.Linear(hidden_size, hidden_size) + self.query_proj = nn.Linear(hidden_size, hidden_size) + + self.out_proj = nn.Linear(hidden_size, hidden_size) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + key_value_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[Tensor, Optional[Tensor]]: + batch_size, tgt_len, hidden_size = hidden_states.size() + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + assert list(hidden_states.size()) == [ + batch_size, + tgt_len, + hidden_size, + ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}" + + # previous time steps are cached - no need to recompute key and value if they are static + query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.key_proj(current_states) + value_states = self.value_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + query_states = query_states.view(batch_size, tgt_len, self.num_attn_heads, self.head_dim).transpose(1, 2) + src_len = key_states.size(2) + + attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) + if attn_weights.size() != expected_shape: + raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}") + + # This is part of a workaround to get around fork/join parallelism not supporting Optional types. + if attention_mask is not None and attention_mask.dim() == 0: + attention_mask = None + + expected_shape = (batch_size, self.num_attn_heads, 1, src_len) + if attention_mask is not None and attention_mask.size() != expected_shape: + raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}") + if attention_mask is not None: # don't attend to padding symbols + attn_weights = attn_weights + attention_mask + if output_attentions: + attn_weights_reshaped = attn_weights + else: + attn_weights_reshaped = None + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + batch_size, self.num_attn_heads, tgt_len, src_len + ) + + # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model + attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped + + attn_probs = nn.functional.dropout( + attn_weights, + p=self.attention_dropout, + training=self.training, + ) + attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim) + if attn_output.size() != expected_shape: + raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size) + attn_output = self.out_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) + return attn_output, attn_weights_reshaped + + +class ProphetNetFeedForward(nn.Module): + """ + This is the residual two feed-forward layer block based on the original Transformer implementation. + """ + + def __init__(self, config: ProphetNetConfig, ffn_dim: int): + super().__init__() + self.activation_fn = ACT2FN[config.activation_function] + self.intermediate = nn.Linear(config.hidden_size, ffn_dim) + self.output = nn.Linear(ffn_dim, config.hidden_size) + self.activation_dropout = config.activation_dropout + self.dropout = config.dropout + + def forward(self, hidden_states): + hidden_states = self.intermediate(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.output(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class ProphetNetNgramSelfAttention(nn.Module): + def __init__(self, config: ProphetNetConfig, layer_idx=None): + super().__init__() + self.hidden_size = config.hidden_size + + self.num_buckets = config.num_buckets + self.relative_max_distance = config.relative_max_distance + self.num_attn_heads = config.num_decoder_attention_heads + self.dropout = config.dropout + self.attention_dropout = config.attention_dropout + self.head_dim = config.hidden_size // self.num_attn_heads + self.ngram = config.ngram + self.layer_idx = layer_idx + + assert self.head_dim * self.num_attn_heads == config.hidden_size, ( + "config.hidden_size must be divisible by num_attn_heads" + ) + # key, value, query projection + self.key_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.value_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.query_proj = nn.Linear(config.hidden_size, config.hidden_size) + + # out projection + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + # rel position embeddings + self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads) + + # for onnx runtime + self.onnx_trace = False + + def _shape(self, tensor, seq_len, batch_size): + return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + past_key_values: Optional[Cache] = None, + attention_mask=None, + layer_head_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None, + cache_position=None, + ): + batch_size, ngram_sequence_length, hidden_size = hidden_states.size() + assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( + f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" + f" {hidden_states.shape}" + ) + + # project + query_states = self.query_proj(hidden_states) + key_states = self.key_proj(hidden_states) + value_states = self.value_proj(hidden_states) + + # normalize + query_states = query_states / (self.head_dim**0.5) + + # reshape + query_states = self._shape(query_states, ngram_sequence_length, batch_size) + key_states = self._shape(key_states, -1, batch_size) + value_states = self._shape(value_states, -1, batch_size) + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) + + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + # chunk into main stream and predict stream + hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) + query_states_list = query_states.chunk(1 + self.ngram, dim=2) + key_states_list = key_states.chunk(1 + self.ngram, dim=2) + value_states_list = value_states.chunk(1 + self.ngram, dim=2) + + main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] + main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] + main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:] + main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] + + # ProphetNet has two separate attention layers, one for self and one for cross attention + # We need to obtain the self attention only for this module, if `EncoderDecoderCache` + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + main_key_states, main_value_states = curr_past_key_value.update( + main_key_states, main_value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # get seq_length of main stream only + sequence_length = ngram_sequence_length // (1 + self.ngram) + + # MAIN-STREAM + # main attn weights + # [batch_size, number_heads, sequence_length, head_dimesion] + # x [batch_size, number_heads, head_dimesion, sequence_length] + # -> [batch_size, number_heads, sequence_length, sequence_length] + main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3)) + + # retrieve relative position embeddings for each layer -> see paper for more details + main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( + main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets + ) + + main_attn_weights = main_attn_weights + main_relative_pos_embeddings + + if attention_mask is not None: + main_attn_weights = main_attn_weights + attention_mask + + main_attn_probs = softmax( + main_attn_weights, + dim=-1, + onnx_trace=self.onnx_trace, + ).type_as(main_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( + batch_size, self.num_attn_heads, -1, sequence_length + ) + + main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) + # project to attn_output + # [batch_size, number_heads, sequence_length, sequence_length] + # x [batch_size, number_heads, sequence_length, head_dimesion] + # -> [batch_size, number_heads, sequence_length, head_dimesion] + main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states) + # reshape so that num_heads dim is merged into last `head_dim` axis + main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size) + main_attn_output = self.out_proj(main_attn_output) + + # PREDICT-STREAM + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_query_states = torch.stack(predict_query_states_list, 1).view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim + ) + + # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1) + + # [batch_size, sequence_length, ngram, hidden_size] + predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2) + + # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion] + predict_value_states = torch.cat( + [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2 + ) + + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states)) + + # retrieve relative position embeddings for each layer -> see paper for more details + # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings] + predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( + predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets + ) + + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings + + if extended_predict_attention_mask is not None: + # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4) + extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype) + predict_attn_weights = predict_attn_weights + extended_predict_attention_mask + + predict_attn_probs = softmax( + predict_attn_weights, + dim=-1, + onnx_trace=self.onnx_trace, + ).type_as(predict_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs + + predict_attn_probs = nn.functional.dropout( + predict_attn_probs, p=self.attention_dropout, training=self.training + ) + # project to attention output + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_attn_output = torch.einsum( + "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2)) + ) + + # reshape so that num_heads dim is merged into last `head_dim` axis + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size] + predict_attn_output = predict_attn_output.transpose(2, 3) + predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size) + predict_attn_output = self.out_proj(predict_attn_output) + + # concat to single attn output + # [batch_size, (1+ngram)*sequence_length, hidden_size] + attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) + # reshape into better form for `config.output_attentions` + main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) + + attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) + + return attn_output, main_attn_probs, predict_attn_probs + + def get_main_relative_pos_embeddings( + self, hidden_states, attn_weights, position_ids, main_relative_position_buckets + ): + # input hidden_states [batch_size, sequence_length, hidden_size] + # input attn_weights [batch_size, num_heads, sequence_length, sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape + attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len) + if main_relative_position_buckets is None: + batch_size, sequence_length = hidden_states.shape[:2] + relative_positions = ( + torch.arange(1, attn_weights.shape[-1] + 1) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, sequence_length, 1) + .to(position_ids.device) + ) + # [batch_size, sequence_length, sequence_length+1] + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) + main_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, relative_positions, False + ) + + # [batch_size, sequence_length, num_buckets * num_heads] + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + rel_pos_embeddings = rel_pos_embeddings.view( + rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2) + # [batch_size, num_heads, sequence_length, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,)) + + main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1) + # [batch_size * num_heads * sequence_length, sequence_length] + main_relative_position_buckets = main_relative_position_buckets.view( + -1, main_relative_position_buckets.shape[-1] + ) + main_relative_position_buckets = main_relative_position_buckets.long() + # [batch_size * num_heads * sequence_length, sequence_length] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) + + main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets) + main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1) + return main_relative_pos_embeddings + + def get_predict_relative_pos_embeddings( + self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets + ): + # input hidden_states [batch_size, sequence_length, ngram, hidden_size] + # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None + batch_size, sequence_length = hidden_states.shape[0:2] + + if predict_relative_position_buckets is None: + key_sequence_length = attn_weights.shape[-1] + assert position_ids[0][0] == key_sequence_length - 1, ( + "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)" + ) + relative_positions = ( + torch.arange(0, key_sequence_length) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, sequence_length, 1) + .to(position_ids.device) + ) + + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) + predict_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, relative_positions, False + ) + + # [batch_size, ngram, sequence_length, hidden_size] + hidden_states = hidden_states.transpose(1, 2) + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + + # [batch_size, ngram, sequence_length, num_buckets, num_heads] + rel_pos_embeddings = rel_pos_embeddings.view( + hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads) + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3) + # [batch_size * ngram * sequence_length * num_heads, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets) + # [ngram, batch_size, num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0) + predict_relative_position_buckets = predict_relative_position_buckets.repeat( + self.ngram, 1, self.num_attn_heads, 1 + ) + # [ngram * batch_size * num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.view( + -1, predict_relative_position_buckets.size(-1) + ).long() + + predict_relative_pos_embeddings = torch.gather( + rel_pos_embeddings, dim=1, index=predict_relative_position_buckets + ) + + # [batch_size, gram, num_heads, sequence_length, -1] + predict_relative_pos_embeddings = predict_relative_pos_embeddings.view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, -1 + ) + + return predict_relative_pos_embeddings + + +class ProphetNetEncoderLayer(GradientCheckpointingLayer): + """ + Encoder block for Prophetnet + """ + + def __init__(self, config: ProphetNetConfig): + super().__init__() + # 1st residual block + self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) + + # 2nd residual block + self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + layer_head_mask, + output_attentions: bool = False, + ): + # 1st residual block + attention_output, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) + + # 2nd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class ProphetNetDecoderLayer(GradientCheckpointingLayer): + """ + Decoder block for Prophetnet + """ + + def __init__(self, config: ProphetNetConfig, layer_idx=None): + super().__init__() + # 1st residual block + self.self_attn = ProphetNetNgramSelfAttention(config, layer_idx=layer_idx) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) + + # 2nd residual block + if config.add_cross_attention: + self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads, layer_idx=layer_idx) + self.cross_attn_layer_norm = LayerNorm(config.hidden_size) + + # 3rd residual block + self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attn_mask=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None, + past_key_values=None, + use_cache: Optional[bool] = True, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ): + # 1st residual block + ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + ) + hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) + + cross_attn_weights = None + if encoder_hidden_states is not None: + # 2nd residual block + attention_output, cross_attn_weights = self.cross_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attn_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) + + # 3rd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights) + + return outputs + + +@auto_docstring( + custom_intro=""" + The standalone encoder part of the ProphetNetModel. + """ +) +class ProphetNetEncoder(ProphetNetPreTrainedModel): + def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + r""" + word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): + The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word + embeddings instead of randomly initialized word embeddings. + """ + super().__init__(config) + + self.word_embeddings = ( + word_embeddings + if word_embeddings is not None + else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + ) + self.position_embeddings = ProphetNetPositionalEmbeddings(config) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) + + self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetEncoder + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone") + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None and inputs_embeds is None: + raise ValueError("Either input_ids or inputs_embeds has to be passed.") + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("Make sure to only pass input_ids or inputs_embeds.") + elif input_ids is not None and inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # prepare attention mask + if attention_mask is not None: + extended_attention_mask = ( + 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1) + ) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) + else: + extended_attention_mask = None + + position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2], inputs_embeds.device) + + hidden_states = inputs_embeds + position_embeddings + hidden_states = self.embeddings_layer_norm(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training) + + encoder_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_hidden_states = encoder_hidden_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_hidden_states = encoder_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_hidden_states, attentions=all_attentions + ) + + +@auto_docstring( + custom_intro=""" + The standalone decoder part of the ProphetNetModel. + """ +) +class ProphetNetDecoder(ProphetNetPreTrainedModel): + def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + r""" + word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): + The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word + embeddings instead of randomly initialized word embeddings. + """ + super().__init__(config) + + self.ngram = config.ngram + self.num_buckets = config.num_buckets + self.relative_max_distance = config.relative_max_distance + self.dropout = config.dropout + self.max_target_positions = config.max_position_embeddings + + self.word_embeddings = ( + word_embeddings + if word_embeddings is not None + else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + ) + self.position_embeddings = ProphetNetPositionalEmbeddings(config) + + self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) + self.layers = nn.ModuleList( + [ProphetNetDecoderLayer(config, layer_idx=i) for i in range(config.num_decoder_layers)] + ) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, ProphetNetDecoderModelOutput]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetDecoder + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetDecoder.from_pretrained("microsoft/prophetnet-large-uncased", add_cross_attention=False) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None and inputs_embeds is None: + raise ValueError("Either `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.") + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("Make sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.") + elif input_ids is not None and inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + main_stream_pos_embed, position_ids = self.position_embeddings( + (batch_size, sequence_length), + device=inputs_embeds.device, + past_key_values=past_key_values, + ) + + if past_key_values_length != 0: + main_relative_position_buckets, predict_relative_position_buckets = None, None + else: + ( + main_relative_position_buckets, + predict_relative_position_buckets, + ) = self.compute_buffered_relative_buckets(position_ids) + predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1) + + # add position embeddings + hidden_states = inputs_embeds + main_stream_pos_embed + + ngram_embeddings = self.ngram_embeddings.weight + + # prepare attention mask + if past_key_values_length != 0: + assert hidden_states.size(1) == 1, ( + "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" + ) + + ngram_hidden_states = [ + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1) + for ngram in range(self.ngram) + ] + extended_attention_mask = None + extended_predict_attention_mask = None + else: + ngram_hidden_states = [ + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram) + ] + extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask) + extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask) + + # prepare encoder attention mask + if encoder_attention_mask is not None: + extended_encoder_attention_mask = ( + 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1) + ) * torch.finfo(self.dtype).min + extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) + else: + extended_encoder_attention_mask = None + + hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1) + + if self.embeddings_layer_norm: + hidden_states = self.embeddings_layer_norm(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # init attentions, hidden_states and cache with empty tuples + all_main_stream_hidden_states = () if output_hidden_states else None + all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None + + all_main_stream_attns = () if output_attentions else None + all_ngram_stream_attns = () if output_attentions else None + all_cross_attns = () if output_attentions and self.config.add_cross_attention else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + # grad cannot be kept because tensor is sliced + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) + if self.config.ngram > 0: + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + + layer_outputs = decoder_layer( + hidden_states, + extended_attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_main_stream_attns += (layer_outputs[1],) + all_ngram_stream_attns += (layer_outputs[2],) + + if self.config.add_cross_attention: + all_cross_attns += (layer_outputs[3],) + + if output_hidden_states: + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) + if self.config.ngram > 0: + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + + # split last_hidden_state for return + last_hidden_state = hidden_states[:, :sequence_length] + last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None + + if not return_dict: + return tuple( + v + for v in [ + last_hidden_state, + last_hidden_state_ngram, + past_key_values, + all_main_stream_hidden_states, + all_ngram_stream_hidden_states, + all_main_stream_attns, + all_ngram_stream_attns, + all_cross_attns, + ] + if v is not None + ) + return ProphetNetDecoderModelOutput( + last_hidden_state=last_hidden_state, + last_hidden_state_ngram=last_hidden_state_ngram, + past_key_values=past_key_values, + hidden_states=all_main_stream_hidden_states, + hidden_states_ngram=all_ngram_stream_hidden_states, + attentions=all_main_stream_attns, + ngram_attentions=all_ngram_stream_attns, + cross_attentions=all_cross_attns, + ) + + def compute_buffered_relative_buckets(self, position_ids): + batch_size, sequence_length = position_ids.shape + + position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(1, 1) + main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets( + self.num_buckets, self.relative_max_distance, position_ids + ) + + # buffer relative buckets + main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1) + predict_relative_buckets = torch.cat( + [ + predict_relative_buckets[:, :sequence_length, :sequence_length], + predict_relative_buckets[ + :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length + ], + ], + 2, + ).repeat(batch_size, 1, 1) + + return main_relative_buckets, predict_relative_buckets + + def prepare_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + causal_mask = torch.full( + (seq_length, seq_length), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + causal_mask = torch.triu(causal_mask, 1) + + extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape + ) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_causal_mask + extended_attention_mask + else: + extended_attention_mask = extended_causal_mask + return extended_attention_mask.to(hidden_states.dtype) + + def prepare_predict_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + predict_causal_mask = ngram_attention_bias( + self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype + ) + predict_causal_mask = torch.cat( + [ + predict_causal_mask[:, :seq_length, :seq_length], + predict_causal_mask[ + :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length + ], + ], + dim=-1, + ) + extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape + ) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.expand( + (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length) + ) + # predicted stream attention_mask should always be 0 + extended_attention_mask = torch.cat( + [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1 + ) + extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask + else: + extended_predict_attention_mask = extended_predict_causal_mask + return extended_predict_attention_mask.to(hidden_states.dtype) + + +@auto_docstring +class ProphetNetModel(ProphetNetPreTrainedModel): + _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"] + + def __init__(self, config: ProphetNetConfig): + super().__init__(config) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + self.encoder = ProphetNetEncoder(encoder_config, self.word_embeddings) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + self.encoder.word_embeddings = self.word_embeddings + self.decoder.word_embeddings = self.word_embeddings + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings) + self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, ProphetNetSeq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetModel + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetModel.from_pretrained("microsoft/prophetnet-large-uncased") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states + >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + return ProphetNetSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + last_hidden_state_ngram=decoder_outputs.last_hidden_state_ngram, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram, + decoder_attentions=decoder_outputs.attentions, + decoder_ngram_attentions=decoder_outputs.ngram_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks. + """ +) +class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] + + def __init__(self, config: ProphetNetConfig): + super().__init__(config) + self.prophetnet = ProphetNetModel(config) + self.padding_idx = config.pad_token_id + self.disable_ngram_loss = config.disable_ngram_loss + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head) + + def get_input_embeddings(self): + return self.prophetnet.word_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, ProphetNetSeq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> logits_next_token = outputs.logits # logits to predict next token as usual + >>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + outputs = self.prophetnet( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + batch_size, sequence_length = ( + decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2] + ) + + predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1) + predict_logits = self.lm_head(predicting_streams) + + logits = predict_logits[:, 0] + logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None + + # To use .view in loss computation, make sure that logits is contiguous. + if not logits.is_contiguous(): + logits = logits.contiguous() + + loss = None + if labels is not None: + loss = self._compute_loss(predict_logits, labels) + + if not return_dict: + all_logits = tuple(v for v in [logits, logits_ngram] if v is not None) + return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:] + else: + return ProphetNetSeq2SeqLMOutput( + loss=loss, + logits=logits, + logits_ngram=logits_ngram, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states, + decoder_attentions=outputs.decoder_attentions, + decoder_ngram_attentions=outputs.decoder_ngram_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def _compute_loss(self, logits, labels, ignore_index=-100): + expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index) + + for i in range(self.config.ngram): + if i > 0 and self.disable_ngram_loss: + break + expend_targets[i, :, :] = labels + + logits = logits.transpose(0, 1).contiguous() + lprobs = nn.functional.log_softmax( + logits.view(-1, logits.size(-1)), + dim=-1, + dtype=torch.float32, + ) + + loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") + + if self.config.eps > 0.0: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + non_masked_tokens = expend_targets.ne(ignore_index).view(-1) + smooth_loss = smooth_loss[non_masked_tokens] + smooth_loss = smooth_loss.mean() + + eps_i = self.config.eps / lprobs.size(-1) + loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss + + return loss + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def get_encoder(self): + return self.prophetnet.encoder + + def get_decoder(self): + return self.prophetnet.decoder + + +@auto_docstring( + custom_intro=""" + The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal + """ +) +class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "prophetnet.word_embeddings.weight", + "prophetnet.decoder.word_embeddings.weight", + "lm_head.weight", + ] + + def __init__(self, config: ProphetNetConfig): + # set config for CLM + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.prophetnet = ProphetNetDecoderWrapper(config) + + self.padding_idx = config.pad_token_id + self.disable_ngram_loss = config.disable_ngram_loss + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prophetnet.decoder.word_embeddings + + def set_input_embeddings(self, value): + self.prophetnet.decoder.word_embeddings = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head) + + def set_decoder(self, decoder): + self.prophetnet.decoder = decoder + + def get_decoder(self): + return self.prophetnet.decoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ProphetNetDecoderLMOutput]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetForCausalLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetForCausalLM.from_pretrained("microsoft/prophetnet-large-uncased") + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + + >>> # Model can also be used with EncoderDecoder framework + >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer + >>> import torch + + >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") + >>> tokenizer_dec = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained( + ... "google-bert/bert-large-uncased", "microsoft/prophetnet-large-uncased" + ... ) + + >>> ARTICLE = ( + ... "the us state department said wednesday it had received no " + ... "formal word from bolivia that it was expelling the us ambassador there " + ... "but said the charges made against him are `` baseless ." + ... ) + >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids + >>> labels = tokenizer_dec( + ... "us rejects charges against its ambassador in bolivia", return_tensors="pt" + ... ).input_ids + >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:]) + + >>> loss = outputs.loss + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + outputs = self.prophetnet.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + batch_size, sequence_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] + + predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1) + predict_logits = self.lm_head(predicting_streams) + + logits = predict_logits[:, 0] + logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None + + loss = None + if labels is not None: + loss = self._compute_loss(predict_logits, labels) + + if not return_dict: + all_logits = tuple(v for v in [logits, logits_ngram] if v is not None) + return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:] + else: + return ProphetNetDecoderLMOutput( + loss=loss, + logits=logits, + logits_ngram=logits_ngram, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + hidden_states_ngram=outputs.hidden_states_ngram, + attentions=outputs.attentions, + ngram_attentions=outputs.ngram_attentions, + cross_attentions=outputs.cross_attentions, + ) + + def _compute_loss(self, logits, labels, ignore_index=-100): + expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index) + + for i in range(self.config.ngram): + if i > 0 and self.disable_ngram_loss: + break + expend_targets[i, :, :] = labels + + logits = logits.transpose(0, 1).contiguous() + lprobs = nn.functional.log_softmax( + logits.view(-1, logits.size(-1)), + dim=-1, + dtype=torch.float32, + ) + + loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") + + if self.config.eps > 0.0: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + non_masked_tokens = expend_targets.ne(ignore_index).view(-1) + smooth_loss = smooth_loss[non_masked_tokens] + smooth_loss = smooth_loss.mean() + + eps_i = self.config.eps / lprobs.size(-1) + loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss + + return loss + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + use_cache=None, + **kwargs, + ): + # Overwritten -- our tests complain if we use GenerationMixin.prepare_inputs_for_generation + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past_key_values is not None and past_key_values.get_seq_length() > 0: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + model_inputs = { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "head_mask": head_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + # Prophetnet does not support cache_position + kwargs.pop("cache_position", None) + + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + return model_inputs + + +class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): + """ + This is a wrapper class, so that [`ProphetNetForCausalLM`] can correctly be loaded from pretrained prophetnet + classes. + """ + + def __init__(self, config: ProphetNetConfig): + super().__init__(config) + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.decoder = ProphetNetDecoder(config, word_embeddings=self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings()) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +__all__ = [ + "ProphetNetDecoder", + "ProphetNetEncoder", + "ProphetNetForCausalLM", + "ProphetNetForConditionalGeneration", + "ProphetNetModel", + "ProphetNetPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/prophetnet/tokenization_prophetnet.py b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/tokenization_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..24401835c7fc62d916802c90aa7d9fede3e9db42 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/prophetnet/tokenization_prophetnet.py @@ -0,0 +1,480 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import unicodedata +from collections.abc import Iterable +from typing import Optional + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "prophetnet.tokenizer"} + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +class ProphetNetTokenizer(PreTrainedTokenizer): + r""" + Construct a ProphetNetTokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + x_sep_token (`str`, *optional*, defaults to `"[X_SEP]"`): + Special second separator token, which can be generated by [`ProphetNetForConditionalGeneration`]. It is + used to separate bullet-point like sentences in summarization, *e.g.*. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + # first name has to correspond to main model input name + # to make sure `tokenizer.pad(...)` works correctly + # `ProphetNet` doesn't have `token_type_ids` as argument. + model_input_names: list[str] = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file: str, + do_lower_case: Optional[bool] = True, + do_basic_tokenize: Optional[bool] = True, + never_split: Optional[Iterable] = None, + unk_token: Optional[str] = "[UNK]", + sep_token: Optional[str] = "[SEP]", + x_sep_token: Optional[str] = "[X_SEP]", + pad_token: Optional[str] = "[PAD]", + mask_token: Optional[str] = "[MASK]", + tokenize_chinese_chars: Optional[bool] = True, + strip_accents: Optional[bool] = None, + clean_up_tokenization_spaces: bool = True, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + x_sep_token=x_sep_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token: str): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index: int): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens: str): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def get_special_tokens_mask( + self, + token_ids_0: list[int], + token_ids_1: Optional[list[int]] = None, + already_has_special_tokens: Optional[bool] = False, + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + [self.sep_token_id] + sep = [self.sep_token_id] + return token_ids_0 + sep + token_ids_1 + sep + + +__all__ = ["ProphetNetTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e447a94c54c86d07eee6c534a6860a2a5a6d0c3e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen2 import * + from .modeling_qwen2 import * + from .tokenization_qwen2 import * + from .tokenization_qwen2_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2/configuration_qwen2.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2/configuration_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..4d75e25092f4c04ef119682f61b9b32241ffb398 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2/configuration_qwen2.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + layer_types (`list`, *optional*): + Attention pattern for each layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + layer_types=None, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen2Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2/modeling_qwen2.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2/modeling_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..262e03cb1568dad0c9b2d027fa8ebdc36301b88b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2/modeling_qwen2.py @@ -0,0 +1,498 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_qwen2 import Qwen2Config + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Qwen2PreTrainedModel(PreTrainedModel): + config: Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen2DecoderLayer, + "attentions": Qwen2Attention, + } + + +class Qwen2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen2Model(Qwen2PreTrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen2ForSequenceClassification(GenericForSequenceClassification, Qwen2PreTrainedModel): + pass + + +class Qwen2ForTokenClassification(GenericForTokenClassification, Qwen2PreTrainedModel): + pass + + +class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedModel): + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2RMSNorm", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2/modular_qwen2.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2/modular_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5ede71b704894c8044e334bfafde359c07d84f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2/modular_qwen2.py @@ -0,0 +1,244 @@ +from typing import Callable, Optional + +import torch +from packaging import version +from torch import nn + +from ...cache_utils import Cache, DynamicCache +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from ...utils.import_utils import get_torch_version +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaMLP, + LlamaPreTrainedModel, + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..mistral.modeling_mistral import MistralModel +from .configuration_qwen2 import Qwen2Config + + +logger = logging.get_logger(__name__) + + +class Qwen2MLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class Qwen2Attention(LlamaAttention): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +if version.parse(get_torch_version()) >= version.parse("2.3.0"): + + class Qwen2RMSNorm(nn.RMSNorm): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + super().__init__(normalized_shape=hidden_size, eps=eps, elementwise_affine=True) + +else: + + @use_kernel_forward_from_hub("RMSNorm") + class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__(config=config, layer_idx=layer_idx) + self.attention_type = config.layer_types[layer_idx] + + +class Qwen2PreTrainedModel(LlamaPreTrainedModel): + pass + + +class Qwen2Model(MistralModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class Qwen2ForCausalLM(LlamaForCausalLM): + pass + + +class Qwen2ForSequenceClassification(LlamaForSequenceClassification): + pass + + +class Qwen2ForTokenClassification(LlamaForTokenClassification): + pass + + +class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2RMSNorm", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2/tokenization_qwen2.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2/tokenization_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..be121adb54423069ff7454678f71a22ca2b76bc1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2/tokenization_qwen2.py @@ -0,0 +1,342 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Qwen2.""" + +import json +import os +import unicodedata +from functools import lru_cache +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} + +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +@lru_cache +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Qwen2Tokenizer(PreTrainedTokenizer): + """ + Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import Qwen2Tokenizer + + >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer") + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + You should not use GPT2Tokenizer instead, because of the different pretokenization rules. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = + ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', + '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + clean_up_tokenization_spaces=False, + split_special_tokens=False, + **kwargs, + ): + # Qwen vocab does not contain control tokens; added tokens need to be special + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_merges = [] + with open(merges_file, encoding="utf-8") as merges_handle: + for i, line in enumerate(merges_handle): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + bpe_merges.append(tuple(line.split())) + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # NOTE: the cache can grow without bound and will get really large for long running processes + # (esp. for texts of language that do not use space between word, e.g. Chinese); technically + # not a memory leak but appears as one. + # GPT2Tokenizer has the same problem, so let's be consistent. + self.cache = {} + + self.pat = re.compile(PRETOKENIZE_REGEX) + + if kwargs.get("add_prefix_space", False): + logger.warning_once( + f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." + ) + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers + # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer + return super().decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, **kwargs): + text = unicodedata.normalize("NFC", text) + return (text, kwargs) + + +__all__ = ["Qwen2Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2/tokenization_qwen2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2/tokenization_qwen2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..dda8123c7d6b255e94c906024e644439ef3f2aea --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2/tokenization_qwen2_fast.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Qwen2.""" + +from typing import Optional + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_qwen2 import Qwen2Tokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", + "tokenizer_file": "tokenizer.json", +} + + +MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} + + +class Qwen2TokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level + Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import Qwen2TokenizerFast + + >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer") + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + merges_file (`str`, *optional*): + Path to the merges file. + tokenizer_file (`str`, *optional*): + Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. Not applicable to this tokenizer. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = Qwen2Tokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + **kwargs, + ): + # We need to at least pass vocab_file and merges_file to base class + # in case a slow tokenizer needs to be initialized; other can be + # configured through files. + # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token + + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + tokenizer_file=tokenizer_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + **kwargs, + ) + + # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["Qwen2TokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7ddae0da7e1b2c16b95e94e6678f9511cb716f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen2_5_omni import * + from .modeling_qwen2_5_omni import * + from .processing_qwen2_5_omni import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd36b7a3c0dda2da0ce85be66fa4d3705ba5b81 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -0,0 +1,1091 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_omni.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerVision`]. It is used to instantiate a + Qwen2.5-VL vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2.5-VL + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + depth (`int`, *optional*, defaults to 32): + Number of layers (depth) in the model. + hidden_size (`int`, *optional*, defaults to 3584): + The size of the hidden layers. + hidden_act (`str`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function used in the model. Supported options include `"quick_gelu"` and others as applicable. + mlp_ratio (`float`, *optional*, defaults to 4): + The ratio used to determine the size of the MLP (Multi-Layer Perceptron) hidden layer. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches extracted from the input. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniVisionEncoderConfig, Qwen2_5OmniVisionEncoder + + >>> # Initializing a Qwen2_5OmniVisionEncoderConfig + >>> configuration = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder (with random weights) + >>> model = Qwen2_5OmniVisionEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_vision_encoder" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range + + +class Qwen2_5OmniAudioEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniAudioEncoder`]. It is used to instantiate a + Qwen2.5-Omni-Thinker audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_mel_bins (`int`, *optional*, defaults to 128): + Number of mel features used per input features. Should correspond to the value used in the + `Qwen2_5OmniProcessor` class. + encoder_layers (`int`, *optional*, defaults to 32): + Number of encoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + d_model (`int`, *optional*, defaults to 1280): + Dimensionality of the layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + max_source_positions (`int`, *optional*, defaults to 1500): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + n_window (`int`, *optional*, defaults to 100): + The chunk for conv and flash attn in AudioEncoder. + output_dim (`int`, *optional*, defaults to 3584): + The output dimension of AudioEncoder. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder + + >>> # Initializing a Qwen2_5OmniAudioEncoderConfig + >>> configuration = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniAudioEncoder (with random weights) + >>> model = Qwen2_5OmniAudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + d_model=1280, + dropout=0, + attention_dropout=0, + activation_function="gelu", + activation_dropout=0, + scale_embedding=False, + initializer_range=0.02, + max_source_positions=1500, + n_window=100, + output_dim=3584, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.num_hidden_layers = encoder_layers + self.initializer_range = initializer_range + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + self.n_window = n_window + self.output_dim = output_dim + + +class Qwen2_5OmniTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + layer_types (`list`, *optional*): + Attention pattern for each layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen25OmniText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + layer_types=None, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + if self.rope_scaling is None: + self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + +class Qwen2_5OmniThinkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_config (`dict`, *optional*): + The config dictionary of the audio backbone. + vision_config (`dict`, *optional*): + The config dictionary of the vision backbone. + text_config (`dict`, *optional*): + The config dictionary of the text backbone. + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + user_token_id (`int, *optional*, defaults to 872): + The user token index to encode the user token. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniTextConfig config + >>> text_config = Qwen2_5OmniTextConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config, text_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_thinker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + sub_configs = { + "audio_config": Qwen2_5OmniAudioEncoderConfig, + "vision_config": Qwen2_5OmniVisionEncoderConfig, + "text_config": Qwen2_5OmniTextConfig, + } + + def __init__( + self, + audio_config=None, + vision_config=None, + text_config=None, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + user_token_id=872, + initializer_range=0.02, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.user_token_id = user_token_id + self.position_id_per_seconds = position_id_per_seconds + self.seconds_per_chunk = seconds_per_chunk + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + self.initializer_range = initializer_range + + if isinstance(vision_config, dict): + vision_config = Qwen2_5OmniVisionEncoderConfig(**vision_config) + elif vision_config is None: + vision_config = Qwen2_5OmniVisionEncoderConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Qwen2_5OmniAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen2_5OmniAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config = Qwen2_5OmniTextConfig(**text_config) + elif text_config is None: + text_config = Qwen2_5OmniTextConfig() + self.text_config = text_config + + super().__init__(**kwargs) + + +class Qwen2_5OmniTalkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniTalkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Talker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + vocab_size (`int`, *optional*, defaults to 8448): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + tts_text_start_token_id (`int`, *optional*, defaults to 151860): + The tts text start token index to encode the start of tts text. + tts_text_end_token_id (`int`, *optional*, defaults to 151861): + The tts text end token index to encode the end of tts text. + tts_text_pad_token_id (`int`, *optional*, defaults to 151859): + The tts text pad token index to encode the pad of tts text. + tts_codec_start_token_id (`int`, *optional*, defaults to 8293): + The tts codec start token index to encode the start of tts codec. + tts_codec_end_token_id (`int`, *optional*, defaults to 8294): + The tts codec end token index to encode the end of tts codec. + tts_codec_pad_token_id (`int`, *optional*, defaults to 8292): + The tts codec pad token index to encode the pad of tts codec. + tts_codec_mask_token_id (`int`, *optional*, defaults to 8296): + The tts codec mask token index to encode the mask of tts codec. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The tts vision start token index to encode the start of vision. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The tts vision end token index to encode the end of vision. + embedding_size (`int`, *optional*, defaults to 3584): + Dimension of the embedding representations. + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + head_dim (`int`, *optional*, defaults to 128): + The dimension of each attention head. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + layer_types (`list`, *optional*): + Attention pattern for each layer. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniTalkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2 config + >>> text_config = Qwen2Config() + + >>> # Initializing a Qwen2_5Omni configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, text_config) + + >>> # Initializing a model from the qwen2-audio style configuration + >>> model = Qwen2_5OmniTalkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_talker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + + def __init__( + self, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + vocab_size=8448, + tts_text_start_token_id=151860, + tts_text_end_token_id=151861, + tts_text_pad_token_id=151859, + tts_codec_start_token_id=8293, + tts_codec_end_token_id=8294, + tts_codec_pad_token_id=8292, + tts_codec_mask_token_id=8296, + vision_start_token_id=151652, + vision_end_token_id=151653, + embedding_size=3584, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + rms_norm_eps=1e-06, + head_dim=128, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + attention_dropout=0.0, + rope_scaling=None, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + initializer_range=0.02, + spatial_merge_size=2, + layer_types=None, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + + self.tts_text_start_token_id = tts_text_start_token_id + self.tts_text_end_token_id = tts_text_end_token_id + self.tts_text_pad_token_id = tts_text_pad_token_id + self.tts_codec_start_token_id = tts_codec_start_token_id + self.tts_codec_end_token_id = tts_codec_end_token_id + self.tts_codec_pad_token_id = tts_codec_pad_token_id + + self.tts_codec_mask_token_id = tts_codec_mask_token_id + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + self.vocab_size = vocab_size + self.head_dim = head_dim + self.embedding_size = embedding_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.position_id_per_seconds = position_id_per_seconds # zf + self.seconds_per_chunk = seconds_per_chunk # zf + self.audio_start_token_id = audio_start_token_id # zf + self.audio_end_token_id = audio_end_token_id # zf + + self.initializer_range = initializer_range + self.spatial_merge_size = spatial_merge_size + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen2_5OmniDiTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavDiT used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + The dimension of the model. + num_hidden_layers (`int`, *optional*, defaults to 22): + The number of transformer blocks in the DiT model. + num_attention_heads (`int`, *optional*, defaults to 16): + The number of attention heads in each transformer block. + ff_mult (`int`, *optional*, defaults to 2): + The multiplier for the feedforward layer in each transformer block. + emb_dim (`int`, *optional*, defaults to 512): + The dimension of the embedding layer. + head_dim (`int`, *optional*, defaults to 64): + The dimension of each attention head. + repeats (`int`, *optional*, defaults to 2): + The number of times the codec embeddings are repeated. + num_embeds (`int`, *optional*, defaults to 8193): + The number of unique embeddings in the codec. + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the transformer blocks. + + enc_emb_dim (`int`, *optional*, defaults to 192): + The dimension of the pre-trained speaker embedding. + enc_dim (`int`, *optional*, defaults to 128): + The dimension of the encoder output. + enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`): + A list of output channels for each TDNN/SERes2Net layer in the encoder. + enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`): + A list of kernel sizes for each layer in the encoder. + enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`): + A list of dilations for each layer in the encoder. + enc_attention_channels (`int`, *optional*, defaults to 64): + The number of attention channels in the SqueezeExcitationBlock. + enc_res2net_scale (`int`, *optional*, defaults to 2): + The scale of the Res2Net block in the encoder. + enc_se_channels (`int`, *optional*, defaults to 64): + The number of output channels after squeeze in the SqueezeExcitationBlock. + """ + + model_type = "qwen2_5_omni_dit" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=22, + num_attention_heads=16, + ff_mult=2, + emb_dim=512, + head_dim=64, + rope_theta=10000.0, + max_position_embeddings=32768, + block_size=24, + look_ahead_layers=[10], + look_backward_layers=[0, 20], + repeats=2, + num_embeds=8193, + mel_dim=80, + dropout=0.1, + enc_emb_dim=192, + enc_dim=128, + enc_channels=[256, 256, 256, 256, 768], + enc_kernel_sizes=[5, 3, 3, 3, 1], + enc_dilations=[1, 2, 3, 4, 1], + enc_attention_channels=64, + enc_res2net_scale=2, + enc_se_channels=64, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.ff_mult = ff_mult + self.emb_dim = emb_dim + self.head_dim = head_dim + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.block_size = block_size + self.look_ahead_layers = look_ahead_layers + self.look_backward_layers = look_backward_layers + self.repeats = repeats + self.num_embeds = num_embeds + self.mel_dim = mel_dim + self.dropout = dropout + self.enc_emb_dim = enc_emb_dim + self.enc_dim = enc_dim + self.enc_channels = enc_channels + self.enc_kernel_sizes = enc_kernel_sizes + self.enc_dilations = enc_dilations + self.enc_attention_channels = enc_attention_channels + self.enc_res2net_scale = enc_res2net_scale + self.enc_se_channels = enc_se_channels + super().__init__(**kwargs) + + +class Qwen2_5OmniBigVGANConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavBigVGAN module used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms. + + Args: + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + upsample_initial_channel (`int`, *optional*, defaults to 1536): + The number of channels in the initial upsampling layer. + resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`): + A list of kernel sizes for each residual block. + resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A list of dilation sizes for each residual block. + upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`): + A list of upsampling rates for each upsampling layer. + upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`): + A list of kernel sizes for each upsampling layer. + """ + + model_type = "qwen2_5_omni_bigvgan" + + def __init__( + self, + mel_dim=80, + upsample_initial_channel=1536, + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[5, 3, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 7, 4, 4, 4, 4], + **kwargs, + ): + self.mel_dim = mel_dim + self.upsample_initial_channel = upsample_initial_channel + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + super().__init__(**kwargs) + + +class Qwen2_5OmniToken2WavConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniToken2WavModel`]. + It is used to instantiate the Qwen2.5-Omni-Token2Wav model which combines a Diffusion Transformer (DiT) for mel-spectrogram generation with a BigVGAN model for waveform synthesis. The configuration contains sub-configurations for both components. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + dit_config ([`DiT_Args`], *optional*): + Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms. + bigvgan_config ([`BigVGAN_Args`], *optional*): + Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms. + Example: + + ```python + >>> from transformers import Qwen2_5OmniToken2WavModel, DiT_Args, BigVGAN_Args + + >>> # Initialize DiT configuration + >>> dit_config = DiT_Args( + ... dim=1024, + ... depth=22, + ... heads=16, + ... ff_mult=2 + ... ) + + >>> # Initialize BigVGAN configuration + >>> bigvgan_config = BigVGAN_Args( + ... mel_dim=80, + ... upsample_rates=[5,3,2,2,2,2] + ... ) + + >>> # Initialize main configuration + >>> config = Qwen2_5OmniToken2WavConfig(dit_config, bigvgan_config) + + >>> # Initialize model with config + >>> model = Qwen2_5OmniToken2Wav(config) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni_token2wav" + sub_configs = { + "dit_config": Qwen2_5OmniDiTConfig, + "bigvgan_config": Qwen2_5OmniBigVGANConfig, + } + + def __init__(self, dit_config=None, bigvgan_config=None, **kwargs): + if dit_config is None: + dit_config = {} + if bigvgan_config is None: + bigvgan_config = {} + self.dit_config = Qwen2_5OmniDiTConfig(**dit_config) + self.bigvgan_config = Qwen2_5OmniBigVGANConfig(**bigvgan_config) + super().__init__(**kwargs) + + +class Qwen2_5OmniConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen2_5OmniForConditionalGeneration`]. It is used to instantiate a Qwen2.5Omni + model according to the specified sub-models configurations, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the + [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model. + talker_config (`dict`, *optional*): Configuration of the underlying talker sub-model. + token2wav_config (`dict`, *optional*): Configuration of the underlying codec sub-model. + enable_audio_output (`bool`, *optional*, defaults to `True`): Whether enable audio output and load talker and token2wav module. + + Example: + + ```python + >>> from transformers import ( + ... Qwen2_5OmniThinkerConfig, + ... Qwen2_5OmniTalkerConfig, + ... Qwen2_5OmniToken2WavConfig, + ... Qwen2_5OmniForConditionalGeneration, + ... Qwen2_5OmniConfig, + ... ) + + >>> # Initializing sub-modules configurations. + >>> thinker_config = Qwen2_5OmniThinkerConfig() + >>> talker_config = Qwen2_5OmniTalkerConfig() + >>> token2wav_config = Qwen2_5OmniToken2WavConfig() + + + >>> # Initializing a module style configuration + >>> configuration = Qwen2_5OmniConfig.from_sub_model_configs( + ... thinker_config, talker_config, token2wav_config + ... ) + + >>> # Initializing a model (with random weights) + >>> model = Qwen2_5OmniForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni" + sub_configs = { + "thinker_config": Qwen2_5OmniThinkerConfig, + "talker_config": Qwen2_5OmniTalkerConfig, + "token2wav_config": Qwen2_5OmniToken2WavConfig, + } + + def __init__( + self, + thinker_config=None, + talker_config=None, + token2wav_config=None, + enable_audio_output: bool = True, + **kwargs, + ): + if thinker_config is None: + thinker_config = {} + logger.info("thinker_config is None. Initializing thinker model with default values") + + if talker_config is None: + talker_config = {} + logger.info("talker_config is None. Initializing talker model with default values") + + if token2wav_config is None: + token2wav_config = {} + logger.info("token2wav_config is None. Initializing token2wav model with default values") + + self.thinker_config = Qwen2_5OmniThinkerConfig(**thinker_config) + self.talker_config = Qwen2_5OmniTalkerConfig(**talker_config) + self.token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config) + self.enable_audio_output = enable_audio_output + + super().__init__(**kwargs) + + def get_text_config(self, *args, **kwargs): + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + Args: + decoder (`Optional[bool]`, *optional*, defaults to `False`): + If set to `True`, then only search for decoder config names. + """ + # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model + # except for Qwen yet. This has to be generalized if more deeply nested configs are + # added. NOTE: currently method used only by vLLM + return self.thinker_config.get_text_config(*args, **kwargs) + + +__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8910d270bb58a723cfe1766bcc94721950f4f1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -0,0 +1,4020 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_omni.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.hub import cached_file +from ..qwen2.modeling_qwen2 import Qwen2RMSNorm +from .configuration_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoderConfig, + Qwen2_5OmniBigVGANConfig, + Qwen2_5OmniConfig, + Qwen2_5OmniDiTConfig, + Qwen2_5OmniTalkerConfig, + Qwen2_5OmniTextConfig, + Qwen2_5OmniThinkerConfig, + Qwen2_5OmniToken2WavConfig, + Qwen2_5OmniVisionEncoderConfig, +) + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class Qwen2_5OmniPreTrainedModel(PreTrainedModel): + config: Qwen2_5OmniConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5OmniDecoderLayer", "Qwen2_5OmniVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = False + _supports_attention_backend = True + + +class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): + def _prepare_4d_causal_attention_mask_with_cache_position( + self, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_llm_pos_ids_for_vision( + self, + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: list[int], + grid_ws: list[int], + ): + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + def get_chunked_index( + self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int + ) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`. + + For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that: + - the first chunk contains token values < 1000, + - the second chunk contains values >= 1000 and < 2000, and so on. + + Parameters: + token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of + token index values. + t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). + remove_index (`int`) An index id to subtract from `token_indices` before chunking + + Returns: + `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive) + and end (exclusive) indices of a chunk in `token_indices`. + """ + + def _iter(): + i, start_idx = 0, 0 # skip bos token + current_chunk = 1 + while i < len(token_indices): # skip eos token + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + audio_seqlens: Optional[torch.LongTensor] = None, + second_per_grids: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + use_audio_in_video (`bool`, *optional*): + If set to `True`, use the audio in video. + audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + second_per_grids (`torch.LongTensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id + vision_start_token_id = self.config.vision_start_token_id + audio_start_token_id = self.config.audio_start_token_id + position_id_per_seconds = self.config.position_id_per_seconds + seconds_per_chunk = self.config.seconds_per_chunk + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is not None: + attention_mask = attention_mask == 1 + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i]] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = ( + image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + ) + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio = input_tokens.index(audio_token_id, st) + else: + ed_audio = len(input_tokens) + 1 + min_ed = min(ed_image, ed_video, ed_audio) + if min_ed == ed_audio: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + + elif min_ed == ed_image: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + + elif min_ed == ed_video and not use_audio_in_video: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + + elif min_ed == ed_video and use_audio_in_video: + text_len = min_ed - st - 2 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + video_llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + + t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + sub_len = 0 + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None + audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None + if video_chunk_index is not None: + sub_len += video_chunk_index[1] - video_chunk_index[0] + + llm_pos_ids_list.append( + video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]] + ) + if audio_chunk_index is not None: + sub_len += audio_chunk_index[1] - audio_chunk_index[0] + + llm_pos_ids_list.append( + audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + + if attention_mask is not None: + position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) + else: + position_ids[..., i, :] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device) + + return position_ids, mrope_position_deltas + else: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas + + +############################ +# Start Thinker # +############################ + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen2.5OmniThinker causal language model (or autoregressive) outputs. + """ +) +class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen2_5OmniAudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Qwen2_5OmniAudioEncoderConfig, + ): + super().__init__() + self.embed_dim = config.d_model + self.num_heads = config.encoder_attention_heads + self.dropout = config.attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.config = config + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 + self.is_decoder = False + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = Qwen2_5OmniAudioAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +@auto_docstring( + custom_intro=""" + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Qwen2_5OmniAudioEncoderLayer`]. + """ +) +class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniAudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"] + _supports_sdpa = True + + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.audio_bos_eos_token = nn.Embedding(2, config.output_dim) + self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.ln_post = nn.LayerNorm(config.d_model) + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(config.d_model, config.output_dim) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + + @auto_docstring + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + **kwargs, + ): + r""" + feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length after cnn + """ + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + + chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( + chunk_list, chunk_lengths, padding_value=0, padding_side="right" + ) + padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) + + padded_embed = padded_embed + self.positional_embedding.positional_embedding[ + : padded_embed.shape[1], : + ].unsqueeze(0).to(padded_embed.dtype) + hidden_states = padded_embed[padded_mask_after_cnn] + cu_seqlens = torch.cat( + ( + torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0), + ) + ).to(torch.int32) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) + token_audio_list = [] + for each_audio_states in hidden_states_list: + each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) + each_audio_states = self.ln_post(each_audio_states) + each_audio_states = self.proj(each_audio_states) + token_audio_list.append(each_audio_states) + token_audio = torch.cat(token_audio_list, dim=0) + return BaseModelOutput(last_hidden_state=token_audio) + + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5OmniVisionAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.q = nn.Linear(self.dim, self.dim, bias=True) + self.k = nn.Linear(self.dim, self.dim, bias=True) + self.v = nn.Linear(self.dim, self.dim, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.num_key_value_groups = 1 # needed for eager attention + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen2_5OmniVisionAttention(config=config) + self.mlp = Qwen2_5OmniMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2_5OmniPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniVisionEncoderConfig + _no_split_modules = ["Qwen2_5OmniVisionBlock"] + + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen2_5OmniPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Modification here + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + **kwargs, + ) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5OmniRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5Omni has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2_5OmniAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + position_ids=position_ids, # pass positions for FA2 + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2MLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5OmniDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = Qwen2_5OmniAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniTextConfig + _no_split_modules = ["Qwen2_5OmniDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions + # where each dim indicates visual spatial positions for temporal/height/width grids. + # There are two scenarios when FA2-like packed masking might be activated. + # 1. User specifically passed packed `position_ids` and no attention mask. + # In this case we expect the useer to create correct position ids for all 3 grids + # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len] + # 2. User runs forward with no attention mask and no position ids. In this case, position ids + # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are + # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass + # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation` + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids + text_position_ids = None + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": text_position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=text_position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring( + custom_intro=""" + The Qwen2.5OmniThinker model which consists of a audio backbone and a language model. + """ +) +class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config: Qwen2_5OmniThinkerConfig + base_model_prefix = "thinker" + _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] + + def __init__(self, config: Qwen2_5OmniThinkerConfig): + super().__init__(config) + self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(config.audio_config) + self.visual = Qwen2_5OmniVisionEncoder._from_config(config.vision_config) + self.vocab_size = config.text_config.vocab_size + self.model = Qwen2_5OmniThinkerTextModel._from_config(config.text_config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.rope_deltas = None + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + ): + """ + Encodes audios into continuous embeddings that can be forwarded to the language model. + + Args: + input_features (`torch.FloatTensor`): + The tensors corresponding to the input audios. + feature_attention_mask (`torch.LongTensor`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + """ + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + + audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + + if audio_features.shape[0] != sum(audio_output_lengths.tolist()): + raise ValueError("length of audio_features should match audio_output_lengths") + + return audio_features + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + special_audio_mask = input_ids == self.config.audio_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask, special_video_mask, special_audio_mask + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + feature_attention_mask: Optional[torch.Tensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_audio_in_video: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_audio_in_video (`bool`, *optional*): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*): + Number of seconds per grid for each video, used for temporal feature mapping. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from qwen_vl_utils import process_vision_info + >>> from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration + + >>> thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B") + >>> processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") + + >>> conversations = [ + >>> {'role': 'system', 'content': 'You are a helpful voice chat bot, and please respond to me in a casual conversation manner using random voice.'}, + >>> {"role": "user", "content": [ + >>> {"type": "image", "image_url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + >>> {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, + >>> ]}, + >>> ] + + >>> text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + >>> audios = [ librosa.load(BytesIO(urlopen( conversations[1]['content'][1]['audio_url'] ).read()), sr=self.processor.feature_extractor.sampling_rate) ] + >>> images, videos = process_vision_info(conversations) + >>> inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True) + + >>> # Generate + >>> inputs['use_audio_in_video'] = `True` or `False` + >>> generation = thinker.generate(**inputs, max_new_tokens=2048) + >>> generate_ids = generation[:, inputs.input_ids.size(1):] + + >>> response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text , audios , image and video + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + else: + audio_feature_lengths = None + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + rope_deltas = rope_deltas - delta0 + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size + ) + + if not return_dict: + output = (logits,) + outputs + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniThinkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_features=None, + feature_attention_mask=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + model_inputs["input_features"] = None + + return model_inputs + + +############################ +# Start Talker # +############################ + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen2.5OmniTalker causal language model (or autoregressive) outputs. + """ +) +class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Hidden states from the thinker model that are used as input for the talker model. These represent the encoded + response that the talker model will use to generate speech tokens. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + thinker_reply_part: Optional[torch.FloatTensor] = None + + +@auto_docstring +class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniTalkerConfig + _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions + # where each dim indicates visual spatial positions for temporal/height/width grids. + # There are two scenarios when FA2-like packed masking might be activated. + # 1. User specifically passed packed `position_ids` and no attention mask. + # In this case we expect the useer to create correct position ids for all 3 grids + # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len] + # 2. User runs forward with no attention mask and no position ids. In this case, position ids + # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are + # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass + # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation` + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids + text_position_ids = None + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": text_position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=text_position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config: Qwen2_5OmniTalkerConfig + base_model_prefix = "talker" + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + + self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size) + + self.model = Qwen2_5OmniTalkerModel(config) + self.codebook_size = config.vocab_size + self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False) + + self.codec_bos_token = config.tts_codec_start_token_id + self.codec_eos_token = config.tts_codec_end_token_id + self.codec_pad_token = config.tts_codec_pad_token_id + self.codec_mask_token = config.tts_codec_mask_token_id + + self.text_bos_token = config.tts_text_start_token_id + self.text_eos_token = config.tts_text_end_token_id + self.text_pad_token = config.tts_text_pad_token_id + + self.spatial_merge_size = self.config.spatial_merge_size + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + thinker_reply_part: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + input_text_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_audio_in_video: Optional[bool] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]: + r""" + thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Hidden states from the thinker model's output that represent the text reply part to be processed. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + input_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Input token IDs for text-only content, used for position calculation in multimodal contexts. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + use_audio_in_video (`bool`, *optional*): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*): + Number of seconds per grid for each video, used for temporal feature mapping. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2_5OmniTalkerForConditionalGeneration + + >>> model = Qwen2_5OmniTalkerForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + position_ids, rope_deltas = self.get_rope_index( + input_text_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + + inputs_embeds[:, -1, :] += self.get_input_embeddings()( + torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device) + ) + inputs_embeds[:, -2, :] += self.get_input_embeddings()( + torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device) + ) + self.rope_deltas = rope_deltas + + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if inputs_embeds is None: + # 1. Inference tokens after second token + codec_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :] + if thinker_reply_part.shape[1] > 1: + thinker_reply_part = thinker_reply_part[:, 1:, :] + + talker_lm_input = self.thinker_to_talker_proj(inputs_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=talker_lm_input, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.codec_head(hidden_states) + logits = logits.float() + + loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniTalkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + thinker_reply_part=thinker_reply_part, + ) + + def _get_initial_cache_position(self, seq_length, device, model_kwargs): + # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily + inputs_embeds = model_kwargs.pop("inputs_embeds") + model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs) + model_kwargs["inputs_embeds"] = inputs_embeds + return model_kwargs + + # prepare inputs for talker lm generation + def prepare_inputs_for_generation( + self, + input_ids, + input_text_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + thinker_reply_part=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_audio_features=None, + audio_feature_attention_mask=None, + audio_feature_lengths=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values, + attention_mask, + inputs_embeds, + cache_position, + use_cache=use_cache, + thinker_reply_part=thinker_reply_part, + input_text_ids=input_text_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_audio_in_video=use_audio_in_video, + audio_feature_lengths=audio_feature_lengths, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + + if getattr(outputs, "thinker_reply_part", None) is not None: + model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part + + return model_kwargs + + +############################ +# Start Token2Wav # +############################ + + +# Using custom RoPE, will use LlamaRotaryEmbedding next version +class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + """ + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: Optional[bool] = False, + code_embed_uncond: Optional[bool] = None, + apply_cfg: Optional[bool] = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class Qwen2_5_OmniAdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self.is_causal = False + + self.to_q = nn.Linear(config.hidden_size, self.inner_dim) + self.to_k = nn.Linear(config.hidden_size, self.inner_dim) + self.to_v = nn.Linear(config.hidden_size, self.inner_dim) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = position_embeddings + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://huggingface.co/papers/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate") + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels + ) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate") + out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise TypeError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=self._get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=self._get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=self._get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, hidden_states): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + residual = hidden_states + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniBigVGANConfig + + def __init__(self, config: Qwen2_5OmniBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu() + + +class RungeKutta4ODESolver: + def __init__(self, function, initial_value): + self.function = function + self.initial_value = initial_value + + self._one_third = 1 / 3 + self._two_thirds = 2 / 3 + + def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None): + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third) + k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third)) + k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) + return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 + + def _compute_step(self, function, time_start, time_step, time_end, value_start): + function_value_start = function(time_start, value_start) + return self._rk4_step( + function, time_start, time_step, time_end, value_start, function_value_start=function_value_start + ), function_value_start + + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + if time_point == time_start: + return value_start + if time_point == time_end: + return value_end + weight = (time_point - time_start) / (time_end - time_start) + return value_start + weight * (value_end - value_start) + + def integrate(self, time_points): + solution = torch.empty( + len(time_points), + *self.initial_value.shape, + dtype=self.initial_value.dtype, + device=self.initial_value.device, + ) + solution[0] = self.initial_value + + current_index = 1 + current_value = self.initial_value + for time_start, time_end in zip(time_points[:-1], time_points[1:]): + time_step = time_end - time_start + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + next_value = current_value + delta_value + + while current_index < len(time_points) and time_end >= time_points[current_index]: + solution[current_index] = self._linear_interpolation( + time_start, time_end, current_value, next_value, time_points[current_index] + ) + current_index += 1 + + current_value = next_value + + return solution + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram. + """ +) +class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + @torch.no_grad() + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ): + noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype) + maximum_duration = quantized_code.shape[1] * self.repeats + initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniToken2WavConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"] + + def __init__(self, config: Qwen2_5OmniToken2WavConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( + config.dit_config, attn_implementation=attn_impl + ) + self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.code2wav_dit_model.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + ) + + waveform = self.code2wav_bigvgan_model(mel_spectrogram) + + return waveform + + +############################ +# Start Qwen2.5Omni # +############################ + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni model, a multimodal model composed of 3 sub-models: + - [`Qwen2_5OmniThinkerForConditionalGeneration`]: + a causal auto-regressive transformer takes text, audio, image, video as input and predict text tokens. + - [`Qwen2_5OmniTalkerForConditionalGeneration`]: + a causal auto-regressive transformer takes thinker hidden states and response as input and predict speech tokens. + - [`Qwen2_5OmniToken2WavModel`]: + a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, GenerationMixin): + config: Qwen2_5OmniConfig + _no_split_modules = [ + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavModel", + ] + + def __init__(self, config): + super().__init__(config) + + self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config) + + self.has_talker = config.enable_audio_output + self.speaker_map = {} + if config.enable_audio_output: + self.enable_talker() + self.post_init() + + def enable_talker(self): + self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config) + self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config) + self.token2wav.float() + self.has_talker = True + + def load_speakers(self, path): + check_torch_load_is_safe() + for key, value in torch.load(path, weights_only=True).items(): + self.speaker_map[key] = value + logger.info(f"Speaker {list(self.speaker_map.keys())} loaded") + + def disable_talker(self): + if hasattr(self, "talker"): + del self.talker + if hasattr(self, "token2wav"): + del self.token2wav + self.has_talker = False + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + cache_dir=None, + ignore_mismatched_sizes=False, + force_download=False, + local_files_only=False, + token=None, + revision="main", + use_safetensors=None, + weights_only=True, + **kwargs, + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + spk_path = cached_file( + pretrained_model_name_or_path, + "spk_dict.pt", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + if spk_path is None: + raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""") + model.load_speakers(spk_path) + + return model + + @torch.no_grad() + # TODO: raushan, defaults should be saved in generation config + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + speaker: str = "Chelsie", + use_audio_in_video: bool = False, + return_audio: Optional[bool] = None, + thinker_max_new_tokens: int = 1024, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 40, + talker_top_p: float = 0.8, + talker_temperature: float = 0.9, + talker_eos_token_id: list[int] = [8292, 8294], + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + r""" + Generate text response and audio from input. + + Args: + input_ids (`Optional[torch.Tensor]`, *optional*): + Input ids, should obtain from processor. + speaker (`str` , defaults to "Chelsie"): + Which speaker should be used in audio response. + use_audio_in_video (`bool`, defaults to False): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + return_audio (`Optional[bool]`, *optional*): + Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`. + kwargs (*optional*): + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model. + - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the + thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix. + Returns: + When `return_audio=False`: + - **Text** (`torch.Tensor`): Generated text token sequence. + When `return_audio=True`: + - **Text** (`torch.Tensor`): Generated text token sequence. + - **Audio waveform** (`torch.Tensor`): Generated audio waveform. + """ + if speaker not in self.speaker_map: + raise ValueError(f"{speaker} is not available, available speakers: {self.speaker_map.keys()}") + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initialized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + if input_ids.shape[0] != 1 and return_audio: + raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output") + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + } + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": talker_eos_token_id, + "repetition_penalty": talker_repetition_penalty, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key == "input_features" or key == "attention_mask": + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + speaker_params = self.speaker_map[speaker] + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + + if not generate_audio: + return thinker_result + + # 2. Generate speech tokens from talker module + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device) + if thinker_kwargs.get("input_features") is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=input_ids.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if thinker_kwargs.get("pixel_values") is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=input_ids.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if thinker_kwargs.get("pixel_values_videos") is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=input_ids.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device) + thinker_token_embeds = [ + token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden + ] + + talker_text_bos_token = speaker_params["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids, + torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + + talker_input_ids = torch.cat( + [ + torch.full_like(input_ids, fill_value=self.talker.codec_mask_token), + torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device), + torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device), + ], + dim=1, + ) + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + + eos_token = torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device) + eos_embedding = thinker_embed_tokens(eos_token).to(input_ids.device) + + pad_token = torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device) + pad_embedding = thinker_embed_tokens(pad_token).to(input_ids.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + + talker_attention_mask = None + if "attention_mask" in kwargs: + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(input_ids.device) + + talker_result = self.talker.generate( + input_ids=talker_input_ids, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + **{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()}, + ) + talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] + + # 3. Generate wavs from code + if self.token2wav.dtype != torch.float: + self.token2wav.float() + + wav = self.token2wav( + talker_generate_codes.to(input_ids.device), + conditioning=speaker_params["cond"].to(input_ids.device).float(), + reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(), + **token2wav_kwargs, + ) + + return thinker_result.sequences, wav.float() + + +__all__ = [ + "Qwen2_5OmniForConditionalGeneration", + "Qwen2_5OmniThinkerTextModel", + "Qwen2_5OmniThinkerForConditionalGeneration", + "Qwen2_5OmniTalkerModel", + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavDiTModel", + "Qwen2_5OmniToken2WavBigVGANModel", + "Qwen2_5OmniToken2WavModel", + "Qwen2_5OmniPreTrainedModel", + "Qwen2_5OmniPreTrainedModelForConditionalGeneration", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..b63c301f36c3f1c9327ed8e73dcc946113a163bb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -0,0 +1,4322 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5Omni model (Audio, Image, Video).""" + +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter + +from transformers.models.llama.modeling_llama import rotate_half +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLAttention, + Qwen2_5_VLMLP, + Qwen2_5_VLPreTrainedModel, + Qwen2_5_VLTextModel, + Qwen2_5_VLVisionBlock, + eager_attention_forward, +) +from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig +from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding + +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + check_torch_load_is_safe, + logging, +) +from ...utils.hub import cached_file + + +logger = logging.get_logger(__name__) + + +class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerVision`]. It is used to instantiate a + Qwen2.5-VL vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2.5-VL + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + depth (`int`, *optional*, defaults to 32): + Number of layers (depth) in the model. + hidden_size (`int`, *optional*, defaults to 3584): + The size of the hidden layers. + hidden_act (`str`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function used in the model. Supported options include `"quick_gelu"` and others as applicable. + mlp_ratio (`float`, *optional*, defaults to 4): + The ratio used to determine the size of the MLP (Multi-Layer Perceptron) hidden layer. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches extracted from the input. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniVisionEncoderConfig, Qwen2_5OmniVisionEncoder + + >>> # Initializing a Qwen2_5OmniVisionEncoderConfig + >>> configuration = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder (with random weights) + >>> model = Qwen2_5OmniVisionEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_vision_encoder" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, + **kwargs, + ): + super().__init__( + depth, + hidden_size, + hidden_act, + intermediate_size, + num_heads, + in_channels, + patch_size, + spatial_merge_size, + temporal_patch_size, + window_size, + out_hidden_size, + fullatt_block_indexes, + initializer_range=initializer_range, + **kwargs, + ) + del self.tokens_per_second + + +class Qwen2_5OmniAudioEncoderConfig(Qwen2AudioEncoderConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniAudioEncoder`]. It is used to instantiate a + Qwen2.5-Omni-Thinker audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_mel_bins (`int`, *optional*, defaults to 128): + Number of mel features used per input features. Should correspond to the value used in the + `Qwen2_5OmniProcessor` class. + encoder_layers (`int`, *optional*, defaults to 32): + Number of encoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + d_model (`int`, *optional*, defaults to 1280): + Dimensionality of the layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + max_source_positions (`int`, *optional*, defaults to 1500): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + n_window (`int`, *optional*, defaults to 100): + The chunk for conv and flash attn in AudioEncoder. + output_dim (`int`, *optional*, defaults to 3584): + The output dimension of AudioEncoder. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder + + >>> # Initializing a Qwen2_5OmniAudioEncoderConfig + >>> configuration = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniAudioEncoder (with random weights) + >>> model = Qwen2_5OmniAudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + d_model=1280, + dropout=0, + attention_dropout=0, + activation_function="gelu", + activation_dropout=0, + scale_embedding=False, + initializer_range=0.02, + max_source_positions=1500, + n_window=100, + output_dim=3584, + **kwargs, + ): + super().__init__( + num_mel_bins, + encoder_layers, + encoder_attention_heads, + encoder_ffn_dim, + d_model, + dropout, + attention_dropout, + activation_function, + activation_dropout, + scale_embedding, + initializer_range, + max_source_positions, + **kwargs, + ) + self.n_window = n_window + self.output_dim = output_dim + del self.encoder_layerdrop + + +class Qwen2_5OmniTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + layer_types (`list`, *optional*): + Attention pattern for each layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen25OmniText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + layer_types=None, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + if self.rope_scaling is None: + self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + +class Qwen2_5OmniThinkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_config (`dict`, *optional*): + The config dictionary of the audio backbone. + vision_config (`dict`, *optional*): + The config dictionary of the vision backbone. + text_config (`dict`, *optional*): + The config dictionary of the text backbone. + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + user_token_id (`int, *optional*, defaults to 872): + The user token index to encode the user token. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniTextConfig config + >>> text_config = Qwen2_5OmniTextConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config, text_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_thinker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + sub_configs = { + "audio_config": Qwen2_5OmniAudioEncoderConfig, + "vision_config": Qwen2_5OmniVisionEncoderConfig, + "text_config": Qwen2_5OmniTextConfig, + } + + def __init__( + self, + audio_config=None, + vision_config=None, + text_config=None, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + user_token_id=872, + initializer_range=0.02, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.user_token_id = user_token_id + self.position_id_per_seconds = position_id_per_seconds + self.seconds_per_chunk = seconds_per_chunk + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + self.initializer_range = initializer_range + + if isinstance(vision_config, dict): + vision_config = Qwen2_5OmniVisionEncoderConfig(**vision_config) + elif vision_config is None: + vision_config = Qwen2_5OmniVisionEncoderConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Qwen2_5OmniAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen2_5OmniAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config = Qwen2_5OmniTextConfig(**text_config) + elif text_config is None: + text_config = Qwen2_5OmniTextConfig() + self.text_config = text_config + + super().__init__(**kwargs) + + +class Qwen2_5OmniTalkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniTalkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Talker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + vocab_size (`int`, *optional*, defaults to 8448): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + tts_text_start_token_id (`int`, *optional*, defaults to 151860): + The tts text start token index to encode the start of tts text. + tts_text_end_token_id (`int`, *optional*, defaults to 151861): + The tts text end token index to encode the end of tts text. + tts_text_pad_token_id (`int`, *optional*, defaults to 151859): + The tts text pad token index to encode the pad of tts text. + tts_codec_start_token_id (`int`, *optional*, defaults to 8293): + The tts codec start token index to encode the start of tts codec. + tts_codec_end_token_id (`int`, *optional*, defaults to 8294): + The tts codec end token index to encode the end of tts codec. + tts_codec_pad_token_id (`int`, *optional*, defaults to 8292): + The tts codec pad token index to encode the pad of tts codec. + tts_codec_mask_token_id (`int`, *optional*, defaults to 8296): + The tts codec mask token index to encode the mask of tts codec. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The tts vision start token index to encode the start of vision. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The tts vision end token index to encode the end of vision. + embedding_size (`int`, *optional*, defaults to 3584): + Dimension of the embedding representations. + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + head_dim (`int`, *optional*, defaults to 128): + The dimension of each attention head. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + layer_types (`list`, *optional*): + Attention pattern for each layer. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniTalkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2 config + >>> text_config = Qwen2Config() + + >>> # Initializing a Qwen2_5Omni configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, text_config) + + >>> # Initializing a model from the qwen2-audio style configuration + >>> model = Qwen2_5OmniTalkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_talker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + + def __init__( + self, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + vocab_size=8448, + tts_text_start_token_id=151860, + tts_text_end_token_id=151861, + tts_text_pad_token_id=151859, + tts_codec_start_token_id=8293, + tts_codec_end_token_id=8294, + tts_codec_pad_token_id=8292, + tts_codec_mask_token_id=8296, + vision_start_token_id=151652, + vision_end_token_id=151653, + embedding_size=3584, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + rms_norm_eps=1e-06, + head_dim=128, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + attention_dropout=0.0, + rope_scaling=None, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + initializer_range=0.02, + spatial_merge_size=2, + layer_types=None, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + + self.tts_text_start_token_id = tts_text_start_token_id + self.tts_text_end_token_id = tts_text_end_token_id + self.tts_text_pad_token_id = tts_text_pad_token_id + self.tts_codec_start_token_id = tts_codec_start_token_id + self.tts_codec_end_token_id = tts_codec_end_token_id + self.tts_codec_pad_token_id = tts_codec_pad_token_id + + self.tts_codec_mask_token_id = tts_codec_mask_token_id + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + self.vocab_size = vocab_size + self.head_dim = head_dim + self.embedding_size = embedding_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.position_id_per_seconds = position_id_per_seconds # zf + self.seconds_per_chunk = seconds_per_chunk # zf + self.audio_start_token_id = audio_start_token_id # zf + self.audio_end_token_id = audio_end_token_id # zf + + self.initializer_range = initializer_range + self.spatial_merge_size = spatial_merge_size + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen2_5OmniDiTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavDiT used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + The dimension of the model. + num_hidden_layers (`int`, *optional*, defaults to 22): + The number of transformer blocks in the DiT model. + num_attention_heads (`int`, *optional*, defaults to 16): + The number of attention heads in each transformer block. + ff_mult (`int`, *optional*, defaults to 2): + The multiplier for the feedforward layer in each transformer block. + emb_dim (`int`, *optional*, defaults to 512): + The dimension of the embedding layer. + head_dim (`int`, *optional*, defaults to 64): + The dimension of each attention head. + repeats (`int`, *optional*, defaults to 2): + The number of times the codec embeddings are repeated. + num_embeds (`int`, *optional*, defaults to 8193): + The number of unique embeddings in the codec. + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the transformer blocks. + + enc_emb_dim (`int`, *optional*, defaults to 192): + The dimension of the pre-trained speaker embedding. + enc_dim (`int`, *optional*, defaults to 128): + The dimension of the encoder output. + enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`): + A list of output channels for each TDNN/SERes2Net layer in the encoder. + enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`): + A list of kernel sizes for each layer in the encoder. + enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`): + A list of dilations for each layer in the encoder. + enc_attention_channels (`int`, *optional*, defaults to 64): + The number of attention channels in the SqueezeExcitationBlock. + enc_res2net_scale (`int`, *optional*, defaults to 2): + The scale of the Res2Net block in the encoder. + enc_se_channels (`int`, *optional*, defaults to 64): + The number of output channels after squeeze in the SqueezeExcitationBlock. + """ + + model_type = "qwen2_5_omni_dit" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=22, + num_attention_heads=16, + ff_mult=2, + emb_dim=512, + head_dim=64, + rope_theta=10000.0, + max_position_embeddings=32768, + block_size=24, + look_ahead_layers=[10], + look_backward_layers=[0, 20], + repeats=2, + num_embeds=8193, + mel_dim=80, + dropout=0.1, + enc_emb_dim=192, + enc_dim=128, + enc_channels=[256, 256, 256, 256, 768], + enc_kernel_sizes=[5, 3, 3, 3, 1], + enc_dilations=[1, 2, 3, 4, 1], + enc_attention_channels=64, + enc_res2net_scale=2, + enc_se_channels=64, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.ff_mult = ff_mult + self.emb_dim = emb_dim + self.head_dim = head_dim + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.block_size = block_size + self.look_ahead_layers = look_ahead_layers + self.look_backward_layers = look_backward_layers + self.repeats = repeats + self.num_embeds = num_embeds + self.mel_dim = mel_dim + self.dropout = dropout + self.enc_emb_dim = enc_emb_dim + self.enc_dim = enc_dim + self.enc_channels = enc_channels + self.enc_kernel_sizes = enc_kernel_sizes + self.enc_dilations = enc_dilations + self.enc_attention_channels = enc_attention_channels + self.enc_res2net_scale = enc_res2net_scale + self.enc_se_channels = enc_se_channels + super().__init__(**kwargs) + + +class Qwen2_5OmniBigVGANConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavBigVGAN module used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms. + + Args: + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + upsample_initial_channel (`int`, *optional*, defaults to 1536): + The number of channels in the initial upsampling layer. + resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`): + A list of kernel sizes for each residual block. + resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A list of dilation sizes for each residual block. + upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`): + A list of upsampling rates for each upsampling layer. + upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`): + A list of kernel sizes for each upsampling layer. + """ + + model_type = "qwen2_5_omni_bigvgan" + + def __init__( + self, + mel_dim=80, + upsample_initial_channel=1536, + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[5, 3, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 7, 4, 4, 4, 4], + **kwargs, + ): + self.mel_dim = mel_dim + self.upsample_initial_channel = upsample_initial_channel + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + super().__init__(**kwargs) + + +class Qwen2_5OmniToken2WavConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniToken2WavModel`]. + It is used to instantiate the Qwen2.5-Omni-Token2Wav model which combines a Diffusion Transformer (DiT) for mel-spectrogram generation with a BigVGAN model for waveform synthesis. The configuration contains sub-configurations for both components. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + dit_config ([`DiT_Args`], *optional*): + Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms. + bigvgan_config ([`BigVGAN_Args`], *optional*): + Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms. + Example: + + ```python + >>> from transformers import Qwen2_5OmniToken2WavModel, DiT_Args, BigVGAN_Args + + >>> # Initialize DiT configuration + >>> dit_config = DiT_Args( + ... dim=1024, + ... depth=22, + ... heads=16, + ... ff_mult=2 + ... ) + + >>> # Initialize BigVGAN configuration + >>> bigvgan_config = BigVGAN_Args( + ... mel_dim=80, + ... upsample_rates=[5,3,2,2,2,2] + ... ) + + >>> # Initialize main configuration + >>> config = Qwen2_5OmniToken2WavConfig(dit_config, bigvgan_config) + + >>> # Initialize model with config + >>> model = Qwen2_5OmniToken2Wav(config) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni_token2wav" + sub_configs = { + "dit_config": Qwen2_5OmniDiTConfig, + "bigvgan_config": Qwen2_5OmniBigVGANConfig, + } + + def __init__(self, dit_config=None, bigvgan_config=None, **kwargs): + if dit_config is None: + dit_config = {} + if bigvgan_config is None: + bigvgan_config = {} + self.dit_config = Qwen2_5OmniDiTConfig(**dit_config) + self.bigvgan_config = Qwen2_5OmniBigVGANConfig(**bigvgan_config) + super().__init__(**kwargs) + + +class Qwen2_5OmniConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen2_5OmniForConditionalGeneration`]. It is used to instantiate a Qwen2.5Omni + model according to the specified sub-models configurations, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the + [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model. + talker_config (`dict`, *optional*): Configuration of the underlying talker sub-model. + token2wav_config (`dict`, *optional*): Configuration of the underlying codec sub-model. + enable_audio_output (`bool`, *optional*, defaults to `True`): Whether enable audio output and load talker and token2wav module. + + Example: + + ```python + >>> from transformers import ( + ... Qwen2_5OmniThinkerConfig, + ... Qwen2_5OmniTalkerConfig, + ... Qwen2_5OmniToken2WavConfig, + ... Qwen2_5OmniForConditionalGeneration, + ... Qwen2_5OmniConfig, + ... ) + + >>> # Initializing sub-modules configurations. + >>> thinker_config = Qwen2_5OmniThinkerConfig() + >>> talker_config = Qwen2_5OmniTalkerConfig() + >>> token2wav_config = Qwen2_5OmniToken2WavConfig() + + + >>> # Initializing a module style configuration + >>> configuration = Qwen2_5OmniConfig.from_sub_model_configs( + ... thinker_config, talker_config, token2wav_config + ... ) + + >>> # Initializing a model (with random weights) + >>> model = Qwen2_5OmniForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni" + sub_configs = { + "thinker_config": Qwen2_5OmniThinkerConfig, + "talker_config": Qwen2_5OmniTalkerConfig, + "token2wav_config": Qwen2_5OmniToken2WavConfig, + } + + def __init__( + self, + thinker_config=None, + talker_config=None, + token2wav_config=None, + enable_audio_output: bool = True, + **kwargs, + ): + if thinker_config is None: + thinker_config = {} + logger.info("thinker_config is None. Initializing thinker model with default values") + + if talker_config is None: + talker_config = {} + logger.info("talker_config is None. Initializing talker model with default values") + + if token2wav_config is None: + token2wav_config = {} + logger.info("token2wav_config is None. Initializing token2wav model with default values") + + self.thinker_config = Qwen2_5OmniThinkerConfig(**thinker_config) + self.talker_config = Qwen2_5OmniTalkerConfig(**talker_config) + self.token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config) + self.enable_audio_output = enable_audio_output + + super().__init__(**kwargs) + + def get_text_config(self, *args, **kwargs): + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + Args: + decoder (`Optional[bool]`, *optional*, defaults to `False`): + If set to `True`, then only search for decoder config names. + """ + # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model + # except for Qwen yet. This has to be generalized if more deeply nested configs are + # added. NOTE: currently method used only by vLLM + return self.thinker_config.get_text_config(*args, **kwargs) + + +class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): + config: Qwen2_5OmniConfig + _can_compile_fullgraph = False + + +class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): + def _prepare_4d_causal_attention_mask_with_cache_position( + self, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_llm_pos_ids_for_vision( + self, + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: list[int], + grid_ws: list[int], + ): + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + def get_chunked_index( + self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int + ) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`. + + For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that: + - the first chunk contains token values < 1000, + - the second chunk contains values >= 1000 and < 2000, and so on. + + Parameters: + token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of + token index values. + t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). + remove_index (`int`) An index id to subtract from `token_indices` before chunking + + Returns: + `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive) + and end (exclusive) indices of a chunk in `token_indices`. + """ + + def _iter(): + i, start_idx = 0, 0 # skip bos token + current_chunk = 1 + while i < len(token_indices): # skip eos token + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + audio_seqlens: Optional[torch.LongTensor] = None, + second_per_grids: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + use_audio_in_video (`bool`, *optional*): + If set to `True`, use the audio in video. + audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + second_per_grids (`torch.LongTensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id + vision_start_token_id = self.config.vision_start_token_id + audio_start_token_id = self.config.audio_start_token_id + position_id_per_seconds = self.config.position_id_per_seconds + seconds_per_chunk = self.config.seconds_per_chunk + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is not None: + attention_mask = attention_mask == 1 + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i]] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = ( + image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + ) + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio = input_tokens.index(audio_token_id, st) + else: + ed_audio = len(input_tokens) + 1 + min_ed = min(ed_image, ed_video, ed_audio) + if min_ed == ed_audio: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + + elif min_ed == ed_image: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + + elif min_ed == ed_video and not use_audio_in_video: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + + elif min_ed == ed_video and use_audio_in_video: + text_len = min_ed - st - 2 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + video_llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + + t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + sub_len = 0 + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None + audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None + if video_chunk_index is not None: + sub_len += video_chunk_index[1] - video_chunk_index[0] + + llm_pos_ids_list.append( + video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]] + ) + if audio_chunk_index is not None: + sub_len += audio_chunk_index[1] - audio_chunk_index[0] + + llm_pos_ids_list.append( + audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + + if attention_mask is not None: + position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) + else: + position_ids[..., i, :] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device) + + return position_ids, mrope_position_deltas + else: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas + + +############################ +# Start Thinker # +############################ + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen2.5OmniThinker causal language model (or autoregressive) outputs. + """ +) +class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5OmniAudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Qwen2_5OmniAudioEncoderConfig, + ): + super().__init__() + self.embed_dim = config.d_model + self.num_heads = config.encoder_attention_heads + self.dropout = config.attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.config = config + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 + self.is_decoder = False + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer): + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__(config) + self.self_attn = Qwen2_5OmniAudioAttention(config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +@auto_docstring( + custom_intro=""" + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Qwen2_5OmniAudioEncoderLayer`]. + """ +) +class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniAudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"] + _supports_sdpa = True + + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.audio_bos_eos_token = nn.Embedding(2, config.output_dim) + self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.ln_post = nn.LayerNorm(config.d_model) + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(config.d_model, config.output_dim) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + + @auto_docstring + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + **kwargs, + ): + r""" + feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length after cnn + """ + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + + chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( + chunk_list, chunk_lengths, padding_value=0, padding_side="right" + ) + padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) + + padded_embed = padded_embed + self.positional_embedding.positional_embedding[ + : padded_embed.shape[1], : + ].unsqueeze(0).to(padded_embed.dtype) + hidden_states = padded_embed[padded_mask_after_cnn] + cu_seqlens = torch.cat( + ( + torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0), + ) + ).to(torch.int32) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) + token_audio_list = [] + for each_audio_states in hidden_states_list: + each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) + each_audio_states = self.ln_post(each_audio_states) + each_audio_states = self.proj(each_audio_states) + token_audio_list.append(each_audio_states) + token_audio = torch.cat(token_audio_list, dim=0) + return BaseModelOutput(last_hidden_state=token_audio) + + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5OmniVisionAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.q = nn.Linear(self.dim, self.dim, bias=True) + self.k = nn.Linear(self.dim, self.dim, bias=True) + self.v = nn.Linear(self.dim, self.dim, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.num_key_value_groups = 1 # needed for eager attention + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock): + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: + super().__init__(config, config._attn_implementation) + self.attn = Qwen2_5OmniVisionAttention(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): + config: Qwen2_5OmniVisionEncoderConfig + _no_split_modules = ["Qwen2_5OmniVisionBlock"] + + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Modification here + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + **kwargs, + ) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5OmniRotaryEmbedding(Qwen2VLRotaryEmbedding): + def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): + super().__init__(config, device) + + +# It's same as `Qwen2_5_VLAttention`, but talker model's hidden_size isn't divisible by num_heads. +# Removes the value error as a workaround. +class Qwen2_5OmniAttention(Qwen2_5_VLAttention): + def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + +class Qwen2MLP(Qwen2_5_VLMLP): + pass + + +class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLTextModel): + config: Qwen2_5OmniTextConfig + _no_split_modules = ["Qwen2_5OmniDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTextConfig): + super().__init__(config) + + +@auto_docstring( + custom_intro=""" + The Qwen2.5OmniThinker model which consists of a audio backbone and a language model. + """ +) +class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config: Qwen2_5OmniThinkerConfig + base_model_prefix = "thinker" + _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] + + def __init__(self, config: Qwen2_5OmniThinkerConfig): + super().__init__(config) + self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(config.audio_config) + self.visual = Qwen2_5OmniVisionEncoder._from_config(config.vision_config) + self.vocab_size = config.text_config.vocab_size + self.model = Qwen2_5OmniThinkerTextModel._from_config(config.text_config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.rope_deltas = None + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + ): + """ + Encodes audios into continuous embeddings that can be forwarded to the language model. + + Args: + input_features (`torch.FloatTensor`): + The tensors corresponding to the input audios. + feature_attention_mask (`torch.LongTensor`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + """ + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + + audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + + if audio_features.shape[0] != sum(audio_output_lengths.tolist()): + raise ValueError("length of audio_features should match audio_output_lengths") + + return audio_features + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + special_audio_mask = input_ids == self.config.audio_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask, special_video_mask, special_audio_mask + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + feature_attention_mask: Optional[torch.Tensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_audio_in_video: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_audio_in_video (`bool`, *optional*): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*): + Number of seconds per grid for each video, used for temporal feature mapping. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from qwen_vl_utils import process_vision_info + >>> from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration + + >>> thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B") + >>> processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") + + >>> conversations = [ + >>> {'role': 'system', 'content': 'You are a helpful voice chat bot, and please respond to me in a casual conversation manner using random voice.'}, + >>> {"role": "user", "content": [ + >>> {"type": "image", "image_url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + >>> {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, + >>> ]}, + >>> ] + + >>> text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + >>> audios = [ librosa.load(BytesIO(urlopen( conversations[1]['content'][1]['audio_url'] ).read()), sr=self.processor.feature_extractor.sampling_rate) ] + >>> images, videos = process_vision_info(conversations) + >>> inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True) + + >>> # Generate + >>> inputs['use_audio_in_video'] = `True` or `False` + >>> generation = thinker.generate(**inputs, max_new_tokens=2048) + >>> generate_ids = generation[:, inputs.input_ids.size(1):] + + >>> response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text , audios , image and video + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + else: + audio_feature_lengths = None + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + rope_deltas = rope_deltas - delta0 + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size + ) + + if not return_dict: + output = (logits,) + outputs + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniThinkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_features=None, + feature_attention_mask=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + model_inputs["input_features"] = None + + return model_inputs + + +############################ +# Start Talker # +############################ + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen2.5OmniTalker causal language model (or autoregressive) outputs. + """ +) +class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Hidden states from the thinker model that are used as input for the talker model. These represent the encoded + response that the talker model will use to generate speech tokens. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + thinker_reply_part: Optional[torch.FloatTensor] = None + + +class Qwen2_5OmniTalkerModel(Qwen2_5_VLTextModel): + config: Qwen2_5OmniTalkerConfig + _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx) + + +class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config: Qwen2_5OmniTalkerConfig + base_model_prefix = "talker" + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + + self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size) + + self.model = Qwen2_5OmniTalkerModel(config) + self.codebook_size = config.vocab_size + self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False) + + self.codec_bos_token = config.tts_codec_start_token_id + self.codec_eos_token = config.tts_codec_end_token_id + self.codec_pad_token = config.tts_codec_pad_token_id + self.codec_mask_token = config.tts_codec_mask_token_id + + self.text_bos_token = config.tts_text_start_token_id + self.text_eos_token = config.tts_text_end_token_id + self.text_pad_token = config.tts_text_pad_token_id + + self.spatial_merge_size = self.config.spatial_merge_size + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + thinker_reply_part: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + input_text_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_audio_in_video: Optional[bool] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]: + r""" + thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Hidden states from the thinker model's output that represent the text reply part to be processed. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + input_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Input token IDs for text-only content, used for position calculation in multimodal contexts. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + use_audio_in_video (`bool`, *optional*): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*): + Number of seconds per grid for each video, used for temporal feature mapping. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2_5OmniTalkerForConditionalGeneration + + >>> model = Qwen2_5OmniTalkerForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + position_ids, rope_deltas = self.get_rope_index( + input_text_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + + inputs_embeds[:, -1, :] += self.get_input_embeddings()( + torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device) + ) + inputs_embeds[:, -2, :] += self.get_input_embeddings()( + torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device) + ) + self.rope_deltas = rope_deltas + + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if inputs_embeds is None: + # 1. Inference tokens after second token + codec_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :] + if thinker_reply_part.shape[1] > 1: + thinker_reply_part = thinker_reply_part[:, 1:, :] + + talker_lm_input = self.thinker_to_talker_proj(inputs_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=talker_lm_input, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.codec_head(hidden_states) + logits = logits.float() + + loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniTalkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + thinker_reply_part=thinker_reply_part, + ) + + def _get_initial_cache_position(self, seq_length, device, model_kwargs): + # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily + inputs_embeds = model_kwargs.pop("inputs_embeds") + model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs) + model_kwargs["inputs_embeds"] = inputs_embeds + return model_kwargs + + # prepare inputs for talker lm generation + def prepare_inputs_for_generation( + self, + input_ids, + input_text_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + thinker_reply_part=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_audio_features=None, + audio_feature_attention_mask=None, + audio_feature_lengths=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values, + attention_mask, + inputs_embeds, + cache_position, + use_cache=use_cache, + thinker_reply_part=thinker_reply_part, + input_text_ids=input_text_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_audio_in_video=use_audio_in_video, + audio_feature_lengths=audio_feature_lengths, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + + if getattr(outputs, "thinker_reply_part", None) is not None: + model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part + + return model_kwargs + + +############################ +# Start Token2Wav # +############################ + + +# Using custom RoPE, will use LlamaRotaryEmbedding next version +class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + """ + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: Optional[bool] = False, + code_embed_uncond: Optional[bool] = None, + apply_cfg: Optional[bool] = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class Qwen2_5_OmniAdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self.is_causal = False + + self.to_q = nn.Linear(config.hidden_size, self.inner_dim) + self.to_k = nn.Linear(config.hidden_size, self.inner_dim) + self.to_v = nn.Linear(config.hidden_size, self.inner_dim) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = position_embeddings + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://huggingface.co/papers/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate") + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels + ) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate") + out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise TypeError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=self._get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=self._get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=self._get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, hidden_states): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + residual = hidden_states + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniBigVGANConfig + + def __init__(self, config: Qwen2_5OmniBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu() + + +class RungeKutta4ODESolver: + def __init__(self, function, initial_value): + self.function = function + self.initial_value = initial_value + + self._one_third = 1 / 3 + self._two_thirds = 2 / 3 + + def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None): + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third) + k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third)) + k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) + return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 + + def _compute_step(self, function, time_start, time_step, time_end, value_start): + function_value_start = function(time_start, value_start) + return self._rk4_step( + function, time_start, time_step, time_end, value_start, function_value_start=function_value_start + ), function_value_start + + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + if time_point == time_start: + return value_start + if time_point == time_end: + return value_end + weight = (time_point - time_start) / (time_end - time_start) + return value_start + weight * (value_end - value_start) + + def integrate(self, time_points): + solution = torch.empty( + len(time_points), + *self.initial_value.shape, + dtype=self.initial_value.dtype, + device=self.initial_value.device, + ) + solution[0] = self.initial_value + + current_index = 1 + current_value = self.initial_value + for time_start, time_end in zip(time_points[:-1], time_points[1:]): + time_step = time_end - time_start + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + next_value = current_value + delta_value + + while current_index < len(time_points) and time_end >= time_points[current_index]: + solution[current_index] = self._linear_interpolation( + time_start, time_end, current_value, next_value, time_points[current_index] + ) + current_index += 1 + + current_value = next_value + + return solution + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram. + """ +) +class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + @torch.no_grad() + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ): + noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype) + maximum_duration = quantized_code.shape[1] * self.repeats + initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniToken2WavConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"] + + def __init__(self, config: Qwen2_5OmniToken2WavConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( + config.dit_config, attn_implementation=attn_impl + ) + self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.code2wav_dit_model.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + ) + + waveform = self.code2wav_bigvgan_model(mel_spectrogram) + + return waveform + + +############################ +# Start Qwen2.5Omni # +############################ + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni model, a multimodal model composed of 3 sub-models: + - [`Qwen2_5OmniThinkerForConditionalGeneration`]: + a causal auto-regressive transformer takes text, audio, image, video as input and predict text tokens. + - [`Qwen2_5OmniTalkerForConditionalGeneration`]: + a causal auto-regressive transformer takes thinker hidden states and response as input and predict speech tokens. + - [`Qwen2_5OmniToken2WavModel`]: + a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, GenerationMixin): + config: Qwen2_5OmniConfig + _no_split_modules = [ + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavModel", + ] + + def __init__(self, config): + super().__init__(config) + + self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config) + + self.has_talker = config.enable_audio_output + self.speaker_map = {} + if config.enable_audio_output: + self.enable_talker() + self.post_init() + + def enable_talker(self): + self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config) + self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config) + self.token2wav.float() + self.has_talker = True + + def load_speakers(self, path): + check_torch_load_is_safe() + for key, value in torch.load(path, weights_only=True).items(): + self.speaker_map[key] = value + logger.info(f"Speaker {list(self.speaker_map.keys())} loaded") + + def disable_talker(self): + if hasattr(self, "talker"): + del self.talker + if hasattr(self, "token2wav"): + del self.token2wav + self.has_talker = False + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + cache_dir=None, + ignore_mismatched_sizes=False, + force_download=False, + local_files_only=False, + token=None, + revision="main", + use_safetensors=None, + weights_only=True, + **kwargs, + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + spk_path = cached_file( + pretrained_model_name_or_path, + "spk_dict.pt", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + if spk_path is None: + raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""") + model.load_speakers(spk_path) + + return model + + @torch.no_grad() + # TODO: raushan, defaults should be saved in generation config + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + speaker: str = "Chelsie", + use_audio_in_video: bool = False, + return_audio: Optional[bool] = None, + thinker_max_new_tokens: int = 1024, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 40, + talker_top_p: float = 0.8, + talker_temperature: float = 0.9, + talker_eos_token_id: list[int] = [8292, 8294], + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + r""" + Generate text response and audio from input. + + Args: + input_ids (`Optional[torch.Tensor]`, *optional*): + Input ids, should obtain from processor. + speaker (`str` , defaults to "Chelsie"): + Which speaker should be used in audio response. + use_audio_in_video (`bool`, defaults to False): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + return_audio (`Optional[bool]`, *optional*): + Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`. + kwargs (*optional*): + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model. + - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the + thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix. + Returns: + When `return_audio=False`: + - **Text** (`torch.Tensor`): Generated text token sequence. + When `return_audio=True`: + - **Text** (`torch.Tensor`): Generated text token sequence. + - **Audio waveform** (`torch.Tensor`): Generated audio waveform. + """ + if speaker not in self.speaker_map: + raise ValueError(f"{speaker} is not available, available speakers: {self.speaker_map.keys()}") + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initialized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + if input_ids.shape[0] != 1 and return_audio: + raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output") + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + } + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": talker_eos_token_id, + "repetition_penalty": talker_repetition_penalty, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key == "input_features" or key == "attention_mask": + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + speaker_params = self.speaker_map[speaker] + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + + if not generate_audio: + return thinker_result + + # 2. Generate speech tokens from talker module + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device) + if thinker_kwargs.get("input_features") is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=input_ids.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if thinker_kwargs.get("pixel_values") is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=input_ids.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if thinker_kwargs.get("pixel_values_videos") is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=input_ids.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device) + thinker_token_embeds = [ + token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden + ] + + talker_text_bos_token = speaker_params["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids, + torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + + talker_input_ids = torch.cat( + [ + torch.full_like(input_ids, fill_value=self.talker.codec_mask_token), + torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device), + torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device), + ], + dim=1, + ) + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + + eos_token = torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device) + eos_embedding = thinker_embed_tokens(eos_token).to(input_ids.device) + + pad_token = torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device) + pad_embedding = thinker_embed_tokens(pad_token).to(input_ids.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + + talker_attention_mask = None + if "attention_mask" in kwargs: + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(input_ids.device) + + talker_result = self.talker.generate( + input_ids=talker_input_ids, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + **{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()}, + ) + talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] + + # 3. Generate wavs from code + if self.token2wav.dtype != torch.float: + self.token2wav.float() + + wav = self.token2wav( + talker_generate_codes.to(input_ids.device), + conditioning=speaker_params["cond"].to(input_ids.device).float(), + reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(), + **token2wav_kwargs, + ) + + return thinker_result.sequences, wav.float() + + +__all__ = [ + "Qwen2_5OmniConfig", + "Qwen2_5OmniThinkerConfig", + "Qwen2_5OmniTalkerConfig", + "Qwen2_5OmniToken2WavConfig", + "Qwen2_5OmniForConditionalGeneration", + "Qwen2_5OmniThinkerTextModel", + "Qwen2_5OmniThinkerForConditionalGeneration", + "Qwen2_5OmniTalkerModel", + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavDiTModel", + "Qwen2_5OmniToken2WavBigVGANModel", + "Qwen2_5OmniToken2WavModel", + "Qwen2_5OmniPreTrainedModel", + "Qwen2_5OmniPreTrainedModelForConditionalGeneration", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcbb0c535f92b912e4638a7b2b5f6fd6bb58f81 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py @@ -0,0 +1,360 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Qwen2.5Omni. +""" + +import logging +import re +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput +from ...video_utils import VideoInput + + +class Qwen2_5_OmniVideosKwargs(VideosKwargs): + fps: Optional[list[Union[int, float]]] + use_audio_in_video: Optional[bool] + seconds_per_chunk: Optional[float] + position_id_per_seconds: Optional[int] + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen2_5_OmniImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen2_5OmniProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_OmniVideosKwargs + images_kwargs: Qwen2_5_OmniImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "padding_side": "left", + }, + "videos_kwargs": { + "seconds_per_chunk": 2.0, + "position_id_per_seconds": 25, + "use_audio_in_video": False, + "size": { + "shortest_edge": 128 * 28 * 28, + "longest_edge": 768 * 28 * 28, + }, + }, + "audio_kwargs": { + "sampling_rate": 16000, + "padding": "max_length", + "return_attention_mask": True, + }, + } + + +class Qwen2_5OmniProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5Omni processor. + [`Qwen2_5OmniProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`], [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5OmniProcessor.__call__`] and [`~Qwen2_5OmniProcessor.decode`] for more information. + + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor. + video_processor ([`Qwen2VLVideoProcessor`], *optional*): + The video processor. + feature_extractor ([`WhisperFeatureExtractor`], *optional*): + The audio feature extractor. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The text tokenizer. + chat_template (`Optional[str]`, *optional*): + The Jinja template to use for formatting the conversation. If not provided, the default chat template is used. + """ + + attributes = ["image_processor", "video_processor", "feature_extractor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" + feature_extractor_class = "WhisperFeatureExtractor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, image_processor=None, video_processor=None, feature_extractor=None, tokenizer=None, chat_template=None + ): + super().__init__(image_processor, video_processor, feature_extractor, tokenizer, chat_template=chat_template) + self.image_token = self.tokenizer.image_token + self.audio_token = self.tokenizer.audio_token + self.video_token = self.tokenizer.video_token + self.vision_bos_token = self.tokenizer.vision_bos_token + self.vision_eos_token = self.tokenizer.vision_eos_token + self.audio_bos_token = self.tokenizer.audio_bos_token + self.audio_eos_token = self.tokenizer.audio_eos_token + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + images: Optional[ImageInput] = None, + videos: Optional[VideoInput] = None, + audio: Optional[AudioInput] = None, + **kwargs: Unpack[Qwen2_5OmniProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to + WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. To prepare the vision inputs, + this method forwards the `vision_infos` and `kwargs` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] + if `vision_infos` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + audio (`np.ndarray`, `list[np.ndarray]`): + The audio or batch of audio to be prepared. Each audio can be a NumPy array. + """ + + if text is None: + raise ValueError("You need to specify either a `text` input to process.") + + output_kwargs = self._merge_kwargs( + Qwen2_5OmniProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk") + position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds") + use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video") + + if audio is not None: + output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here + audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + audio_inputs["feature_attention_mask"] = audio_inputs.pop( + "attention_mask" + ) # rename feature_attention_mask to prevent conflicts later on + audio_inputs["input_features"] = audio_inputs.pop( + "input_features" + ) # rename input_features to prevent conflicts later on + input_lengths = (audio_inputs["feature_attention_mask"].sum(-1) - 1) // 2 + 1 + audio_lengths = iter((input_lengths - 2) // 2 + 1) + else: + audio_inputs = {} + audio_lengths = iter([]) + + if images is not None: + images_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = iter(images_inputs["image_grid_thw"]) + else: + images_inputs = {} + image_grid_thw = iter([]) + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + + fps = output_kwargs["videos_kwargs"].get("fps", 2.0) + video_grid_thw = videos_inputs["video_grid_thw"] + second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) + videos_inputs["video_second_per_grid"] = second_per_grid_ts + + video_grid_thw = iter(video_grid_thw) + video_second_per_grid = iter(second_per_grid_ts) + else: + videos_inputs = {} + video_grid_thw = iter([]) + video_second_per_grid = iter([]) + + if not isinstance(text, list): + text = [text] + + if images is not None or videos is not None or audio is not None: + text = self.replace_multimodal_special_tokens( + text, + audio_lengths, + image_grid_thw, + video_grid_thw, + video_second_per_grid=video_second_per_grid, + use_audio_in_video=use_audio_in_video, + position_id_per_seconds=position_id_per_seconds, + seconds_per_chunk=seconds_per_chunk, + ) + + texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature( + data={**texts_inputs, **images_inputs, **videos_inputs, **audio_inputs}, + tensor_type=kwargs.get("return_tensors"), + ) + + def replace_multimodal_special_tokens( + self, + text, + audio_lengths, + image_grid_thw, + video_grid_thw, + video_second_per_grid, + use_audio_in_video, + position_id_per_seconds, + seconds_per_chunk, + ): + # Extend mm token length + merge_length_image = self.image_processor.merge_size**2 + merge_length_video = self.video_processor.merge_size**2 + + processed_text = [] + for sample in text: + positions = [] + special_tokens = [re.escape(tok) for tok in [self.audio_token, self.image_token, self.video_token]] + pattern = "|".join(special_tokens) + positions = sorted([(match.start(), match.group()) for match in re.finditer(pattern, sample)]) + positions.sort(key=lambda x: x[0]) + + for _, special_token in positions: + if special_token == self.audio_token: + sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1) + elif special_token == self.image_token: + image_seq_length = next(image_grid_thw).prod() // merge_length_image + sample = sample.replace(self.image_token, "<|image_placeholder|>" * image_seq_length, 1) + elif special_token == self.video_token: + if not use_audio_in_video: + video_seq_length = next(video_grid_thw).prod() // merge_length_video + sample = sample.replace(self.video_token, "<|video_placeholder|>" * video_seq_length, 1) + else: + audio_token_indices = np.arange(next(audio_lengths)) + curr_video_grid_thw = next(video_grid_thw) + height = curr_video_grid_thw[1] // self.video_processor.merge_size + width = curr_video_grid_thw[2] // self.video_processor.merge_size + video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1) + video_token_indices = np.broadcast_to( + video_token_indices, (video_token_indices.shape[0], height, width) + ).reshape(-1) + video_token_indices = ( + video_token_indices * next(video_second_per_grid) * position_id_per_seconds + ) + + tokens_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = self.get_chunked_index(video_token_indices, tokens_per_chunk) + audio_chunk_indexes = self.get_chunked_index(audio_token_indices, tokens_per_chunk) + + placeholder_string = self.vision_bos_token + self.audio_bos_token + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + if j < len(video_chunk_indexes): + video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0] + placeholder_string += "<|video_placeholder|>" * video_seq_length + if j < len(audio_chunk_indexes): + audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0] + placeholder_string += "<|audio_placeholder|>" * audio_seq_length + placeholder_string += self.audio_eos_token + self.vision_eos_token + sample = sample.replace( + self.vision_bos_token + self.video_token + self.vision_eos_token, + placeholder_string, + 1, + ) + + sample = sample.replace("<|audio_placeholder|>", self.audio_token) + sample = sample.replace("<|image_placeholder|>", self.image_token) + sample = sample.replace("<|video_placeholder|>", self.video_token) + processed_text.append(sample) + return processed_text + + def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`. + + For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that: + - the first chunk contains token values < 1000, + - the second chunk contains values >= 1000 and < 2000, and so on. + + Parameters: + token_indices (`np.ndarray`): A monotonically increasing list of token index values. + t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). + + Returns: + `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive) + and end (exclusive) indices of a chunk in `token_indices`. + """ + + def _iter(): + i, start_idx = 0, 0 # skip bos token + current_chunk = 1 + while i < len(token_indices): # skip eos token + if token_indices[i] >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + def apply_chat_template(self, conversations, chat_template=None, **kwargs): + is_batched = False + if isinstance(conversations[0], dict): + conversations = [conversations] + is_batched = True + + for conversation in conversations: + if ( + conversation[0]["role"] != "system" + or conversation[0]["content"][0]["text"] + != "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + ): + logging.warning( + "System prompt modified, audio output may not work as expected. " + + "Audio output mode only works when using default system prompt 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'" + ) + if is_batched: + conversations = conversations[0] + + return super().apply_chat_template(conversations, chat_template, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list( + dict.fromkeys( + tokenizer_input_names + + feature_extractor_input_names + + image_processor_input_names + + ["feature_attention_mask"] + + ["video_second_per_grid"] + ) + ) + + +__all__ = ["Qwen2_5OmniProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..715effd8b7f1dbe0d8cb8c9540df3229243acbc2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_moe import * + from .modeling_qwen3_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/configuration_qwen3_moe.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/configuration_qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..a23d19c1154978b8149bcb67d237f15415f1b634 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/configuration_qwen3_moe.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3MoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen3MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a + Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [Qwen/Qwen3-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 768): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 128): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + + ```python + >>> from transformers import Qwen3MoeModel, Qwen3MoeConfig + + >>> # Initializing a Qwen3MoE style configuration + >>> configuration = Qwen3MoeConfig() + + >>> # Initializing a model from the Qwen3-15B-A2B" style configuration + >>> model = Qwen3MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3Moe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=6144, + num_hidden_layers=24, + num_attention_heads=32, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=8, + num_experts=128, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen3MoeConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..25c738744faba82db3c451b6987b01db74fadcfd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -0,0 +1,712 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_moe/modular_qwen3_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_qwen3_moe import Qwen3MoeConfig + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = getattr(config, "sliding_window", None) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3MoeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen3MoePreTrainedModel(PreTrainedModel): + config: Qwen3MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(Qwen3MoeSparseMoeBlock, index=1), + "hidden_states": Qwen3MoeDecoderLayer, + "attentions": Qwen3MoeAttention, + } + + +@auto_docstring +class Qwen3MoeModel(Qwen3MoePreTrainedModel): + def __init__(self, config: Qwen3MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class Qwen3MoeForSequenceClassification(GenericForSequenceClassification, Qwen3MoePreTrainedModel): + pass + + +class Qwen3MoeForTokenClassification(GenericForTokenClassification, Qwen3MoePreTrainedModel): + pass + + +class Qwen3MoeForQuestionAnswering(GenericForQuestionAnswering, Qwen3MoePreTrainedModel): + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +__all__ = [ + "Qwen3MoeForCausalLM", + "Qwen3MoeForQuestionAnswering", + "Qwen3MoeModel", + "Qwen3MoePreTrainedModel", + "Qwen3MoeForSequenceClassification", + "Qwen3MoeForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/modular_qwen3_moe.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/modular_qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..e7dd3dda00acedfc2ef9c7d8a3bbd74e5426f4bf --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3 model.""" + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg +from ..llama.modeling_llama import ( + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaRMSNorm, +) +from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel, load_balancing_loss_func +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer +from ..qwen3.modeling_qwen3 import Qwen3Attention +from .configuration_qwen3_moe import Qwen3MoeConfig + + +logger = logging.get_logger(__name__) + + +class Qwen3MoeAttention(Qwen3Attention): # This is the main diff with qwen2Moe! + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.sliding_window = getattr(config, "sliding_window", None) + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Qwen3MoeRMSNorm(LlamaRMSNorm): + pass + + +class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3MoeModel(MixtralModel): + pass + + +class Qwen3MoeForCausalLM(MixtralForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.num_experts = config.num_experts + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class Qwen3MoeForSequenceClassification(LlamaForSequenceClassification): + pass + + +class Qwen3MoeForTokenClassification(LlamaForTokenClassification): + pass + + +class Qwen3MoeForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +__all__ = [ + "Qwen3MoeForCausalLM", + "Qwen3MoeForQuestionAnswering", + "Qwen3MoeModel", + "Qwen3MoePreTrainedModel", # noqa: F822 + "Qwen3MoeForSequenceClassification", + "Qwen3MoeForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f87f71c621d8da2dc9caffd688cf9f1d19b11c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_next import * + from .modeling_qwen3_next import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/configuration_qwen3_next.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/configuration_qwen3_next.py new file mode 100644 index 0000000000000000000000000000000000000000..148166cbd16fc3f83663d32a66a73e375971c05d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/configuration_qwen3_next.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3-Next model configuration""" + +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen3NextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a + Qwen3-Next model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 10): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 512): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + + ```python + >>> from transformers import Qwen3NextModel, Qwen3NextConfig + + >>> # Initializing a Qwen3Next style configuration + >>> configuration = Qwen3NextConfig() + + >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration + >>> model = Qwen3NextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen3_next" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.25, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + decoder_sparse_step=1, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=10, + num_experts=512, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=[], + layer_types=None, + **kwargs, + ): + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + interval_pattern = kwargs.get("full_attention_interval", 4) + self.layer_types = [ + "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = mlp_only_layers + + +__all__ = ["Qwen3NextConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/modeling_qwen3_next.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/modeling_qwen3_next.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc4e203ff28f39d1776fefa497e1aa52b4ad724 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -0,0 +1,1276 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_next/modular_qwen3_next.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_next.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_flash_linear_attention_available, +) +from .configuration_qwen3_next import Qwen3NextConfig + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_linear_attention_available(): + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +else: + chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None + FusedRMSNormGated = None + +logger = logging.get_logger(__name__) + + +class Qwen3NextRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) + + +class Qwen3NextDynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention + cache (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__(self, config: Qwen3NextConfig): + super().__init__() + self.layer_types = config.layer_types + self.transformer_layers = [ + i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" + ] + self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") + + # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference + self.conv_states = [None for _ in range(config.num_hidden_layers)] + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] + + def __len__(self): + return len(self.layer_types) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] is not None: + device = self.key_cache[layer_idx].device + beam_idx = beam_idx.to(device) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) + + if self.conv_states[layer_idx] is not None: + device = self.conv_states[layer_idx].device + beam_idx = beam_idx.to(device) + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) + self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + return self.conv_states[self.last_linear_layer] is not None + + +class Qwen3NextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3NextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3NextRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Qwen3Next is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3NextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3NextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3NextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output * torch.sigmoid(gate) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +is_fast_path_available = all( + (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +) + + +def torch_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, +): + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype) + return out + + +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def torch_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +class Qwen3NextGatedDeltaNet(nn.Module): + def __init__(self, config: Qwen3NextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False) + self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.norm = ( + Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + if FusedRMSNormGated is None + else FusedRMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(), + ) + ) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.causal_conv1d_fn = causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of the required library is not installed. Falling back to " + "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + """ + + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads, + 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim) + z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim) + b = b.reshape(b.size(0), b.size(1), self.num_v_heads) + a = a.reshape(a.size(0), a.size(1), self.num_v_heads) + return query, key, value, z, b, a + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Qwen3NextDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv = mixed_qkv.transpose(1, 2) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + + output = self.out_proj(core_attn_out) + return output + + +class Qwen3NextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3NextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3NextMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + + final_hidden_states = final_hidden_states + shared_expert_output + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Qwen3NextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3NextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + # token mixer + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3NextSparseMoeBlock(config) + else: + self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3NextPreTrainedModel(PreTrainedModel): + config: Qwen3NextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3NextDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] + _can_record_outputs = { + "router_logits": OutputRecorder(Qwen3NextSparseMoeBlock, index=1), + "hidden_states": Qwen3NextDecoderLayer, + "attentions": Qwen3NextAttention, + } + _is_stateful = True + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Qwen3NextGatedDeltaNet): + module.dt_bias.data.fill_(1.0) + module.A_log.data.uniform_(0, 16).log_() + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, Qwen3NextRMSNorm): + module.weight.data.zero_() + + +class Qwen3NextModel(Qwen3NextPreTrainedModel): + def __init__(self, config: Qwen3NextConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [Qwen3NextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3NextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3NextDynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _update_linear_attn_mask(self, attention_mask, cache_position): + """ + NOTE: Left-padding is used for linear attention mask. + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class Qwen3NextForCausalLM(Qwen3NextPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3NextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Qwen3NextDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3NextForCausalLM + + >>> model = Qwen3NextForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class Qwen3NextForSequenceClassification(GenericForSequenceClassification, Qwen3NextPreTrainedModel): + pass + + +class Qwen3NextForTokenClassification(GenericForTokenClassification, Qwen3NextPreTrainedModel): + pass + + +class Qwen3NextForQuestionAnswering(GenericForQuestionAnswering, Qwen3NextPreTrainedModel): + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +__all__ = [ + "Qwen3NextForCausalLM", + "Qwen3NextForQuestionAnswering", + "Qwen3NextModel", + "Qwen3NextPreTrainedModel", + "Qwen3NextForSequenceClassification", + "Qwen3NextForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/modular_qwen3_next.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/modular_qwen3_next.py new file mode 100644 index 0000000000000000000000000000000000000000..6411c4ffacb0ab60d3968709eb526b800281ed9a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_next/modular_qwen3_next.py @@ -0,0 +1,884 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3-Next model.""" + +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_flash_linear_attention_available, +) +from ..bamba.modeling_bamba import apply_mask_to_padding_states, apply_rotary_pos_emb +from ..gemma3.modeling_gemma3 import Gemma3RMSNorm +from ..llama.modeling_llama import ( + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, +) +from ..mixtral.modeling_mixtral import MixtralForCausalLM +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock +from ..qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeMLP, + Qwen3MoeRotaryEmbedding, + eager_attention_forward, +) +from .configuration_qwen3_next import Qwen3NextConfig + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_linear_attention_available(): + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +else: + chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None + FusedRMSNormGated = None + + +is_fast_path_available = all( + (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +) + +logger = logging.get_logger(__name__) + + +class Qwen3NextRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) + + +class Qwen3NextDynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention + cache (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__(self, config: Qwen3NextConfig): + super().__init__() + self.layer_types = config.layer_types + self.transformer_layers = [ + i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" + ] + self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") + + # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference + self.conv_states = [None for _ in range(config.num_hidden_layers)] + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] + + def __len__(self): + return len(self.layer_types) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] is not None: + device = self.key_cache[layer_idx].device + beam_idx = beam_idx.to(device) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) + + if self.conv_states[layer_idx] is not None: + device = self.conv_states[layer_idx].device + beam_idx = beam_idx.to(device) + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) + self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + return self.conv_states[self.last_linear_layer] is not None + + +class Qwen3NextRotaryEmbedding(Qwen3MoeRotaryEmbedding): + pass + + +class Qwen3NextRMSNorm(Gemma3RMSNorm): + pass + + +class Qwen3NextAttention(Qwen3MoeAttention): + def __init__(self, config: Qwen3NextConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias + ) + del self.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output * torch.sigmoid(gate) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def torch_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, +): + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype) + return out + + +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def torch_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +class Qwen3NextGatedDeltaNet(nn.Module): + def __init__(self, config: Qwen3NextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False) + self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.norm = ( + Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + if FusedRMSNormGated is None + else FusedRMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(), + ) + ) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.causal_conv1d_fn = causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of the required library is not installed. Falling back to " + "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + """ + + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads, + 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim) + z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim) + b = b.reshape(b.size(0), b.size(1), self.num_v_heads) + a = a.reshape(a.size(0), a.size(1), self.num_v_heads) + return query, key, value, z, b, a + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Qwen3NextDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv = mixed_qkv.transpose(1, 2) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + + output = self.out_proj(core_attn_out) + return output + + +class Qwen3NextMLP(Qwen3MoeMLP): + pass + + +class Qwen3NextSparseMoeBlock(Qwen2MoeSparseMoeBlock): + pass + + +class Qwen3NextDecoderLayer(Qwen3MoeDecoderLayer): + def __init__(self, config: Qwen3NextConfig, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + # token mixer + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3NextSparseMoeBlock(config) + else: + self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3NextPreTrainedModel(PreTrainedModel): + config: Qwen3NextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3NextDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] + _can_record_outputs = { + "router_logits": OutputRecorder(Qwen3NextSparseMoeBlock, index=1), + "hidden_states": Qwen3NextDecoderLayer, + "attentions": Qwen3NextAttention, + } + _is_stateful = True + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Qwen3NextGatedDeltaNet): + module.dt_bias.data.fill_(1.0) + module.A_log.data.uniform_(0, 16).log_() + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, Qwen3NextRMSNorm): + module.weight.data.zero_() + + +class Qwen3NextModel(Qwen3NextPreTrainedModel): + def __init__(self, config: Qwen3NextConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [Qwen3NextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3NextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3NextDynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _update_linear_attn_mask(self, attention_mask, cache_position): + """ + NOTE: Left-padding is used for linear attention mask. + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + +class Qwen3NextForCausalLM(MixtralForCausalLM): + def __init__(self, config): + super().__init__(config) + self.num_experts = config.num_experts + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Qwen3NextDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3NextForCausalLM + + >>> model = Qwen3NextForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + +class Qwen3NextForSequenceClassification(LlamaForSequenceClassification): + pass + + +class Qwen3NextForTokenClassification(LlamaForTokenClassification): + pass + + +class Qwen3NextForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +__all__ = [ + "Qwen3NextForCausalLM", + "Qwen3NextForQuestionAnswering", + "Qwen3NextModel", + "Qwen3NextPreTrainedModel", + "Qwen3NextForSequenceClassification", + "Qwen3NextForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e37161a2e4153b0192ca2d10d85292c8fa2ed3f2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_vl import * + from .modeling_qwen3_vl import * + from .processing_qwen3_vl import * + from .video_processing_qwen3_vl import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/configuration_qwen3_vl.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/configuration_qwen3_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..132ffa8be150dbab496810254d9a9757c70ccc1c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/configuration_qwen3_vl.py @@ -0,0 +1,287 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Qwen3VLVisionConfig(PretrainedConfig): + model_type = "qwen3_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3VLModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig + + >>> # Initializing a Qwen3VL style configuration + >>> configuration = Qwen3VLTextConfig() + + >>> # Initializing a model from the Qwen3-VL-7B style configuration + >>> model = Qwen3VLTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_text" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig + + >>> # Initializing a Qwen3-VL style configuration + >>> configuration = Qwen3VLConfig() + + >>> # Initializing a model from the Qwen3-VL-4B style configuration + >>> model = Qwen3VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl" + sub_configs = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +__all__ = ["Qwen3VLConfig", "Qwen3VLTextConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa91b03cea8c6f9be720177baadb95483d6f65c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -0,0 +1,1561 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig + + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3VLVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3VLVisionAttention(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3VLVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLVisionAttention(config=config) + self.mlp = Qwen3VLVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3VLTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen3VLTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3VLTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3VLTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3VLTextMLP(config) + self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3VLModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +@auto_docstring +class Qwen3VLPreTrainedModel(PreTrainedModel): + config: Qwen3VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen3VLTextDecoderLayer, + "attentions": Qwen3VLTextAttention, + } + + +class Qwen3VLVisionModel(Qwen3VLPreTrainedModel): + config: Qwen3VLVisionConfig + _no_split_modules = ["Qwen3VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VL, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLTextModel(Qwen3VLPreTrainedModel): + config: Qwen3VLTextConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer"] + + def __init__(self, config: Qwen3VLTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +@auto_docstring +class Qwen3VLModel(Qwen3VLPreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + return Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3VL causal language model (or autoregressive) outputs. + """ +) +class Qwen3VLCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3VLModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + TODO: Add example + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Qwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "Qwen3VLVisionModel", + "Qwen3VLForConditionalGeneration", + "Qwen3VLModel", + "Qwen3VLPreTrainedModel", + "Qwen3VLTextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/modular_qwen3_vl.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/modular_qwen3_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..8627dbdafadeb233b9d647ff9f4fe9e3bd9f54be --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -0,0 +1,1469 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3-VL model.""" + +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import auto_docstring, is_torchdynamo_compiling, logging +from ...utils.generic import check_model_inputs +from ...video_utils import VideoInput +from ..qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLCausalLMOutputWithPast, + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, + Qwen2_5_VLVisionBlock, +) +from ..qwen2_vl.modeling_qwen2_vl import ( + PatchEmbed, + Qwen2VLModelOutputWithPast, + Qwen2VLPreTrainedModel, + TransformersKwargs, + VisionAttention, + VisionRotaryEmbedding, +) +from ..qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor +from ..qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3DecoderLayer, + Qwen3Model, + apply_rotary_pos_emb, + eager_attention_forward, +) + + +logger = logging.get_logger(__name__) + + +class Qwen3VLVisionConfig(PretrainedConfig): + model_type = "qwen3_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3VLModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig + + >>> # Initializing a Qwen3VL style configuration + >>> configuration = Qwen3VLTextConfig() + + >>> # Initializing a model from the Qwen3-VL-7B style configuration + >>> model = Qwen3VLTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_text" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig + + >>> # Initializing a Qwen3-VL style configuration + >>> configuration = Qwen3VLConfig() + + >>> # Initializing a model from the Qwen3-VL-4B style configuration + >>> model = Qwen3VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl" + sub_configs = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLVisionPatchEmbed(PatchEmbed): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + +class Qwen3VLVisionRotaryEmbedding(VisionRotaryEmbedding): + pass + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +class Qwen3VLVisionAttention(VisionAttention): + def __init__(self, config: Qwen3VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + + +class Qwen3VLVisionBlock(Qwen2_5_VLVisionBlock): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLVisionAttention(config=config) + self.mlp = Qwen3VLVisionMLP(config=config) + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3VLTextAttention(Qwen3Attention): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + del self.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLTextDecoderLayer(Qwen3DecoderLayer): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + del self.attention_type + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + return super().forward( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + +class Qwen3VLModelOutputWithPast(Qwen2VLModelOutputWithPast): + pass + + +class Qwen3VLPreTrainedModel(Qwen2VLPreTrainedModel): + config: Qwen3VLConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _can_record_outputs = { + "hidden_states": Qwen3VLTextDecoderLayer, + "attentions": Qwen3VLTextAttention, + } + + +class Qwen3VLVisionModel(Qwen3VLPreTrainedModel): + config: Qwen3VLVisionConfig + _no_split_modules = ["Qwen3VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VL, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLTextModel(Qwen3VLPreTrainedModel, Qwen3Model): + config: Qwen3VLTextConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer"] + + def __init__(self, config: Qwen3VLTextConfig): + super().__init__(config) + del self.has_sliding_layers + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Qwen3VLModel(Qwen2_5_VLModel): + config: Qwen3VLConfig + _checkpoint_conversion_mapping = {} + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLTextModel._from_config(config.text_config) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + return Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + ) + + +class Qwen3VLCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast): + pass + + +class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + config: Qwen3VLConfig + _checkpoint_conversion_mapping = {} + + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + TODO: Add example + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Qwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + +class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False): + pass + + +class Qwen3VLImagesKwargs(Qwen2VLImagesKwargs): + pass + + +class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen3VLImagesKwargs + videos_kwargs: Qwen3VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_token_type_ids": False, + "return_mm_token_type_ids": False, + }, + "videos_kwargs": {"return_metadata": True}, + } + + +class Qwen3VLProcessor(Qwen2VLProcessor): + r""" + Constructs a Qwen3VL processor which wraps a Qwen3VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen3VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen3VLProcessor.__call__`] and [`~Qwen3VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + video_processor ([`Qwen3VLVideoProcessor`], *optional*): + The video processor is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs) + self.vision_start_token = ( + "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token + ) + self.vision_end_token = ( + "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token + ) + self.vision_start_token_id = ( + tokenizer.vision_start_token_id + if getattr(tokenizer, "vision_start_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_start_token) + ) + self.vision_end_token_id = ( + tokenizer.vision_end_token_id + if getattr(tokenizer, "vision_end_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_end_token) + ) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen3VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen3VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + # If user has not requested video metadata, pop it + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.video_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + metadata = video_metadata[index] + if metadata.fps is None: + logger.warning_once( + "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. " + "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + metadata.fps = 24 if metadata.fps is None else metadata.fps + + # if timestamps are not provided, calculate them + curr_timestamp = self._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + self.video_processor.merge_size, + ) + + video_placeholder = "" + frame_seqlen = video_grid_thw[index][1:].prod() // merge_length + for frame_idx in range(video_grid_thw[index][0]): + curr_time = curr_timestamp[frame_idx] + video_placeholder += f"<{curr_time:.1f} seconds>" + video_placeholder += ( + self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token + ) + if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]: + text[i] = text[i].replace( + f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1 + ) + else: + # vllm may input video token directly + text[i] = text[i].replace(self.video_token, video_placeholder, 1) + index += 1 + + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size)) + timestamps = [idx / video_fps for idx in indices] + # @JJJYmmm frames are merged by self.merge_size, \ + # so we need to average the timestamps between the first/last frame within the temporal patch + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size) + ] + return timestamps + + +__all__ = [ + "Qwen3VLConfig", + "Qwen3VLTextConfig", + "Qwen3VLVisionModel", + "Qwen3VLForConditionalGeneration", + "Qwen3VLModel", + "Qwen3VLPreTrainedModel", + "Qwen3VLProcessor", + "Qwen3VLTextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/processing_qwen3_vl.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/processing_qwen3_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..17bdd975eef337371aa763824edc7e6ec1987dbf --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -0,0 +1,328 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +from ...video_utils import VideoInput + + +logger = logging.get_logger(__name__) + + +class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False): + pass + + +class Qwen3VLImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen3VLImagesKwargs + videos_kwargs: Qwen3VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_token_type_ids": False, + "return_mm_token_type_ids": False, + }, + "videos_kwargs": {"return_metadata": True}, + } + + +class Qwen3VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen3VL processor which wraps a Qwen3VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen3VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen3VLProcessor.__call__`] and [`~Qwen3VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + video_processor ([`Qwen3VLVideoProcessor`], *optional*): + The video processor is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer", "video_processor"] + image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + self.vision_start_token = ( + "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token + ) + self.vision_end_token = ( + "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token + ) + self.vision_start_token_id = ( + tokenizer.vision_start_token_id + if getattr(tokenizer, "vision_start_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_start_token) + ) + self.vision_end_token_id = ( + tokenizer.vision_end_token_id + if getattr(tokenizer, "vision_end_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_end_token) + ) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen3VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen3VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + # If user has not requested video metadata, pop it + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.video_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + metadata = video_metadata[index] + if metadata.fps is None: + logger.warning_once( + "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. " + "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + metadata.fps = 24 if metadata.fps is None else metadata.fps + + # if timestamps are not provided, calculate them + curr_timestamp = self._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + self.video_processor.merge_size, + ) + + video_placeholder = "" + frame_seqlen = video_grid_thw[index][1:].prod() // merge_length + for frame_idx in range(video_grid_thw[index][0]): + curr_time = curr_timestamp[frame_idx] + video_placeholder += f"<{curr_time:.1f} seconds>" + video_placeholder += ( + self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token + ) + if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]: + text[i] = text[i].replace( + f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1 + ) + else: + # vllm may input video token directly + text[i] = text[i].replace(self.video_token, video_placeholder, 1) + index += 1 + + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + video_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (num_frames, height, width) per each video. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + if video_sizes is not None: + videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {}) + videos_kwargs.update(kwargs) + num_video_patches = [ + self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) + for video_size in video_sizes + ] + num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] + vision_data["num_video_tokens"] = num_video_tokens + + return MultiModalData(**vision_data) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size)) + timestamps = [idx / video_fps for idx in indices] + # @JJJYmmm frames are merged by self.merge_size, \ + # so we need to average the timestamps between the first/last frame within the temporal patch + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size) + ] + return timestamps + + +__all__ = ["Qwen3VLProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/video_processing_qwen3_vl.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/video_processing_qwen3_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..c4648788c9dc9c9660b31c09d652d27a5d3773cd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl/video_processing_qwen3_vl.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""video processor class for Qwen3-VL.""" + +import math +from typing import Optional, Union + +import numpy as np +import torch + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ChannelDimension, PILImageResampling, SizeDict, get_image_size +from ...processing_utils import Unpack, VideosKwargs +from ...utils import TensorType, add_start_docstrings, logging +from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor +from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos + + +logger = logging.get_logger(__name__) + + +def smart_resize( + num_frames: int, + height: int, + width: int, + temporal_factor: int = 2, + factor: int = 32, + min_pixels: int = 128 * 128, + max_pixels: int = 16 * 16 * 2 * 2 * 2 * 6144, +): + if num_frames < temporal_factor: + raise ValueError(f"t:{num_frames} must be larger than temporal_factor:{temporal_factor}") + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + t_bar = round(num_frames / temporal_factor) * temporal_factor + + if t_bar * h_bar * w_bar > max_pixels: + beta = math.sqrt((num_frames * height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif t_bar * h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (num_frames * height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + return h_bar, w_bar + + +class Qwen3VLVideoProcessorInitKwargs(VideosKwargs): + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + min_frames: Optional[int] + max_frames: Optional[int] + + +@add_start_docstrings( + "Constructs a fast Qwen3-VL image processor that dynamically resizes videos based on the original videos.", + BASE_VIDEO_PROCESSOR_DOCSTRING, + """ + patch_size (`int`, *optional*, defaults to 16): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """, +) +class Qwen3VLVideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BICUBIC + size = {"shortest_edge": 128 * 32 * 32, "longest_edge": 32 * 32 * 768} + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + patch_size = 16 + temporal_patch_size = 2 + merge_size = 2 + fps = 2 + min_frames = 4 + max_frames = 768 + do_sample_frames = True + valid_kwargs = Qwen3VLVideoProcessorInitKwargs + model_input_names = ["pixel_values_videos", "video_grid_thw"] + + def __init__(self, **kwargs: Unpack[Qwen3VLVideoProcessorInitKwargs]): + super().__init__(**kwargs) + if self.size is not None and ( + self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None + ): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + return super()._further_process_kwargs(size=size, **kwargs) + + def sample_frames( + self, + metadata: VideoMetadata, + num_frames: Optional[int] = None, + fps: Optional[Union[int, float]] = None, + **kwargs, + ): + """ + Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames. + If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames` + and `fps` are mutually exclusive. + + Args: + video (`torch.Tensor`): + Video that need to be sampled. + metadata (`VideoMetadata`): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + fps (`int` or `float`, *optional*): + Target frames to sample per second. Defaults to `self.fps`. + Returns: + torch.Tensor: + Sampled video frames. + """ + if fps is not None and num_frames is not None: + raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") + + total_num_frames = metadata.total_num_frames + fps = fps if fps is not None else self.fps + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is None and fps is not None: + if metadata.fps is None: + metadata.fps = 24 + logger.warning_once( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + num_frames = int(total_num_frames / metadata.fps * fps) + num_frames = min(min(max(num_frames, self.min_frames), self.max_frames), total_num_frames) + + if num_frames is None: + num_frames = min(max(total_num_frames, self.min_frames), self.max_frames) + + indices = np.linspace(0, total_num_frames - 1, num_frames).round().astype(int) + + return indices + + def _preprocess( + self, + videos: list[torch.Tensor], + do_convert_rgb: bool = True, + do_resize: bool = True, + size: Optional[SizeDict] = None, + interpolation: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: float = 1 / 255.0, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ): + grouped_videos, grouped_videos_index = group_videos_by_shape(videos) + resized_videos_grouped = {} + + for shape, stacked_videos in grouped_videos.items(): + B, T, C, H, W = stacked_videos.shape + num_frames, height, width = T, H, W + if do_resize: + resized_height, resized_width = smart_resize( + num_frames=num_frames, + height=height, + width=width, + temporal_factor=temporal_patch_size, + factor=patch_size * merge_size, + min_pixels=size.shortest_edge, + max_pixels=size.longest_edge, + ) + stacked_videos = stacked_videos.view(B * T, C, H, W) + stacked_videos = self.resize( + stacked_videos, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, + ) + stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width) + resized_videos_grouped[shape] = stacked_videos + resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) + + # Group videos by size for further processing + # Needed in case do_resize is False, or resize returns videos with different sizes + grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos) + processed_videos_grouped = {} + processed_grids = {} + for shape, stacked_videos in grouped_videos.items(): + resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST) + + # Fused rescale and normalize + stacked_videos = self.rescale_and_normalize( + stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + patches = stacked_videos + + # Check that videos have `num_frames` divisible by `temporal_patch_size` + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=1) + batch_size, grid_t, channel = patches.shape[:3] + grid_t = grid_t // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + + patches = patches.view( + batch_size, + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten_patches = patches.reshape( + batch_size, + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + processed_videos_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) + processed_grids = reorder_videos(processed_grids, grouped_videos_index) + pixel_values_videos = torch.cat(processed_videos, dim=0) + video_grid_thw = torch.tensor(processed_grids) + data = { + "pixel_values_videos": pixel_values_videos, + "video_grid_thw": video_grid_thw, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Qwen3VLVideoProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4000cb272723d9920136a6c78465e8413a8b4d1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_vl_moe import * + from .modeling_qwen3_vl_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..25358aa79bff482437632829f6319effad36f138 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py @@ -0,0 +1,335 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Qwen3VLMoeTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + head_dim (`int`, *optional*): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3VLMoe style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen3VLMoe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + attention_bias=False, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=1408, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=True, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + rope_scaling=None, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.head_dim = head_dim or hidden_size // num_attention_heads + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLMoeVisionConfig(PretrainedConfig): + model_type = "qwen3_vl_moe" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLMoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3-VL-MOE style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe" + sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +__all__ = ["Qwen3VLMoeConfig", "Qwen3VLMoeTextConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..519e888d4820780ab2b41ae65068a9a1fec7ecb5 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -0,0 +1,1832 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3VLMoeTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3VLMoeTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3VLMoeTextExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_size = config.moe_intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + ) -> torch.Tensor: + """ + When training it is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + router_indices (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + gate, up = gate_up.chunk(2, dim=-1) + gated_output = up * self.act_fn(gate) + out = gated_output @ self.down_proj[expert_idx] + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size) + next_states = ( + next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + ) + next_states = next_states.sum(dim=0) + return next_states + + +class Qwen3VLMoeTextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = Qwen3VLMoeTextExperts(config) + + # since all the models use norm_topk_prob, we don't need to have a extra check for it + # self.norm_topk_prob = config.norm_topk_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) + routed_out = self.experts(hidden_states, router_weights, router_indices) + return routed_out, router_logits + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3VLMoeTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3VLMoeTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3VLMoeTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLMoeTextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3VLMoeTextAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3VLMoeTextSparseMoeBlock(config) + else: + self.mlp = Qwen3VLMoeTextMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring +class Qwen3VLMoePreTrainedModel(PreTrainedModel): + config: Qwen3VLMoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(Qwen3VLMoeTextSparseMoeBlock, index=1), + "hidden_states": Qwen3VLMoeTextDecoderLayer, + "attentions": Qwen3VLMoeTextAttention, + } + + def _init_weights(self, module): + """Initialize the weights.""" + super()._init_weights(module) + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + if isinstance(module, Qwen3VLMoeTextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + + +class Qwen3VLMoeVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLMoeVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3VLMoeVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3VLMoeVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLMoeVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen3VLMoeVisionAttention(nn.Module): + def __init__(self, config: Qwen3VLMoeVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3VLMoeVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLMoeVisionAttention(config=config) + self.mlp = Qwen3VLMoeVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VLMoeVisionModel(Qwen3VLMoePreTrainedModel): + config: Qwen3VLMoeVisionConfig + _no_split_modules = ["Qwen3VLMoeVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLMoeVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLMoeVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLMoeVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLMoeVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +class Qwen3VLMoeTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLMoeTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VLMoe has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VLMoe, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLMoeTextModel(Qwen3VLMoePreTrainedModel): + config: Qwen3VLMoeTextConfig + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer"] + + def __init__(self, config: Qwen3VLMoeTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3VLMoeTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3VLMoe causal language model (or autoregressive) outputs. + """ +) +class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3VLMoeModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +@auto_docstring +class Qwen3VLMoeModel(Qwen3VLMoePreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLMoeConfig + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLMoeVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLMoeTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VLMoe use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLMoeModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + return Qwen3VLMoeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLMoeConfig + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3VLMoeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLMoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + >>> model = Qwen3VLMoeForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image in short."}, + ], + } + ] + + >>> # Preparation for inference + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + >>> inputs = inputs.to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=128) + >>> generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." + ```""" + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.config.text_config.num_experts, + self.config.text_config.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.config.text_config.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return Qwen3VLMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VLMoe position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "Qwen3VLMoeVisionModel", + "Qwen3VLMoeForConditionalGeneration", + "Qwen3VLMoeModel", + "Qwen3VLMoePreTrainedModel", + "Qwen3VLMoeTextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..72d3452bdc50a244b6cf3fdd9bd76dda5b2fc0f4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -0,0 +1,551 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3-VL-MOE model.""" + +from typing import Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging +from ..qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeDecoderLayer, + Qwen3MoePreTrainedModel, + Qwen3MoeRMSNorm, + load_balancing_loss_func, +) +from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig +from ..qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLCausalLMOutputWithPast, + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + Qwen3VLTextAttention, + Qwen3VLTextModel, + Qwen3VLVisionModel, +) + + +logger = logging.get_logger(__name__) + + +class Qwen3VLMoeTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + head_dim (`int`, *optional*): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3VLMoe style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen3VLMoe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + attention_bias=False, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=1408, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=True, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + rope_scaling=None, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.head_dim = head_dim or hidden_size // num_attention_heads + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLMoeVisionConfig(Qwen3VLVisionConfig): + pass + + +class Qwen3VLMoeConfig(Qwen3VLConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3-VL-MOE style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe" + sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig} + + +class Qwen3VLMoeTextRMSNorm(Qwen3MoeRMSNorm): + pass + + +class Qwen3VLMoeTextExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_size = config.moe_intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + ) -> torch.Tensor: + """ + When training it is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + router_indices (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + gate, up = gate_up.chunk(2, dim=-1) + gated_output = up * self.act_fn(gate) + out = gated_output @ self.down_proj[expert_idx] + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size) + next_states = ( + next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + ) + next_states = next_states.sum(dim=0) + return next_states + + +class Qwen3VLMoeTextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = Qwen3VLMoeTextExperts(config) + + # since all the models use norm_topk_prob, we don't need to have a extra check for it + # self.norm_topk_prob = config.norm_topk_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) + routed_out = self.experts(hidden_states, router_weights, router_indices) + return routed_out, router_logits + + +class Qwen3VLMoeTextAttention(Qwen3VLTextAttention): + pass + + +class Qwen3VLMoeTextDecoderLayer(Qwen3MoeDecoderLayer): + pass + + +class Qwen3VLMoePreTrainedModel(Qwen3MoePreTrainedModel): + config: Qwen3VLMoeConfig + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"] + + def _init_weights(self, module): + """Initialize the weights.""" + PreTrainedModel._init_weights(self, module) + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + if isinstance(module, Qwen3VLMoeTextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + + +class Qwen3VLMoeVisionModel(Qwen3VLVisionModel): + pass + + +class Qwen3VLMoeTextModel(Qwen3VLTextModel): + pass + + +class Qwen3VLMoeCausalLMOutputWithPast(Qwen3VLCausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + + +class Qwen3VLMoeModel(Qwen3VLModel): + pass + + +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + >>> model = Qwen3VLMoeForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image in short."}, + ], + } + ] + + >>> # Preparation for inference + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + >>> inputs = inputs.to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=128) + >>> generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." + ```""" + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.config.text_config.num_experts, + self.config.text_config.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.config.text_config.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return Qwen3VLMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + ) + + +__all__ = [ + "Qwen3VLMoeConfig", + "Qwen3VLMoeTextConfig", + "Qwen3VLMoeVisionModel", + "Qwen3VLMoeForConditionalGeneration", + "Qwen3VLMoeModel", + "Qwen3VLMoePreTrainedModel", + "Qwen3VLMoeTextModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9335fc4beff8b178336ccb546bdecb4cd45171 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_recurrent_gemma import * + from .modeling_recurrent_gemma import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py b/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2a08699123a5f6e6ef8d226f299904fdb7d445 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RecurrentGemma model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RecurrentGemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RecurrentGemmaModel`]. It is used to instantiate a RecurrentGemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the RecurrentGemma-7B. + + e.g. [google/recurrentgemma-2b](https://huggingface.co/google/recurrentgemma-2b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_hidden_layers (`int`, *optional*, defaults to 26): + The number of hidden layers in the model. + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the RecurrentGemma model. Defines the number of + different tokens that can be represented by the + `inputs_ids` passed when calling [`RecurrentGemmaModel`] + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 7680): + Dimension of the MLP representations. + num_attention_heads (`int`, *optional*, defaults to 10): + The number of heads for the attention block and the number of + heads/blocks for the block-diagonal layers used in the RG-LRU gates. + This number must divide `hidden_size` and `lru_width`. + lru_width (`int` or `None`, *optional*): + Dimension of the hidden representations of the RG-LRU. If `None` + this will be set to `hidden_size`. + Whether to scale the output of the embeddings by `sqrt(hidden_size)`. + attention_window_size (`int`, *optional*, defaults to 2048): + The size of the attention window used in the attention block. + conv1d_width (`int`, *optional*, defaults to 4): + The kernel size of conv1d layers used in the recurrent blocks. + logits_soft_cap (`float`, *optional*, defaults to 30.0): + The value at which the logits should be soft-capped to after the transformer and LM-head computation in the Causal LM architecture. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values + attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + hidden_activation (``str` or `function``, *optional*, defaults to `"gelu_pytorch_tanh"`): + The hidden activation used in the recurrent block as well as the MLP layer of the decoder layers. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + The partial rotary factor used in the initialization of the rotary embeddings. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + block_types (`list[str]`, *optional*, defaults to `('recurrent', 'recurrent', 'attention')`): + List of aleternating blocks that will be repeated to initialize the `temporal_block` layer. + attention_dropout (`float`, *optional*, defaults to 0.0): dropout value to use after the attention softmax. + num_key_value_heads (`16`, *optional*, defaults to 16): Number of key value heads to use GQA. + attention_bias (`bool`, *optional*, defaults to `False`): whether or not the linear q,k,v of the Attention layer should have bias + w_init_variance_scale (`float`, *optional*, defaults to 0.01): weight initialization variance. + ```python + >>> from transformers import RecurrentGemmaModel, RecurrentGemmaConfig + + >>> # Initializing a RecurrentGemma recurrentgemma-2b style configuration + >>> configuration = RecurrentGemmaConfig() + + >>> # Initializing a model from the recurrentgemma-2b style configuration + >>> model = RecurrentGemmaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "recurrent_gemma" + + def __init__( + self, + num_hidden_layers=26, + vocab_size=256000, + hidden_size=2560, + intermediate_size=3 * 2560, + num_attention_heads=10, + lru_width=None, + attention_window_size=2048, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + hidden_activation="gelu_pytorch_tanh", + partial_rotary_factor=0.5, + rope_theta=10000.0, + block_types=("recurrent", "recurrent", "attention"), + attention_dropout=0.0, + num_key_value_heads=None, + attention_bias=False, + w_init_variance_scale=0.01, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] + + +__all__ = ["RecurrentGemmaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..88364515459af62e3bdc74d41c9c1dfe97c88ba2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -0,0 +1,785 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RecurrentGemma model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.import_utils import is_torchdynamo_compiling +from .configuration_recurrent_gemma import RecurrentGemmaConfig + + +logger = logging.get_logger(__name__) +_MAX_SQRT_GRADIENT = 1000.0 + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->RecurrentGemma +class RecurrentGemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst RecurrentGemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class RecurrentGemmaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim, base=10000, device=None): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class RecurrentGemmaSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: RecurrentGemmaConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.partial_rotary_factor = config.partial_rotary_factor + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = RecurrentGemmaRotaryEmbedding( + int(self.partial_rotary_factor * self.head_dim), + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + + # Partial rotary embedding + query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1) + key_rot, key_pass = torch.chunk(key_states, int(1 / self.partial_rotary_factor), dim=-1) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if use_cache and hasattr(self, "key_states"): + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=causal_mask, # pretty much a must for sliding window backend! + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + def _setup_cache(self, batch_size, device, dtype=None): + if dtype is None and self.config.dtype is not None: + dtype = self.config.dtype + dtype = dtype if dtype is not None else torch.float32 + cache_shape = (batch_size, self.num_key_value_heads, self.config.attention_window_size, self.head_dim) + self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) + self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) + + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + """ + torch.compile compatible sliding window. + Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`. + The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`) + """ + cache_position = cache_kwargs.get("cache_position") + if cache_position.shape[0] > self.config.attention_window_size: + # int indexing -> device sync? in compile, use tensor + k_out = key_states[:, :, -self.config.attention_window_size :, :] + v_out = value_states[:, :, -self.config.attention_window_size :, :] + else: + slicing = torch.ones( + self.config.attention_window_size, dtype=torch.long, device=value_states.device + ).cumsum(0) + cache_position = cache_position.clamp(0, self.config.attention_window_size - 1) + to_shift = cache_position >= self.config.attention_window_size - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size + + k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states.to(k_out.dtype) + v_out[:, :, cache_position] = value_states.to(v_out.dtype) + + self.key_states, self.value_states = k_out, v_out + return k_out, v_out + + +class SqrtBoundDerivative(torch.autograd.Function): + """Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`.""" + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + """The forward pass, which is a normal `sqrt`.""" + ctx.save_for_backward(x) + return torch.sqrt(x) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """The backward pass, which clips the `sqrt` gradient.""" + (x,) = ctx.saved_tensors + clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2)) + return grad_output / torch.sqrt(clipped_x_times_4) + + +class RecurrentGemmaRglru(nn.Module): + """A Real-Gated Linear Recurrent Unit (RG-LRU) layer.""" + + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.block_width = config.lru_width // self.num_attention_heads + + self.recurrent_param = nn.Parameter(torch.empty([config.lru_width])) + self.input_gate_weight = nn.Parameter( + torch.empty([self.num_attention_heads, self.block_width, self.block_width]) + ) + self.input_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width])) + + self.recurrent_gate_weight = nn.Parameter( + torch.empty([self.num_attention_heads, self.block_width, self.block_width]) + ) + self.recurrent_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width])) + self.recurrent_states = None + + def forward( + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, lru_width = activations.shape + reset = position_ids[:, :, None] == 0 + + reshape_act = activations.reshape(batch_size * seq_len, self.num_attention_heads, self.block_width) + reshape_act = reshape_act.permute(1, 0, 2) + + res = torch.baddbmm(self.input_gate_bias[:, None, :], reshape_act, self.input_gate_weight) + input_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width)) + + res = torch.baddbmm(self.recurrent_gate_bias[:, None, :], reshape_act, self.recurrent_gate_weight) + recurrent_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width)) + + # Compute the parameter `A` of the recurrence. + log_recurrent_gate = -8.0 * recurrent_gate * nn.functional.softplus(self.recurrent_param) + recurrent_gate = torch.exp(log_recurrent_gate) + a_square = torch.exp(2 * log_recurrent_gate) + + # Gate the input. + gated_inputs = activations * input_gate + + # Apply gamma normalization to the input. We need to clip the derivatives of + # `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying + multiplier = 1 + tracing = isinstance(activations, torch.fx.Proxy) or is_torchdynamo_compiling() + if not torch.jit.is_tracing() and not tracing: + multiplier = SqrtBoundDerivative.apply(1 - a_square) + multiplier = reset + ~reset * multiplier + normalized_x = gated_inputs * multiplier.type(activations.dtype) + + hidden_states, recurrent_states = self._rnn_scan( + hidden_states=normalized_x, + recurrent_gate=recurrent_gate, + reset=reset, + recurrent_states=self.recurrent_states, + ) + self.recurrent_states = recurrent_states + return hidden_states + + # TODO refactor + def _rnn_scan( + self, + hidden_states: torch.Tensor, + recurrent_gate: torch.Tensor, + reset: torch.Tensor, + recurrent_states: Union[torch.Tensor, None], + acc_dtype: torch.dtype = torch.float32, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Runs the recurrence of a linear RNN. + + Args: + hidden_states: The input sequence. + recurrent_gate: The diagonal of the recurrence matrix `A`. + reset: Indicator of document boundaries, e.g. when to reset the hidden state + of the RNN. + recurrent_states: The initial hidden state. + acc_dtype: The data type for the accumulation. + + Returns: + The output of the linear recurrence. + """ + # Multiply `a` by the reset. + recurrent_gate = recurrent_gate * ~reset + + if hidden_states.shape[1] == 1: + # Using scan in sampling mode. + if recurrent_states is None: # same here, when decoding you always have cache + return hidden_states, hidden_states[:, 0].type(acc_dtype) + + else: + contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to( + recurrent_gate.device + ) + contextualized_states += hidden_states.type(acc_dtype) + return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1] + + else: + # Using scan in linear mode. + if recurrent_states is None: + recurrent_states = torch.zeros(hidden_states[:, 0].shape, dtype=acc_dtype, device=hidden_states.device) + + contextualized_states = torch.zeros_like(hidden_states) + for t in range(hidden_states.shape[1]): + recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device) + recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype) + contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype) + + return contextualized_states, recurrent_states + + +class RecurrentGemmaRecurrentBlock(nn.Module): + """Griffin and Hawk's recurrent block.""" + + def __init__(self, config): + super().__init__() + self.lru_width = config.lru_width + self.hidden_size = config.hidden_size + self.linear_y = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width) + self.linear_x = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width) + self.linear_out = nn.Linear(in_features=config.lru_width, out_features=config.hidden_size) + self.conv1d_width = config.conv1d_width + self.conv_1d = nn.Conv1d( + config.lru_width, + config.lru_width, + kernel_size=config.conv1d_width, + groups=config.lru_width, + padding=config.conv1d_width - 1, + ) + self.rg_lru = RecurrentGemmaRglru(config) + self.act_fn = ACT2FN[config.hidden_activation] + + self.conv1d_state = None + + def forward( + self, + input_states: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + cache_position: torch.Tensor, + use_cache: bool = True, + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + _, seq_len, _ = input_states.shape + + y_branch = self.linear_y(input_states) + y_branch = self.act_fn(y_branch) + + x_branch = self.linear_x(input_states) + x_branch = x_branch.transpose(1, 2) + + if use_cache: + if cache_position.shape[0] != 1: # prefill + self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0)) + x_branch = self.conv_1d(x_branch)[..., :seq_len] + else: # decoding + conv_state = torch.cat((self.conv1d_state, x_branch), -1) + x_branch = torch.sum(conv_state * self.conv_1d.weight[:, 0, :], dim=-1) + self.conv_1d.bias + x_branch = x_branch.unsqueeze(-1) + self.conv1d_state = conv_state[:, :, 1:] + else: + x_branch = self.conv_1d(x_branch)[..., :seq_len] + + x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids) + + hidden_states = x_branch * y_branch + hidden_states = self.linear_out(hidden_states) + return hidden_states + + def _setup_cache(self, batch, device, dtype): + # recurrent_states always computed in full precision + self.rg_lru.recurrent_states = torch.zeros((batch, self.lru_width), device=device, dtype=torch.float32) + self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype) + + +TEMPORAL_BLOCK_CLASSES = {"recurrent": RecurrentGemmaRecurrentBlock, "attention": RecurrentGemmaSdpaAttention} + + +class RecurrentGemmaMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, hidden_states): + gate = self.act_fn(self.gate_proj(hidden_states)) + return self.down_proj(gate * self.up_proj(hidden_states)) + + +class RecurrentGemmaDecoderLayer(GradientCheckpointingLayer): + """Griffin and Hawk's residual block.""" + + def __init__(self, config, layer_idx): + super().__init__() + self.temporal_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.temporal_block = TEMPORAL_BLOCK_CLASSES[config.layers_block_type[layer_idx]](config) + self.channel_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp_block = RecurrentGemmaMlp(config) + + def forward( + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + cache_position: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + raw_activations = activations + inputs_normalized = self.temporal_pre_norm(raw_activations) # RMSNorm introduces slight slight differences + + hidden_states = self.temporal_block( + inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache + ) + + residual = hidden_states + raw_activations + + hidden_states = self.channel_pre_norm(residual) + hidden_states = self.mlp_block(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states + + +@auto_docstring +class RecurrentGemmaPreTrainedModel(PreTrainedModel): + config: RecurrentGemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["RecurrentGemmaDecoderLayer"] + _skip_keys_device_placement = ["cache"] + _supports_flash_attn = False + _supports_sdpa = False # we can't compare with eager for now + + def _init_weights(self, module): + std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) + if isinstance(module, nn.Conv1d): + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + torch.nn.init.zeros_(module.bias) + elif isinstance(module, RecurrentGemmaSdpaAttention): + torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + + std = math.sqrt(self.config.final_w_init_variance_scale / self.config.hidden_size) + torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=std) + elif isinstance(module, RecurrentGemmaRecurrentBlock): + torch.nn.init.zeros_(module.linear_x.bias) + torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + + torch.nn.init.zeros_(module.linear_y.bias) + torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) + + std = math.sqrt(self.config.final_w_init_variance_scale / self.config.lru_width) + torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std) + torch.nn.init.zeros_(module.linear_out.bias) + elif isinstance(module, RecurrentGemmaRglru): + std = math.sqrt( + self.config.w_init_variance_scale / (self.config.lru_width // self.config.num_attention_heads) + ) + torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std) + torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std) + torch.nn.init.zeros_(module.input_gate_bias) + torch.nn.init.zeros_(module.recurrent_gate_bias) + + module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8) + module.recurrent_param.data.log_().mul_(0.5) + module.recurrent_param.data.neg_().exp_().sub_(1.0).log_() + elif isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + if getattr(module, "bias", None) is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, RecurrentGemmaRMSNorm): + module.weight.data.zero_() + + def _setup_cache(self, config, batch, device, dtype): + layers = getattr(self, "model", self).layers + for layer in layers: + layer.temporal_block._setup_cache(batch, device, dtype) + + def reset_cache(self, batch, device, dtype): + pass + + +@auto_docstring +class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel): + def __init__(self, config: RecurrentGemmaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [RecurrentGemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.register_buffer( + "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False + ) + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + if use_cache and inputs_embeds.shape[1] != 1: # TODO let's maybe only call in the `generate`? + self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + hidden_states = hidden_states * self.normalizer.type(hidden_states.dtype) + + all_hidden_states = () if output_hidden_states else None + for i, residual_block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache) + + hidden_states = self.final_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + # Ignore copy + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = max(self.config.attention_window_size, sequence_length) + + diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = diagonal + if sequence_length != 1: + causal_mask = torch.triu(diagonal, diagonal=-1) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + # Crop the attention mask to the target length. + attention_mask = attention_mask[:, -target_length:] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"]: + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma +@auto_docstring +class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RecurrentGemmaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, RecurrentGemmaForCausalLM + + >>> model = RecurrentGemmaForCausalLM.from_pretrained("google/recurrentgemma-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-2b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + cache_position=cache_position, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + # Soft-cap the logits TODO remove if always done. + # if self.config.logits_soft_cap is not None: + cap = self.config.logits_soft_cap + logits = nn.functional.tanh(logits / cap) * cap + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +__all__ = ["RecurrentGemmaForCausalLM", "RecurrentGemmaModel", "RecurrentGemmaPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/reformer/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/reformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6721b6ac2367847239ab69d9b36133e6f87967 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/reformer/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_reformer import * + from .modeling_reformer import * + from .tokenization_reformer import * + from .tokenization_reformer_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/reformer/configuration_reformer.py b/venv/lib/python3.13/site-packages/transformers/models/reformer/configuration_reformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f7734b9f3fc042a7d42161c4f214ee296b6acb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/reformer/configuration_reformer.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Reformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ReformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ReformerModel`]. It is used to instantiate a + Reformer model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ReFormer + [google/reformer-crime-and-punishment](https://huggingface.co/google/reformer-crime-and-punishment) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attention_head_size (`int`, *optional*, defaults to 64): + Dimensionality of the projected key, query and value vectors + attn_layers (`list[str]`, *optional*, defaults to `["local", "lsh", "local", "lsh", "local", "lsh"]`): + List of attention layer types in ascending order. It can be chosen between a LSHSelfAttention layer + (`"lsh"`) and a LocalSelfAttention layer (`"local"`). + + For more information on LSHSelfAttention layer, see [LSH Self Attention](reformer#lsh-self-attention). For + more information on LocalSelfAttention layer, see [Local Self Attention](reformer#local-self-attention). + axial_pos_embds (`bool`, *optional*, defaults to `True`): + Whether or not to use axial position embeddings. For more information on how axial position embeddings + work, see [Axial Position Encodings](reformer#axial-positional-encodings). + axial_norm_std (`float`, *optional*, defaults to 1.0): + The standard deviation of the normal_initializer for initializing the weight matrices of the axial + positional encodings. + axial_pos_shape (`list[int]`, *optional*, defaults to `[64, 64]`): + The position dims of the axial position encodings. During training, the product of the position dims has to + be equal to the sequence length. + + For more information on how axial position embeddings work, see [Axial Position + Encodings](reformer#axial-positional-encodings). + axial_pos_embds_dim (`list[int]`, *optional*, defaults to `[64, 192]`): + The embedding dims of the axial position encodings. The sum of the embedding dims has to be equal to the + hidden size. + + For more information on how axial position embeddings work, see [Axial Position + Encodings](reformer#axial-positional-encodings). + chunk_size_lm_head (`int`, *optional*, defaults to 0): + The chunk size of the final language model feed forward head layer. A chunk size of 0 means that the feed + forward layer is not chunked. A chunk size of n means that the feed forward layer processes n < + sequence_length embeddings at a time. + + For more information on feed forward chunking, see [How does Feed Forward Chunking + work?](../glossary#feed-forward-chunking). + eos_token_id (`int`, *optional*, defaults to 2): + The token id for the end-of-sentence token. + feed_forward_size (`int`, *optional*, defaults to 512): + Dimensionality of the feed_forward layer in the residual attention block. + hash_seed (`int`, *optional*): + Seed that can be used to make local sensitive hashing in `LSHSelfAttention` deterministic. This should only + be set for testing purposed. For evaluation and training purposes `hash_seed` should be left as `None` to + ensure fully random rotations in local sensitive hashing scheme. + hidden_act (`str` or `Callable`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the feed forward layer in the residual attention + block. If string, `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.05): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the output hidden states of the residual attention blocks. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether or not to use a causal mask in addition to the `attention_mask` passed to [`ReformerModel`]. When + using the Reformer for causal language modeling, this argument should be set to `True`. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + local_chunk_length (`int`, *optional*, defaults to 64): + Length of chunk which attends to itself in `LocalSelfAttention`. Chunking reduces memory complexity from + sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk + length (chunked self attention). + local_num_chunks_before (`int`, *optional*, defaults to 1): + Number of previous neighbouring chunks to attend to in `LocalSelfAttention` layer to itself. + local_num_chunks_after (`int`, *optional*, defaults to 0): + Number of following neighbouring chunks to attend to in `LocalSelfAttention` layer in addition to itself. + local_attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities in `LocalSelfAttention`. + lsh_attn_chunk_length (`int`, *optional*, defaults to 64): + Length of chunk which attends to itself in `LSHSelfAttention`. Chunking reduces memory complexity from + sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk + length (chunked self attention). + lsh_num_chunks_before (`int`, *optional*, defaults to 1): + Number of previous neighbouring chunks to attend to in `LSHSelfAttention` layer to itself. + lsh_num_chunks_after (`int`, *optional*, defaults to 0): + Number of following neighbouring chunks to attend to in `LSHSelfAttention` layer to itself. + lsh_attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities in `LSHSelfAttention`. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_buckets (`int` or `list[int]`, *optional*): + Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. + Each query key vector is hashed into a hash in `1, ..., num_buckets`. The number of buckets can also be + factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a + hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is + factorized into two factors. The number of buckets (or the product the factors) should approximately equal + sequence length / lsh_chunk_length. If `num_buckets` not set, a good value is calculated on the fly. + num_hashes (`int`, *optional*, defaults to 1): + Number of hashing rounds (e.g., number of random rotations) in Local Sensitive Hashing scheme. The higher + `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive + the hashing becomes. + pad_token_id (`int`, *optional*, defaults to 0): + The token id for the padding token. + vocab_size (`int`, *optional*, defaults to 320):\ + Vocabulary size of the Reformer model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`ReformerModel`]. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie input and output embeddings. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import ReformerConfig, ReformerModel + + >>> # Initializing a Reformer configuration + >>> configuration = ReformerConfig() + + >>> # Initializing a Reformer model (with random weights) + >>> model = ReformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` +""" + + model_type = "reformer" + keys_to_ignore_at_inference = ["past_buckets_states"] + attribute_map = {} + + def __init__( + self, + attention_head_size=64, + attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], + axial_norm_std=1.0, + axial_pos_embds=True, + axial_pos_shape=[64, 64], + axial_pos_embds_dim=[64, 192], + chunk_size_lm_head=0, + eos_token_id=2, + feed_forward_size=512, + hash_seed=None, + hidden_act="relu", + hidden_dropout_prob=0.05, + hidden_size=256, + initializer_range=0.02, + is_decoder=False, + layer_norm_eps=1e-12, + local_num_chunks_before=1, + local_num_chunks_after=0, + local_attention_probs_dropout_prob=0.05, + local_attn_chunk_length=64, + lsh_attn_chunk_length=64, + lsh_attention_probs_dropout_prob=0.0, + lsh_num_chunks_before=1, + lsh_num_chunks_after=0, + max_position_embeddings=4096, + num_attention_heads=12, + num_buckets=None, + num_hashes=1, + pad_token_id=0, + vocab_size=320, + tie_word_embeddings=False, + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + self.hash_seed = hash_seed + self.vocab_size = vocab_size + self.attention_head_size = attention_head_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hashes = num_hashes + self.num_hidden_layers = len(attn_layers) + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + self.lsh_attn_chunk_length = lsh_attn_chunk_length + self.local_attn_chunk_length = local_attn_chunk_length + self.lsh_num_chunks_after = lsh_num_chunks_after + self.lsh_num_chunks_before = lsh_num_chunks_before + self.local_num_chunks_after = local_num_chunks_after + self.local_num_chunks_before = local_num_chunks_before + self.hidden_act = hidden_act + self.feed_forward_size = feed_forward_size + self.hidden_dropout_prob = hidden_dropout_prob + self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob + self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) + self.axial_norm_std = axial_norm_std + self.chunk_size_lm_head = chunk_size_lm_head + self.attn_layers = attn_layers + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_decoder=is_decoder, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["ReformerConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/reformer/modeling_reformer.py b/venv/lib/python3.13/site-packages/transformers/models/reformer/modeling_reformer.py new file mode 100644 index 0000000000000000000000000000000000000000..031754e9faa08096b7229f71483b38a4a9115a13 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/reformer/modeling_reformer.py @@ -0,0 +1,2781 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch REFORMER model.""" + +import sys +from collections import namedtuple +from collections.abc import Iterable +from dataclasses import dataclass +from functools import reduce +from operator import mul +from typing import Any, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.autograd.function import Function +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + ModelOutput, + auto_docstring, + logging, +) +from .configuration_reformer import ReformerConfig + + +logger = logging.get_logger(__name__) + + +# Define named tuples for nn.Modules here +LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) +AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) +ReformerBackwardOutput = namedtuple( + "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] +) +ReformerEncoderOutput = namedtuple( + "ReformerEncoderOutput", + ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], +) + + +class ReformerDynamicCache: + """ + A dynamic cache that stores past buckets instead of key/values. + """ + + def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self.buckets_cache: list[torch.Tensor] = [] + self.states_cache: list[torch.Tensor] = [] + + if _distributed_cache_data is not None: + for buckets, states in _distributed_cache_data: + self.buckets_cache.append(buckets) + self.states_cache.append(states) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds + to the number of layers in the model. + """ + return len(self.states_cache) + + def update( + self, + buckets: torch.Tensor, + states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `ReformerDynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += states.shape[-2] + + # Update the cache + if states is not None: + if len(self.states_cache) <= layer_idx: + self.states_cache.append(states) + else: + self.states_cache[layer_idx] = torch.cat([self.states_cache[layer_idx], states], dim=1) + + if buckets is not None: + if len(self.buckets_cache) <= layer_idx: + self.buckets_cache.append(buckets) + else: + self.buckets_cache[layer_idx] = torch.cat([self.buckets_cache[layer_idx], buckets], dim=-1) + else: + # `ReformerLocalAttn` passes `None` to buckets as the module uses no buckets + self.buckets_cache.append(torch.tensor([], device=self.states_cache[layer_idx].device)) + + return self.buckets_cache[layer_idx], self.states_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return None + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """Converts the `ReformerDynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + buckets, states = self.buckets_cache[layer_idx], self.states_cache[layer_idx] + buckets = buckets if buckets.numel() != 0 else None + legacy_cache += ((buckets, states),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_buckets_states: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "ReformerDynamicCache": + """Converts a cache in the legacy cache format into an equivalent `ReformerDynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_buckets_states is not None: + for layer_idx in range(len(past_buckets_states)): + buckets, states = past_buckets_states[layer_idx] + cache.update(buckets, states, layer_idx) + return cache + + +def _stable_argsort(vector, dim): + # this function scales the vector so that torch.argsort is stable. + # torch.argsort is not stable on its own + scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1) + scale_offset = scale_offset.expand(vector.shape) + scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim]) + return torch.argsort(scaled_vector, dim=dim) + + +def _get_least_common_mult_chunk_len(config): + attn_types = config.attn_layers + attn_types_set = set(attn_types) + if len(attn_types_set) == 1 and attn_types[0] == "lsh": + return config.lsh_attn_chunk_length + elif len(attn_types_set) == 1 and attn_types[0] == "local": + return config.local_attn_chunk_length + elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}: + return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length) + else: + raise NotImplementedError( + f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select " + "attn layer types from ['lsh', 'local'] only." + ) + + +def _get_min_chunk_len(config): + attn_types = config.attn_layers + attn_types_set = set(attn_types) + if len(attn_types_set) == 1 and attn_types[0] == "lsh": + return config.lsh_attn_chunk_length + elif len(attn_types_set) == 1 and attn_types[0] == "local": + return config.local_attn_chunk_length + elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}: + return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) + else: + raise NotImplementedError( + f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select " + "attn layer types from ['lsh', 'local'] only." + ) + + +class AxialPositionEmbeddings(nn.Module): + """ + Constructs axial position embeddings. Useful for very long input sequences to save memory and time. + """ + + def __init__(self, config): + super().__init__() + self.axial_pos_shape = config.axial_pos_shape + self.axial_pos_embds_dim = config.axial_pos_embds_dim + self.dropout = config.hidden_dropout_prob + + self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config) + self.weights = nn.ParameterList() + + if sum(self.axial_pos_embds_dim) != config.hidden_size: + raise ValueError( + f"Make sure that config.axial_pos_embds factors: {self.axial_pos_embds_dim} sum to " + f"config.hidden_size: {config.hidden_size}" + ) + + # create weights + for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim): + # create expanded shapes + ax_shape = [1] * len(self.axial_pos_shape) + ax_shape[axis] = self.axial_pos_shape[axis] + ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,) + + # create tensor and init + self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32))) + + def forward(self, position_ids): + # broadcast weights to correct shape + batch_size = position_ids.shape[0] + sequence_length = position_ids.shape[1] + + broadcasted_weights = [ + weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights + ] + + if self.training is True: + if reduce(mul, self.axial_pos_shape) != sequence_length: + raise ValueError( + f"If training, make sure that config.axial_pos_shape factors: {self.axial_pos_shape} multiply to " + f"sequence length. Got prod({self.axial_pos_shape}) != sequence_length: {sequence_length}. " + f"You might want to consider padding your sequence length to {reduce(mul, self.axial_pos_shape)} " + "or changing config.axial_pos_shape." + ) + + if self.dropout > 0: + weights = torch.cat(broadcasted_weights, dim=-1) + # permute weights so that 2D correctly drops dims 1 and 2 + transposed_weights = weights.transpose(2, 1) + # drop entire matrix of last two dims (prev dims 1 and 2) + dropped_transposed_weights = nn.functional.dropout2d( + transposed_weights, p=self.dropout, training=self.training + ) + dropped_weights = dropped_transposed_weights.transpose(2, 1) + + position_encodings = torch.reshape(dropped_weights, (batch_size, sequence_length, -1)) + + else: + position_encodings = torch.cat( + [torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], + dim=-1, + ) + + else: + if reduce(mul, self.axial_pos_shape) < sequence_length: + raise ValueError( + f"Make sure that config.axial_pos_shape factors: {self.axial_pos_shape} multiply at least to " + f"max(sequence_length, least_common_mult_chunk_length): max({sequence_length}, " + f"{self.least_common_mult_chunk_length})." + ) + + # compute how many columns are needed + max_position_id = position_ids.max().item() + required_pos_encodings_columns = -(-(max_position_id + 1) // self.axial_pos_shape[1]) + + # cut to columns that are needed + position_encodings = torch.cat( + [weight[:, :required_pos_encodings_columns] for weight in broadcasted_weights], dim=-1 + ) + position_encodings = torch.reshape(position_encodings, (batch_size, -1, position_encodings.shape[-1])) + + # select correct position encodings + position_encodings = torch.cat( + [ + torch.index_select(position_encodings[i], 0, position_ids[i]).unsqueeze(0) + for i in range(batch_size) + ], + dim=0, + ) + + return position_encodings + + +class PositionEmbeddings(nn.Module): + """Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`.""" + + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + def forward(self, position_ids): + position_embeddings = self.embedding(position_ids) + position_embeddings = nn.functional.dropout(position_embeddings, p=self.dropout, training=self.training) + return position_embeddings + + +class ReformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.max_position_embeddings = config.max_position_embeddings + self.dropout = config.hidden_dropout_prob + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = ( + AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) + ) + + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, start_idx_pos_encodings=0): + if input_ids is not None: + input_shape = input_ids.size() + device = input_ids.device + else: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + + seq_length = input_shape[1] + if position_ids is None: + position_ids = torch.arange( + start_idx_pos_encodings, start_idx_pos_encodings + seq_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if position_ids.shape[-1] > self.max_position_embeddings: + raise ValueError( + f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than " + f"config.max_position_embeddings {self.max_position_embeddings}." + ) + + # dropout + embeddings = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training) + + # add positional embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + return embeddings + + +class EfficientAttentionMixin: + """ + A few utilities for nn.Modules in Reformer, to be used as a mixin. + """ + + def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): + """ + Used to implement attention between consecutive chunks. + + Args: + vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] + num_chunks_before: chunks before current chunk to include in attention + num_chunks_after: chunks after current chunk to include in attention + + Returns: + tensor of shape [num_chunks, N * chunk_length, ...], where N = (1 + num_chunks_before + num_chunks_after). + """ + if num_chunks_before == 0 and num_chunks_after == 0: + return vectors + + slices = [] + for i in range(-num_chunks_before, num_chunks_after + 1): + if i == 0: + slices.append(vectors) + else: + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) + return torch.cat(slices, dim=3) + + def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size): + """ + splits hidden_size dim into attn_head_size and num_attn_heads + """ + new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size) + x = x.view(*new_x_shape) + return x.transpose(2, 1) + + def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size): + """ + merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + x = x.permute(0, 2, 1, 3) + return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) + + def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): + """ + splits sequence length dim of vectors into `dim_factor_1` and `dim_factor_2` dims + """ + batch_size = vectors.shape[0] + split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) + + if len(vectors.shape) == 4: + return torch.reshape(vectors, split_dim_shape + (attn_head_size,)) + elif len(vectors.shape) == 3: + return torch.reshape(vectors, split_dim_shape) + else: + raise ValueError(f"Input vector rank should be one of [3, 4], but is: {len(vectors.shape)}") + + +class LSHSelfAttention(nn.Module, EfficientAttentionMixin): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + + self.chunk_length = config.lsh_attn_chunk_length + self.num_hashes = config.num_hashes + self.num_buckets = config.num_buckets + self.num_chunks_before = config.lsh_num_chunks_before + self.num_chunks_after = config.lsh_num_chunks_after + self.hash_seed = config.hash_seed + self.is_decoder = config.is_decoder + self.max_position_embeddings = config.max_position_embeddings + self.layer_idx = layer_idx + + self.dropout = config.lsh_attention_probs_dropout_prob + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.hidden_size = config.hidden_size + + # projection matrices + self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + + # save mask value here. Need fp32 and fp16 mask values + self.register_buffer("self_mask_value_float16", torch.tensor(-1e3), persistent=False) + self.register_buffer("self_mask_value_float32", torch.tensor(-1e5), persistent=False) + self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False) + self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + buckets=None, + past_buckets_states=None, + use_cache=False, + output_attentions=False, + cache_position=None, + **kwargs, + ): + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + # num hashes can optionally be overwritten by user + num_hashes = num_hashes if num_hashes is not None else self.num_hashes + + # check if cache shall be used and that hidden states are already cached + exists_cache = past_buckets_states is not None and len(past_buckets_states) > self.layer_idx + if exists_cache: + assert sequence_length == 1, ( + "At the moment, auto-regressive language generation is only possible one word at a time. Make sure" + f" that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed." + ) + + # get query vector + query_vectors = self.query_key(hidden_states) + query_vectors = self._split_hidden_size_dim( + query_vectors, self.num_attention_heads, self.attention_head_size + ) + + past_buckets = past_buckets_states.buckets_cache[self.layer_idx] + past_states = past_buckets_states.states_cache[self.layer_idx] + + if past_buckets.numel() != 0: + key_value_hidden_states, sorted_bucket_idx, buckets = self._get_relevant_hid_states_and_buckets( + query_vectors=query_vectors, + attention_mask=attention_mask, + num_hashes=num_hashes, + hidden_states=hidden_states, + past_states=past_states, + past_buckets=past_buckets, + ) + + query_key_vectors = self._query_per_attn_head(key_value_hidden_states) + value_vectors = self._value_per_attn_head(key_value_hidden_states) + + # split key & value vectors by num hashes to apply + # self attention on each separately + query_key_vectors = self._split_seq_length_dim_to( + query_key_vectors, + num_hashes, + -1, + self.num_attention_heads, + self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, + num_hashes, + -1, + self.num_attention_heads, + self.attention_head_size, + ) + # repeat query vectors across hash dimension + query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1) + else: + key_value_hidden_states = torch.cat([past_states, hidden_states], dim=1) + + query_key_vectors = self.query_key(key_value_hidden_states) + value_vectors = self.value(key_value_hidden_states) + + else: + # project hidden_states to query_key and value + query_vectors = None + query_key_vectors = self.query_key(hidden_states) + value_vectors = self.value(hidden_states) + + # if query key is not already split + if not exists_cache or past_buckets.numel() == 0: + query_key_vectors = self._split_hidden_size_dim( + query_key_vectors, self.num_attention_heads, self.attention_head_size + ) + value_vectors = self._split_hidden_size_dim( + value_vectors, self.num_attention_heads, self.attention_head_size + ) + + # cache buckets for next incremental decoding + if exists_cache and key_value_hidden_states.shape[1] >= self.chunk_length: + buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) + + # free memory + del hidden_states + + assert query_key_vectors.shape[-1] == self.attention_head_size, ( + f"last dim of query_key_vectors is {query_key_vectors.shape[-1]} but should be {self.attention_head_size}." + ) + assert value_vectors.shape[-1] == self.attention_head_size, ( + f"last dim of value_vectors is {value_vectors.shape[-1]} but should be {self.attention_head_size}." + ) + + do_standard_self_attention = (sequence_length <= self.chunk_length) or ( + exists_cache and past_states is not None + ) + # LSH attention only makes sense if chunked attention should be performed + if not do_standard_self_attention: + # set `num_buckets` on the fly, recommended way to do it + if self.num_buckets is None: + self._set_num_buckets(sequence_length) + + # use cached buckets for backprop only + if buckets is None: + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) + else: + # make sure buckets has correct shape for LSH attention + buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes * sequence_length) + + assert int(buckets.shape[-1]) == num_hashes * sequence_length, ( + f"last dim of buckets is {buckets.shape[-1]}, but should be {num_hashes * sequence_length}" + ) + + sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( + sequence_length, buckets, num_hashes + ) + + # make sure bucket idx is not longer then sequence length + sorted_bucket_idx_per_hash = sorted_bucket_idx % sequence_length + + # cluster query key value vectors according to hashed buckets + query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes) + value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes) + query_key_vectors = self._split_seq_length_dim_to( + query_key_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, + ) + + if self.chunk_length is None: + assert self.num_chunks_before == 0 and self.num_chunks_after == 0, ( + "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and" + " `config.num_chunks_before` are set to 0." + ) + elif exists_cache and past_buckets.numel() != 0: + # use max sequence length + sorted_bucket_idx_per_hash = sorted_bucket_idx + else: + # get sequence length indices + sorted_bucket_idx_per_hash = torch.arange(sequence_length, device=query_key_vectors.device).repeat( + batch_size, self.num_attention_heads, 1 + ) + + # scale key vectors + sqrt_num = np.sqrt(self.attention_head_size) + key_vectors = self._len_and_dim_norm(query_key_vectors, sqrt_num) + + # set query_vectors to query key vectors if LSH self attention + query_vectors = query_vectors if query_vectors is not None else query_key_vectors + + # free memory + del query_key_vectors + + # get attention probs + out_vectors, logits, attention_probs = self._attend( + query_vectors=query_vectors, + key_vectors=key_vectors, + value_vectors=value_vectors, + sorted_bucket_idx_per_hash=sorted_bucket_idx_per_hash, + attention_mask=attention_mask, + head_mask=head_mask, + do_standard_self_attention=do_standard_self_attention, + use_cache=exists_cache, + ) + + # free memory + del key_vectors, value_vectors + + # re-order out_vectors and logits + if not do_standard_self_attention: + # sort clusters back to correct ordering + out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) + + if not do_standard_self_attention or (exists_cache and past_buckets.numel() != 0): + # sum up all hash rounds + if num_hashes > 1: + out_vectors = self._split_seq_length_dim_to( + out_vectors, + num_hashes, + sequence_length, + self.num_attention_heads, + self.attention_head_size, + ) + logits = self._split_seq_length_dim_to( + logits, + num_hashes, + sequence_length, + self.num_attention_heads, + self.attention_head_size, + ).unsqueeze(-1) + + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) + # free memory + del probs_vectors + + # free memory + del logits + + assert out_vectors.shape == ( + batch_size, + self.num_attention_heads, + sequence_length, + self.attention_head_size, + ), ( + "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length," + " config.attention_head_size]`." + ) + + out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) + + if output_attentions is False: + attention_probs = () + + if buckets is not None: + buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes, -1) + + return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) + + def _query_per_attn_head(self, hidden_states): + per_head_query_key = self.query_key.weight.reshape( + self.num_attention_heads, self.attention_head_size, self.hidden_size + ).transpose(-2, -1) + # only relevant for inference and no bias => we can use einsum here + query_key_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_query_key) + return query_key_vectors + + def _value_per_attn_head(self, hidden_states): + per_head_value = self.value.weight.reshape( + self.num_attention_heads, self.attention_head_size, self.hidden_size + ).transpose(-2, -1) + # only relevant for inference and no bias => we can use einsum here + value_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_value) + return value_vectors + + def _hash_vectors(self, vectors, num_hashes, attention_mask, increase_num_buckets=False): + batch_size = vectors.shape[0] + + # See https://huggingface.co/papers/1509.02897 + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(self.num_buckets, int): + assert self.num_buckets % 2 == 0, ( + f"There should be an even number of buckets, but `self.num_buckets`: {self.num_buckets}" + ) + rotation_size = self.num_buckets + num_buckets = self.num_buckets + else: + # Factorize the hash if self.num_buckets is a list or tuple + rotation_size, num_buckets = 0, 1 + for bucket_factor in self.num_buckets: + assert bucket_factor % 2 == 0, ( + f"The number of buckets should be even, but `num_bucket`: {bucket_factor}" + ) + rotation_size = rotation_size + bucket_factor + num_buckets = num_buckets * bucket_factor + + # remove gradient + vectors = vectors.detach() + + if self.hash_seed is not None: + # for determinism + torch.manual_seed(self.hash_seed) + + rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) + # create a random self.attention_head_size x num_hashes x num_buckets/2 + random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype) + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 + rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) + + if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for bucket_factor in self.num_buckets: + rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)] + cur_sum = cur_sum + bucket_factor // 2 + rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1) + if buckets is None: + buckets = torch.argmax(rotated_vectors_factor, dim=-1) + else: + buckets = buckets + (cur_product * torch.argmax(rotated_vectors_factor, dim=-1)) + + cur_product = cur_product * bucket_factor + + if attention_mask is not None and (attention_mask.sum().item() < batch_size * attention_mask.shape[-1]): + # add an extra bucket for padding tokens only + num_buckets = num_buckets + 1 + # assign padding tokens extra bucket + buckets_mask = attention_mask.to(torch.bool)[:, None, None, :].expand(buckets.shape) + buckets = torch.where( + buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device) + ) + elif increase_num_buckets: + num_buckets = num_buckets + 1 + + # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). + # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. + offsets = torch.arange(num_hashes, device=vectors.device) + offsets = (offsets * num_buckets).view((1, 1, -1, 1)) + + # expand to batch size and num attention heads + offsets = offsets.expand((batch_size, self.num_attention_heads) + offsets.shape[-2:]) + offset_buckets = (buckets + offsets).flatten(start_dim=2, end_dim=3) + + return offset_buckets + + def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): + # no gradients are needed + with torch.no_grad(): + # hash-based sort + sorted_bucket_idx = _stable_argsort(buckets, dim=-1) + + # create simple indices to scatter to, to have undo sort + indices = ( + torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device) + .view(1, 1, -1) + .expand(sorted_bucket_idx.shape) + ) + + # get undo sort + undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) + undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices) + + return sorted_bucket_idx, undo_sorted_bucket_idx + + def _set_num_buckets(self, sequence_length): + # `num_buckets` should be set to 2 * sequence_length // chunk_length as recommended in paper + num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1 + # make sure buckets are power of 2 + num_buckets = 2**num_buckets_pow_2 + + # factorize `num_buckets` if `num_buckets` becomes too large + num_buckets_limit = 2 * max( + int((self.max_position_embeddings // self.chunk_length) ** (0.5)), + self.chunk_length, + ) + if num_buckets > num_buckets_limit: + num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)] + + logger.warning(f"config.num_buckets is not set. Setting config.num_buckets to {num_buckets}...") + + # set num buckets in config to be properly saved + self.config.num_buckets = num_buckets + self.num_buckets = num_buckets + + def _attend( + self, + query_vectors, + key_vectors, + value_vectors, + sorted_bucket_idx_per_hash, + attention_mask, + head_mask, + do_standard_self_attention, + use_cache, + ): + # look at previous and following chunks if chunked attention + if not do_standard_self_attention: + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + + # get logits and dots + # (BS, NumAttn, NumHash x NumChunk, Chunk_L x Hidden),(BS, NumAttn, NumHash x NumChunk, Chunk_L * (Num_bef + Num_aft + 1) x Hidden) -> (BS, NumAttn, NumHash x NumChunk, Chunk_L, Chunk_L * (1 + Num_bef + Num_aft)) + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + # free memory + del query_vectors, key_vectors + + # if chunked attention split bucket idxs to query and key + if not do_standard_self_attention: + query_bucket_idx = self._split_seq_length_dim_to( + sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads + ) + key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) + elif use_cache and query_key_dots.ndim > 4: + key_value_bucket_idx = sorted_bucket_idx_per_hash + query_bucket_idx = ( + key_value_bucket_idx.new_ones(key_value_bucket_idx.shape[:-1] + (1,)) * key_value_bucket_idx.max() + ) + elif use_cache and query_key_dots.ndim <= 4: + query_bucket_idx = (query_key_dots.shape[-1] - 1) * torch.ones_like(query_key_dots)[:, :, :, -1] + key_value_bucket_idx = torch.arange( + query_key_dots.shape[-1], dtype=torch.long, device=query_key_dots.device + )[None, None, :].expand(query_bucket_idx.shape[:2] + (-1,)) + else: + query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx_per_hash + + # get correct mask values depending on precision + if query_key_dots.dtype == torch.float16: + self_mask_value = self.self_mask_value_float16.half() + mask_value = self.mask_value_float16.half() + else: + self_mask_value = self.self_mask_value_float32 + mask_value = self.mask_value_float32 + + if not use_cache: + mask = self._compute_attn_mask( + query_bucket_idx, + key_value_bucket_idx, + attention_mask, + query_key_dots.shape, + do_standard_self_attention, + ) + + if mask is not None: + query_key_dots = torch.where(mask, query_key_dots, mask_value) + + # free memory + del mask + + # Self mask is ALWAYS applied. + # From the reformer paper (https://huggingface.co/papers/2001.04451): + # " While attention to the future is not allowed, typical implementations of the + # Transformer do allow a position to attend to itself. + # Such behavior is undesirable in a shared-QK formulation because the dot-product + # of a query vector with itself will almost always be greater than the dot product of a + # query vector with a vector at another position. We therefore modify the masking + # to forbid a token from attending to itself, except in situations + # where a token has no other valid attention targets (e.g. the first token in a sequence) " + + self_mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to( + query_bucket_idx.device + ) + + # apply self_mask + query_key_dots = torch.where(self_mask, query_key_dots, self_mask_value) + + # free memory + del self_mask + + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` + attention_probs = torch.exp(query_key_dots - logits) + + # free memory + del query_key_dots + + # dropout + attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # attend values + out_vectors = torch.matmul(attention_probs, value_vectors) + + # free memory + del value_vectors + + # merge chunk length + if out_vectors.ndim > 4: + logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) + + return out_vectors, logits, attention_probs + + def _compute_attn_mask( + self, query_indices, key_indices, attention_mask, query_key_dot_shape, do_standard_self_attention + ): + # attention mask for LSH + if attention_mask is not None: + # if chunked attention, the attention mask has to correspond to LSH order + attention_mask = attention_mask.to(torch.bool)[:, None, :] + if not do_standard_self_attention: + # expand attn_mask to fit with key_value_bucket_idx shape + attention_mask = attention_mask[:, None, :] + attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) + # extract attention mask from LSH sorted key_indices + attention_mask = torch.gather(attention_mask, -1, key_indices) + + attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dot_shape) + + # Causal mask + if self.is_decoder is True: + causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) + + # add attention mask if not None + if attention_mask is not None: + attention_mask = causal_mask * attention_mask + else: + attention_mask = causal_mask + + return attention_mask + + def _get_relevant_hid_states_and_buckets( + self, query_vectors, attention_mask, num_hashes, hidden_states, past_states, past_buckets + ): + # concat hidden states + hidden_states = torch.cat([past_states, hidden_states], dim=1) + + # batch_size hidden + batch_size = hidden_states.shape[0] + sequence_length = hidden_states.shape[1] + + # check if cached buckets include pad bucket + max_bucket = self.num_buckets if isinstance(self.num_buckets, int) else reduce(mul, self.num_buckets) + + # if pad bucket was cached => need to increase num buckets for caching + increase_num_buckets = past_buckets.max() > num_hashes * max_bucket - 1 + + # retrieve query buckets + query_buckets = self._hash_vectors( + query_vectors, num_hashes, attention_mask, increase_num_buckets=increase_num_buckets + ) + + # concat buckets + concat_buckets = torch.cat([past_buckets, query_buckets.unsqueeze(-1)], dim=-1) + + # hash-based sort + bucket_idx = _stable_argsort(concat_buckets, dim=-1) + + # bucket_idx has shape: BatchSize x NumAttnHeads x NumHashes x SequenceLength + assert bucket_idx.shape == ( + batch_size, + self.num_attention_heads, + num_hashes, + sequence_length, + ), ( + f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but" + f" has shape {bucket_idx.shape}." + ) + + # find indices of new bucket indices + relevant_bucket_idx = (bucket_idx == (bucket_idx.shape[-1] - 1)).nonzero() + + # expand relevant bucket indices to its chunks + relevant_bucket_idx_chunk = self._expand_to_indices_in_relevant_chunk(relevant_bucket_idx, sequence_length) + relevant_bucket_idx_chunk = bucket_idx[tuple(relevant_bucket_idx_chunk.transpose(0, 1))] + + # adapt bucket_idx for batch and hidden states for index select + offset = torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long) + bucket_idx_batch_offset = sequence_length * ( + batch_size * torch.div(offset, relevant_bucket_idx_chunk.shape[-1], rounding_mode="floor") + ) + + # add batch offset + relevant_bucket_idx_chunk_all_batch = relevant_bucket_idx_chunk + bucket_idx_batch_offset + hidden_states = hidden_states.reshape((-1, self.hidden_size)) + + # select all relevant hidden states + relevant_hidden_states = hidden_states.index_select(0, relevant_bucket_idx_chunk_all_batch) + + # reshape hidden states and bucket_idx to correct output + relevant_hidden_states = relevant_hidden_states.reshape( + batch_size, self.num_attention_heads, -1, self.hidden_size + ) + relevant_bucket_idx_chunk = relevant_bucket_idx_chunk.reshape( + batch_size, self.num_attention_heads, num_hashes, -1 + ) + + assert ( + relevant_hidden_states.shape[2] + == (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes + ), ( + "There should be" + f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`," + f" there are {relevant_hidden_states.shape[2]} `hidden_states`." + ) + + assert ( + relevant_bucket_idx_chunk.shape[-1] + == (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length + ), ( + "There should be" + f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are" + f" {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`." + ) + + return relevant_hidden_states, relevant_bucket_idx_chunk, query_buckets + + def _expand_to_indices_in_relevant_chunk(self, indices, sequence_length): + # get relevant indices of where chunk starts and its size + start_indices_chunk = ((indices[:, -1] // self.chunk_length) - self.num_chunks_before) * self.chunk_length + total_chunk_size = self.chunk_length * (1 + self.num_chunks_before + self.num_chunks_after) + + # expand start indices and add correct chunk offset via arange + expanded_start_indices = start_indices_chunk.unsqueeze(-1).expand(indices.shape[0], total_chunk_size) + chunk_sequence_indices = expanded_start_indices + torch.arange( + total_chunk_size, device=indices.device, dtype=torch.long + ).unsqueeze(0).expand(indices.shape[0], total_chunk_size) + + # make sure that circular logic holds via % seq len + chunk_sequence_indices = chunk_sequence_indices.flatten() % sequence_length + + # expand indices and set indices correctly + indices = indices.unsqueeze(1).expand((indices.shape[0], total_chunk_size, -1)).flatten(0, 1).clone() + indices[:, -1] = chunk_sequence_indices + + return indices + + def _len_and_dim_norm(self, vectors, sqrt_num): + """ + length and attention head size dim normalization + """ + vectors = self._len_norm(vectors) + vectors = vectors / sqrt_num + return vectors + + def _len_norm(self, x, epsilon=1e-6): + """ + length normalization + """ + variance = torch.mean(x**2, -1, keepdim=True) + norm_x = x * torch.rsqrt(variance + epsilon) + return norm_x + + def _gather_by_expansion(self, vectors, idxs, num_hashes): + """ + expand dims of idxs and vectors for all hashes and gather + """ + expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + vectors = vectors.repeat(1, 1, num_hashes, 1) + return torch.gather(vectors, 2, expanded_idxs) + + +class ReverseSort(Function): + """ + After chunked attention is applied which sorted clusters, original ordering has to be restored. Since customized + backward function is used for Reformer, the gradients of the output vectors have to be explicitly sorted here. + """ + + @staticmethod + def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx): + # save sorted_bucket_idx for backprop + with torch.no_grad(): + ctx.sorted_bucket_idx = sorted_bucket_idx + + # undo sort to have correct order for next layer + expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_sorted_bucket_idx) + return out_vectors, logits + + @staticmethod + def backward(ctx, grad_out_vectors, grad_logits): + # get parameters saved in ctx + sorted_bucket_idx = ctx.sorted_bucket_idx + + expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape) + # reverse sort of forward + grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices) + grad_logits = torch.gather(grad_logits, 2, sorted_bucket_idx) + + # return grad and `None` fillers for last 2 forward args + return grad_out_vectors, grad_logits, None, None + + +class LocalSelfAttention(nn.Module, EfficientAttentionMixin): + def __init__(self, config, layer_idx=None): + super().__init__() + + self.num_attention_heads = config.num_attention_heads + self.chunk_length = config.local_attn_chunk_length + self.num_chunks_before = config.local_num_chunks_before + self.num_chunks_after = config.local_num_chunks_after + self.is_decoder = config.is_decoder + self.pad_token_id = config.pad_token_id + + self.attention_head_size = config.attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # projection matrices + self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + + self.dropout = config.local_attention_probs_dropout_prob + + # save mask value here + self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False) + self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + past_buckets_states=None, + use_cache=False, + output_attentions=False, + **kwargs, + ): + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + # check if cache shall be used and that hidden states are already cached + if past_buckets_states is not None and len(past_buckets_states) > self.layer_idx: + past_buckets = past_buckets_states.buckets_cache[self.layer_idx] + past_states = past_buckets_states.states_cache[self.layer_idx] + + assert past_buckets.numel() == 0, ( + "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching" + " hidden_states_and_buckets." + ) + key_value_hidden_states = self._retrieve_relevant_hidden_states( + past_states, self.chunk_length, self.num_chunks_before + ) + key_value_hidden_states = torch.cat([key_value_hidden_states, hidden_states], dim=1) + + # only query vector for last token + query_vectors = self.query(hidden_states) + # compute key and value for relevant chunk + key_vectors = self.key(key_value_hidden_states) + value_vectors = self.value(key_value_hidden_states) + + # free memory + del key_value_hidden_states + else: + # project hidden_states to query, key and value + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + # split last dim into `config.num_attention_heads` and `config.attention_head_size` + query_vectors = self._split_hidden_size_dim(query_vectors, self.num_attention_heads, self.attention_head_size) + key_vectors = self._split_hidden_size_dim(key_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) + + assert query_vectors.shape[-1] == self.attention_head_size, ( + f"last dim of query_key_vectors is {query_vectors.shape[-1]} but should be {self.attention_head_size}." + ) + assert key_vectors.shape[-1] == self.attention_head_size, ( + f"last dim of query_key_vectors is {key_vectors.shape[-1]} but should be {self.attention_head_size}." + ) + assert value_vectors.shape[-1] == self.attention_head_size, ( + f"last dim of query_key_vectors is {value_vectors.shape[-1]} but should be {self.attention_head_size}." + ) + + if self.chunk_length is None: + assert self.num_chunks_before == 0 and self.num_chunks_after == 0, ( + "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and" + " `config.num_chunks_before` are set to 0." + ) + + # normalize key vectors + key_vectors = key_vectors / np.sqrt(self.attention_head_size) + + # get sequence length indices + indices = torch.arange(sequence_length, device=query_vectors.device).repeat( + batch_size, self.num_attention_heads, 1 + ) + + # if one should do normal n^2 self-attention + do_standard_self_attention = sequence_length <= self.chunk_length + + # if input should be chunked + if not do_standard_self_attention: + # chunk vectors + # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size + query_vectors = self._split_seq_length_dim_to( + query_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, + ) + key_vectors = self._split_seq_length_dim_to( + key_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, + ) + + # chunk indices + query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + + # append chunks before and after + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) + else: + query_indices = key_indices = indices + + # query-key matmul: QK^T + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + # free memory + del query_vectors, key_vectors + + mask = self._compute_attn_mask( + query_indices, key_indices, attention_mask, query_key_dots.shape, do_standard_self_attention + ) + + if mask is not None: + # get mask tensor depending on half precision or not + if query_key_dots.dtype == torch.float16: + mask_value = self.mask_value_float16.half() + else: + mask_value = self.mask_value_float32 + + query_key_dots = torch.where(mask, query_key_dots, mask_value) + + # free memory + del mask + + # softmax + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + attention_probs = torch.exp(query_key_dots - logits) + + # free memory + del logits + + # dropout + attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # attend values + out_vectors = torch.matmul(attention_probs, value_vectors) + + # free memory + del value_vectors + + # merge chunk length + if not do_standard_self_attention: + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) + + assert out_vectors.shape == ( + batch_size, + self.num_attention_heads, + sequence_length, + self.attention_head_size, + ) + + out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) + + if output_attentions is False: + attention_probs = () + + return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) + + def _compute_attn_mask( + self, query_indices, key_indices, attention_mask, query_key_dots_shape, do_standard_self_attention + ): + # chunk attention mask and look before and after + if attention_mask is not None: + attention_mask = attention_mask.to(torch.bool)[:, None, :] + + if not do_standard_self_attention: + attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) + attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + # create attn_mask + attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dots_shape) + + # Causal mask + if self.is_decoder is True: + causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) + + # add attention mask if not None + if attention_mask is not None: + attention_mask = causal_mask * attention_mask + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _retrieve_relevant_hidden_states(previous_hidden_states, chunk_length, num_chunks_before): + start_position = ((previous_hidden_states.shape[1] // chunk_length) - num_chunks_before) * chunk_length + return previous_hidden_states[:, start_position:] + + +class ReformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + all_head_size = config.num_attention_heads * config.attention_head_size + self.dropout = config.hidden_dropout_prob + + self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class ReformerAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.layer_id = layer_id + self.attn_layers = config.attn_layers + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": + self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) + elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": + self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) + elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == {"lsh", "local"}: + # get correct attn layers + if self.attn_layers[self.layer_id] == "lsh": + self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) + else: + self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) + else: + raise NotImplementedError( + f"Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {self.attn_layers}. " + "Select attn layer types from ['lsh', 'local'] only." + ) + self.output = ReformerSelfOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + past_buckets_states=None, + use_cache=False, + orig_sequence_length=None, + output_attentions=False, + buckets=None, + cache_position=None, + ): + hidden_states = self.layer_norm(hidden_states) + + # use cached buckets for backprob if buckets not None for LSHSelfAttention + self_attention_outputs = self.self_attention( + hidden_states=hidden_states, + head_mask=head_mask, + attention_mask=attention_mask, + num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + output_attentions=output_attentions, + buckets=buckets, + cache_position=cache_position, + ) + + # add buckets if necessary + if hasattr(self_attention_outputs, "buckets"): + buckets = self_attention_outputs.buckets + else: + buckets = None + + # cache hidden states for future use + if use_cache and past_buckets_states is not None: + # padded input should not be cached during prefill + states = ( + hidden_states[:, :orig_sequence_length] + if len(past_buckets_states.states_cache) <= self.layer_id + else hidden_states + ) + buckets = ( + buckets[:, :, :, :orig_sequence_length] + if ( + len(past_buckets_states.buckets_cache) <= self.layer_id + and buckets is not None + and orig_sequence_length > 1 + ) + else buckets + ) + buckets, hidden_states = past_buckets_states.update( + buckets, states[:, :orig_sequence_length], self.layer_id + ) + + # compute attention feed forward output + attention_output = self.output(self_attention_outputs.hidden_states) + + return AttentionOutput( + hidden_states=attention_output, + attention_probs=self_attention_outputs.attention_probs, + buckets=buckets, + ) + + +class ReformerFeedForwardDense(nn.Module): + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + + if isinstance(config.hidden_act, str): + self.act_fn = ACT2FN[config.hidden_act] + else: + self.act_fn = config.hidden_act + + self.dense = nn.Linear(config.hidden_size, config.feed_forward_size) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.act_fn(hidden_states) + return hidden_states + + +class ReformerFeedForwardOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + + self.dense = nn.Linear(config.feed_forward_size, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class ChunkReformerFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = ReformerFeedForwardDense(config) + self.output = ReformerFeedForwardOutput(config) + + def forward(self, attention_output): + return apply_chunking_to_forward( + self.forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + + def forward_chunk(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dense(hidden_states) + return self.output(hidden_states) + + +class ReformerLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.attention = ReformerAttention(config, layer_id) + # dropout requires to have the same + # seed for forward and backward pass + self.attention_seed = None + self.feed_forward_seed = None + + self.feed_forward = ChunkReformerFeedForward(config) + + def _init_attention_seed(self): + """ + This function sets a new seed for the attention layer to make dropout deterministic for both forward calls: 1 + normal forward call and 1 forward call in backward to recalculate activations. + """ + + # randomize seeds + # use cuda generator if available + if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0: + # GPU + device_idx = torch.cuda.current_device() + self.attention_seed = torch.cuda.default_generators[device_idx].seed() + else: + # CPU + self.attention_seed = int(torch.seed() % sys.maxsize) + + torch.manual_seed(self.attention_seed) + + def _init_feed_forward_seed(self): + """ + This function sets a new seed for the feed forward layer to make dropout deterministic for both forward calls: + 1 normal forward call and 1 forward call in backward to recalculate activations. + """ + # randomize seeds + # use cuda generator if available + if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0: + # GPU + device_idx = torch.cuda.current_device() + self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() + else: + # CPU + self.feed_forward_seed = int(torch.seed() % sys.maxsize) + + torch.manual_seed(self.feed_forward_seed) + + def forward( + self, + prev_attn_output, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + past_buckets_states=None, + use_cache=False, + orig_sequence_length=None, + output_attentions=False, + ): + with torch.no_grad(): + # every forward pass we sample a different seed + # for dropout and save for forward fn in backward pass + # to have correct dropout + if self.training: + self._init_attention_seed() + + attn_outputs = self.attention( + hidden_states=hidden_states, + head_mask=head_mask, + attention_mask=attention_mask, + num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + orig_sequence_length=orig_sequence_length, + output_attentions=output_attentions, + ) + attn_output = attn_outputs.hidden_states + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # Y_1 = X_1 + f(X_2) + attn_output = prev_attn_output + attn_output + + # free memory + del prev_attn_output + + # every forward pass we sample a different seed + # for dropout and save seed for forward fn in backward + # to have correct dropout + if self.training: + self._init_feed_forward_seed() + # Y_2 = X_2 + g(Y_1) + hidden_states = hidden_states + self.feed_forward(attn_output) + + return ReformerOutput( + attn_output=attn_output, + hidden_states=hidden_states, + attention_probs=attn_outputs.attention_probs, + buckets=attn_outputs.buckets, + ) + + def backward_pass( + self, + next_attn_output, + hidden_states, + grad_attn_output, + grad_hidden_states, + attention_mask=None, + head_mask=None, + buckets=None, + ): + # Implements the backward pass for reversible ResNets. + # A good blog post on how this works can be found here: + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py + + assert self.training, ( + "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the" + " model into training mode." + ) + + with torch.enable_grad(): + next_attn_output.requires_grad = True + + # set seed to have correct dropout + torch.manual_seed(self.feed_forward_seed) + # g(Y_1) + res_hidden_states = self.feed_forward(next_attn_output) + res_hidden_states.backward(grad_hidden_states, retain_graph=True) + + with torch.no_grad(): + # X_2 = Y_2 - g(Y_1) + hidden_states = hidden_states - res_hidden_states + del res_hidden_states + + grad_attn_output = grad_attn_output + next_attn_output.grad + next_attn_output.grad = None + + with torch.enable_grad(): + hidden_states.requires_grad = True + + # set seed to have correct dropout + torch.manual_seed(self.attention_seed) + # f(X_2) + # use cached buckets for backprob if buckets not None for LSHSelfAttention + output = self.attention( + hidden_states=hidden_states, + head_mask=head_mask, + attention_mask=attention_mask, + buckets=buckets, + ).hidden_states + output.backward(grad_attn_output, retain_graph=True) + + with torch.no_grad(): + # X_1 = Y_1 - f(X_2) + attn_output = next_attn_output - output + del output, next_attn_output + + grad_hidden_states = grad_hidden_states + hidden_states.grad + hidden_states.grad = None + hidden_states = hidden_states.detach() + + return ReformerBackwardOutput( + attn_output=attn_output, + hidden_states=hidden_states, + grad_attn_output=grad_attn_output, + grad_hidden_states=grad_hidden_states, + ) + + +class _ReversibleFunction(Function): + """ + To prevent PyTorch from performing the usual backpropagation, a customized backward function is implemented here. + This way it is made sure that no memory expensive activations are saved during the forward pass. This function is + heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py + """ + + @staticmethod + def forward( + ctx, + hidden_states, + layers, + attention_mask, + head_mask, + num_hashes, + all_hidden_states, + all_attentions, + past_buckets_states, + use_cache, + orig_sequence_length, + output_hidden_states, + output_attentions, + ): + all_buckets = () + + # split duplicated tensor + hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) + + for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)): + if output_hidden_states is True: + all_hidden_states.append(hidden_states) + + layer_outputs = layer( + prev_attn_output=attn_output, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + orig_sequence_length=orig_sequence_length, + output_attentions=output_attentions, + ) + + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + all_buckets = all_buckets + (layer_outputs.buckets,) + + if output_attentions: + all_attentions.append(layer_outputs.attention_probs) + + # Add last layer + if output_hidden_states is True: + all_hidden_states.append(hidden_states) + + # attach params to ctx for backward + ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) + ctx.layers = layers + ctx.all_buckets = all_buckets + ctx.head_mask = head_mask + ctx.attention_mask = attention_mask + + # Concatenate 2 RevNet outputs + return torch.cat([attn_output, hidden_states], dim=-1) + + @staticmethod + def backward(ctx, grad_hidden_states): + grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) + + # retrieve params from ctx for backward + attn_output, hidden_states = ctx.saved_tensors + + # create tuple + output = ReformerBackwardOutput( + attn_output=attn_output, + hidden_states=hidden_states, + grad_attn_output=grad_attn_output, + grad_hidden_states=grad_hidden_states, + ) + + # free memory + del grad_attn_output, grad_hidden_states, attn_output, hidden_states + + layers = ctx.layers + all_buckets = ctx.all_buckets + head_mask = ctx.head_mask + attention_mask = ctx.attention_mask + + for idx, layer in enumerate(layers[::-1]): + # pop last buckets from stack + buckets = all_buckets[-1] + all_buckets = all_buckets[:-1] + + # backprop + output = layer.backward_pass( + next_attn_output=output.attn_output, + hidden_states=output.hidden_states, + grad_attn_output=output.grad_attn_output, + grad_hidden_states=output.grad_hidden_states, + head_mask=head_mask[len(layers) - idx - 1], + attention_mask=attention_mask, + buckets=buckets, + ) + + assert all_buckets == (), "buckets have to be empty after backpropagation" + grad_hidden_states = torch.cat([output.grad_attn_output, output.grad_hidden_states], dim=-1) + + # num of return vars has to match num of forward() args + # return gradient for hidden_states arg and None for other args + return grad_hidden_states, None, None, None, None, None, None, None, None, None, None, None + + +class ReformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + + self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.layer_norm = nn.LayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + past_buckets_states=None, + use_cache=False, + orig_sequence_length=None, + output_hidden_states=False, + output_attentions=False, + ): + # hidden_states and attention lists to be filled if wished + all_hidden_states = [] + all_attentions = [] + + # init cached hidden states if necessary + if use_cache and past_buckets_states is None: + past_buckets_states = ReformerDynamicCache() + elif use_cache and isinstance(past_buckets_states, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `ReformerDynamicCache` instead, e.g. " + "`past_key_values=ReformerDynamicCache.from_legacy_cache(past_key_values)`." + ) + past_buckets_states = ReformerDynamicCache.from_legacy_cache(past_buckets_states) + + # concat same tensor for reversible ResNet + hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) + hidden_states = _ReversibleFunction.apply( + hidden_states, + self.layers, + attention_mask, + head_mask, + num_hashes, + all_hidden_states, + all_attentions, + past_buckets_states, + use_cache, + orig_sequence_length, + output_hidden_states, + output_attentions, + ) + + # Apply layer norm to concatenated hidden states + hidden_states = self.layer_norm(hidden_states) + + # Apply dropout + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + next_cache = past_buckets_states if use_cache else None + + return ReformerEncoderOutput( + hidden_states=hidden_states, + all_hidden_states=all_hidden_states, + all_attentions=all_attentions, + past_buckets_states=next_cache, + ) + + +class ReformerOnlyLMHead(nn.Module): + def __init__(self, config): + super().__init__() + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.seq_len_dim = 1 + self.chunk_size_lm_head = config.chunk_size_lm_head + self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, hidden_states): + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) + + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + def _tie_weights(self) -> None: + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + + +@auto_docstring +class ReformerPreTrainedModel(PreTrainedModel): + config: ReformerConfig + base_model_prefix = "reformer" + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, AxialPositionEmbeddings): + for weight in module.weights: + nn.init.normal_(weight, std=self.config.axial_norm_std) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`ReformerModel`]. + """ +) +class ReformerModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_predict, hidden_size)`): + Sequence of hidden-states at the last layer of the model. + + `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` + corresponds to `sequence_length`. + past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element + being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the + second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). + + Contains precomputed buckets and hidden-states that can be used (see `past_buckets_states` input) to speed + up sequential decoding. + """ + + last_hidden_state: torch.FloatTensor + past_buckets_states: Optional[list[tuple[torch.LongTensor, torch.FloatTensor]]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`ReformerModelWithLMHead`]. + """ +) +class ReformerModelWithLMHeadOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, num_predict, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` + corresponds to `sequence_length`. + past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element + being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the + second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). + + Contains precomputed buckets and hidden-states that can be used (see `past_buckets_states` input) to speed + up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_buckets_states: Optional[list[tuple[torch.LongTensor, torch.FloatTensor]]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class ReformerModel(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + assert self.config.num_hidden_layers > 0, ( + "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']" + ) + + self.embeddings = ReformerEmbeddings(config) + self.encoder = ReformerEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + num_hashes: Optional[int] = None, + past_buckets_states: Optional[list[tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ReformerModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be + a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices + are automatically padded to be a multiple of the chunk length. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + num_hashes (`int`, *optional*): + The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites + the default defined in `config.num_hashes`. + + For more information, see `num_hashes` in [`ReformerConfig`]. + past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*): + List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element + being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the + second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). + + Contains precomputed hidden-states and buckets (only relevant for LSH Self-Attention). Can be used to speed + up sequential decoding. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + assert len(input_shape) == 2, ( + f"`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {input_shape}" + ) + + if past_buckets_states is not None: + assert not self.training, "`past_buckets_states` can only be used for inference, not for training`." + + # prepare head mask + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True) + + # original sequence length for padding + orig_sequence_length = input_shape[-1] + + # if needs padding + least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) + min_chunk_length = _get_min_chunk_len(self.config) + + must_pad_to_match_chunk_length = ( + input_shape[-1] % least_common_mult_chunk_length != 0 + and input_shape[-1] > min_chunk_length + and past_buckets_states is None + ) + + if must_pad_to_match_chunk_length: + padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length + + if self.training is True: + raise ValueError( + f"If training, sequence length {input_shape[-1]} has to be a multiple of least common multiple " + f"chunk_length {least_common_mult_chunk_length}. Please consider padding the input to a length " + f"of {input_shape[-1] + padding_length}." + ) + + # pad input + input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length( + input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + input_shape=input_shape, + padding_length=padding_length, + padded_seq_length=least_common_mult_chunk_length, + device=device, + ) + + # start index for position encoding depends on incremental decoding + if past_buckets_states is not None: + start_idx_pos_encodings = past_buckets_states[0][1].shape[1] + else: + start_idx_pos_encodings = 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + start_idx_pos_encodings=start_idx_pos_encodings, + ) + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + head_mask=head_mask, + attention_mask=attention_mask, + num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + orig_sequence_length=orig_sequence_length, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + sequence_output = encoder_outputs.hidden_states + + # if padding was applied + if must_pad_to_match_chunk_length: + sequence_output = sequence_output[:, :orig_sequence_length] + + past_buckets_states = encoder_outputs.past_buckets_states if use_cache else None + hidden_states = encoder_outputs.all_hidden_states if output_hidden_states else None + attentions = encoder_outputs.all_attentions if output_attentions else None + + if not return_dict: + return tuple(v for v in [sequence_output, past_buckets_states, hidden_states, attentions] if v is not None) + return ReformerModelOutput( + last_hidden_state=sequence_output, + past_buckets_states=past_buckets_states, + hidden_states=hidden_states, + attentions=attentions, + ) + + def _pad_to_mult_of_chunk_length( + self, + input_ids, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + input_shape=None, + padding_length=None, + padded_seq_length=None, + device=None, + ): + logger.warning_once( + f"Input ids are automatically padded from {input_shape[-1]} to {input_shape[-1] + padding_length} to be a " + f"multiple of `config.chunk_length`: {padded_seq_length}" + ) + + padded_input_ids = torch.full( + (input_shape[0], padding_length), + self.config.pad_token_id, + device=device, + dtype=torch.long, + ) + + # Extend `attention_mask` + if attention_mask is not None: + pad_attention_mask = torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype) + + attention_mask = torch.cat([attention_mask, pad_attention_mask], dim=-1) + else: + attention_mask = torch.cat( + [ + torch.ones(input_shape, device=device, dtype=torch.bool), + torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool), + ], + dim=-1, + ) + + # Extend `input_ids` with padding to match least common multiple chunk_length + if input_ids is not None: + input_ids = torch.cat([input_ids, padded_input_ids], dim=-1) + input_shape = input_ids.size() + + # Pad position ids if given + if position_ids is not None: + padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device) + padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length) + position_ids = torch.cat([position_ids, padded_position_ids], dim=-1) + + # Extend `inputs_embeds` with padding to match least common multiple chunk_length + if inputs_embeds is not None: + padded_inputs_embeds = self.get_input_embeddings()(padded_input_ids) + inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) + input_shape = inputs_embeds.size() + return input_ids, inputs_embeds, attention_mask, position_ids, input_shape + + +@auto_docstring( + custom_intro=""" + Reformer Model with a `language modeling` head on top. + """ +) +class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`." + assert "local" not in self.config.attn_layers or config.local_num_chunks_after == 0, ( + "If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not" + f" {config.local_num_chunks_after}." + ) + assert "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0, ( + "If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not" + f" {config.lsh_num_chunks_after}." + ) + + self.reformer = ReformerModel(config) + self.lm_head = ReformerOnlyLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + num_hashes: Optional[int] = None, + past_buckets_states: Optional[list[tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, CausalLMOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be + a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices + are automatically padded to be a multiple of the chunk length. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + num_hashes (`int`, *optional*): + The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites + the default defined in `config.num_hashes`. + + For more information, see `num_hashes` in [`ReformerConfig`]. + past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*): + List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element + being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the + second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). + + Contains precomputed hidden-states and buckets (only relevant for LSH Self-Attention). Can be used to speed + up sequential decoding. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + reformer_outputs = self.reformer( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + sequence_output = reformer_outputs[0] + logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + reformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ReformerModelWithLMHeadOutput( + loss=loss, + logits=logits, + past_buckets_states=reformer_outputs.past_buckets_states, + hidden_states=reformer_outputs.hidden_states, + attentions=reformer_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs + ): + # Overitten -- different expected inputs/outputs + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + model_inputs = { + "input_ids": input_ids, + "past_buckets_states": past_key_values, + "use_cache": use_cache, + "num_hashes": num_hashes, + } + + # Attention mask is computed on ReformerModel.forward() + kwargs.pop("attention_mask", None) + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + print(f"Warning: {key} is not a recognized input.") + model_inputs[key] = value + + return model_inputs + + def _reorder_cache(self, past_key_values, beam_idx): + reord_past_buckets_states = [] + for buckets, hidden_states in past_key_values: + # buckets + if buckets is not None and buckets.numel() > 0: + reord_buckets = buckets.index_select(0, beam_idx.to(buckets.device)) + else: + reord_buckets = None + + # hidden states + reord_hidden_states = hidden_states.index_select(0, beam_idx.to(hidden_states.device)) + reord_past_buckets_states.append((reord_buckets, reord_hidden_states)) + + if isinstance(past_key_values, ReformerDynamicCache): + reord_past_buckets_states = ReformerDynamicCache.from_legacy_cache(reord_past_buckets_states) + return reord_past_buckets_states + + +@auto_docstring +class ReformerForMaskedLM(ReformerPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + assert not config.is_decoder, ( + "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional" + " self-attention." + ) + self.reformer = ReformerModel(config) + self.lm_head = ReformerOnlyLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + num_hashes: Optional[int] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be + a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices + are automatically padded to be a multiple of the chunk length. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + num_hashes (`int`, *optional*): + The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites + the default defined in `config.num_hashes`. + + For more information, see `num_hashes` in [`ReformerConfig`]. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels + + + + This example uses a false checkpoint since we don't have any available pretrained model for the masked language + modeling task with the Reformer architecture. + + + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, ReformerForMaskedLM + + >>> tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer") + >>> model = ReformerForMaskedLM.from_pretrained("hf-internal-testing/tiny-random-reformer") + + >>> # add mask_token + >>> tokenizer.add_special_tokens({"mask_token": "[MASK]"}) # doctest: +IGNORE_RESULT + >>> inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt") + + >>> # resize model's embedding matrix + >>> model.resize_token_embeddings(new_num_tokens=model.config.vocab_size + 1) # doctest: +IGNORE_RESULT + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # retrieve index of [MASK] + >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] + + >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) + >>> predicted_token = tokenizer.decode(predicted_token_id) + ``` + + ```python + >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] + >>> # mask labels of non-[MASK] tokens + >>> labels = torch.where( + ... inputs.input_ids == tokenizer.mask_token_id, labels[:, : inputs["input_ids"].shape[-1]], -100 + ... ) + + >>> outputs = model(**inputs, labels=labels) + >>> loss = round(outputs.loss.item(), 2) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + reformer_outputs = self.reformer( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + use_cache=False, # no causal mask + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + sequence_output = reformer_outputs[0] + logits = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + reformer_outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=reformer_outputs.hidden_states, + attentions=reformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Reformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class ReformerForSequenceClassification(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.reformer = ReformerModel(config) + self.classifier = ReformerClassificationHead(config) + if config.is_decoder is True: + logger.warning("You might want to disable causal masking for sequence classification") + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + num_hashes: Optional[int] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be + a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices + are automatically padded to be a multiple of the chunk length. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + num_hashes (`int`, *optional*): + The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites + the default defined in `config.num_hashes`. + + For more information, see `num_hashes` in [`ReformerConfig`]. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Example of single-label classification: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, ReformerForSequenceClassification + + >>> tokenizer = AutoTokenizer.from_pretrained("google/reformer-crime-and-punishment") + >>> model = ReformerForSequenceClassification.from_pretrained("google/reformer-crime-and-punishment") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_class_id = logits.argmax().item() + >>> label = model.config.id2label[predicted_class_id] + ``` + + ```python + >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` + >>> num_labels = len(model.config.id2label) + >>> model = ReformerForSequenceClassification.from_pretrained( + ... "google/reformer-crime-and-punishment", num_labels=num_labels + ... ) + + >>> labels = torch.tensor(1) + >>> loss = model(**inputs, labels=labels).loss + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.reformer( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class ReformerClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, hidden_states, **kwargs): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +@auto_docstring +class ReformerForQuestionAnswering(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.reformer = ReformerModel(config) + # 2 * config.hidden_size because we use reversible residual layers + self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + num_hashes: Optional[int] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, QuestionAnsweringModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be + a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices + are automatically padded to be a multiple of the chunk length. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + num_hashes (`int`, *optional*): + The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites + the default defined in `config.num_hashes`. + + For more information, see `num_hashes` in [`ReformerConfig`]. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + reformer_outputs = self.reformer( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + use_cache=False, # no causal mask + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + sequence_output = reformer_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + reformer_outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=reformer_outputs.hidden_states, + attentions=reformer_outputs.attentions, + ) + + +__all__ = [ + "ReformerAttention", + "ReformerForMaskedLM", + "ReformerForQuestionAnswering", + "ReformerForSequenceClassification", + "ReformerLayer", + "ReformerModel", + "ReformerModelWithLMHead", + "ReformerPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/reformer/tokenization_reformer.py b/venv/lib/python3.13/site-packages/transformers/models/reformer/tokenization_reformer.py new file mode 100644 index 0000000000000000000000000000000000000000..458b72df4ff6e84896efc01669fdea19a50560fc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/reformer/tokenization_reformer.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model Reformer.""" + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +@requires(backends=("sentencepiece",)) +class ReformerTokenizer(PreTrainedTokenizer): + """ + Construct a Reformer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) . + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + additional_special_tokens (`list[str]`, *optional*, defaults to `[]`): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + additional_special_tokens=[], + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self) -> dict[str, int]: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index < self.sp_model.get_piece_size(): + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["ReformerTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/reformer/tokenization_reformer_fast.py b/venv/lib/python3.13/site-packages/transformers/models/reformer/tokenization_reformer_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..d68528de5872c24985b8cbbda69c9d63c026130d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/reformer/tokenization_reformer_fast.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model Reformer.""" + +import os +from shutil import copyfile +from typing import Optional + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_reformer import ReformerTokenizer +else: + ReformerTokenizer = None + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +class ReformerTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" Reformer tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (`list[str]`, *optional*): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = ReformerTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + eos_token="", + unk_token="", + additional_special_tokens=[], + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + unk_token=unk_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["ReformerTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/resnet/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/resnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..625e93a25543701db300db039d736589f2148e6d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/resnet/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_resnet import * + from .modeling_flax_resnet import * + from .modeling_resnet import * + from .modeling_tf_resnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/resnet/configuration_resnet.py b/venv/lib/python3.13/site-packages/transformers/models/resnet/configuration_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f4dcc6e0c6bcadf3786035a0141fd2754a70e237 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/resnet/configuration_resnet.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ResNet model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class ResNetConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an + ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ResNet + [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`list[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"bottleneck"`): + The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or + `"bottleneck"` (used for larger models like resnet-50 and above). + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + downsample_in_bottleneck (`bool`, *optional*, defaults to `False`): + If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + ```python + >>> from transformers import ResNetConfig, ResNetModel + + >>> # Initializing a ResNet resnet-50 style configuration + >>> configuration = ResNetConfig() + + >>> # Initializing a model (with random weights) from the resnet-50 style configuration + >>> model = ResNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "resnet" + layer_types = ["basic", "bottleneck"] + + def __init__( + self, + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.layer_type = layer_type + self.hidden_act = hidden_act + self.downsample_in_first_stage = downsample_in_first_stage + self.downsample_in_bottleneck = downsample_in_bottleneck + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +class ResNetOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-3 + + +__all__ = ["ResNetConfig", "ResNetOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_flax_resnet.py b/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_flax_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a9418b7cf2957e5f3d52b4978014864fef30ed --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_flax_resnet.py @@ -0,0 +1,704 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithNoAttention, + FlaxBaseModelOutputWithPoolingAndNoAttention, + FlaxImageClassifierOutputWithNoAttention, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_resnet import ResNetConfig + + +RESNET_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`jax.numpy.float32` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`AutoImageProcessor.__call__`] for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Identity(nn.Module): + """Identity function.""" + + @nn.compact + def __call__(self, x, **kwargs): + return x + + +class FlaxResNetConvLayer(nn.Module): + out_channels: int + kernel_size: int = 3 + stride: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(self.kernel_size, self.kernel_size), + strides=self.stride, + padding=self.kernel_size // 2, + dtype=self.dtype, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype), + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetEmbeddings(nn.Module): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embedder = FlaxResNetConvLayer( + self.config.embedding_size, + kernel_size=7, + stride=2, + activation=self.config.hidden_act, + dtype=self.dtype, + ) + + self.max_pool = partial(nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) + + def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + num_channels = pixel_values.shape[-1] + if num_channels != self.config.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values, deterministic=deterministic) + embedding = self.max_pool(embedding) + return embedding + + +class FlaxResNetShortCut(nn.Module): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + out_channels: int + stride: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(1, 1), + strides=self.stride, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + return hidden_state + + +class FlaxResNetBasicLayerCollection(nn.Module): + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layer = [ + FlaxResNetConvLayer(self.out_channels, stride=self.stride, dtype=self.dtype), + FlaxResNetConvLayer(self.out_channels, activation=None, dtype=self.dtype), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + in_channels: int + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) + if should_apply_shortcut + else None + ) + self.layer = FlaxResNetBasicLayerCollection( + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.activation] + + def __call__(self, hidden_state, deterministic: bool = True): + residual = hidden_state + hidden_state = self.layer(hidden_state, deterministic=deterministic) + + if self.shortcut is not None: + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetBottleNeckLayerCollection(nn.Module): + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + reduction: int = 4 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + reduces_channels = self.out_channels // self.reduction + + self.layer = [ + FlaxResNetConvLayer(reduces_channels, kernel_size=1, dtype=self.dtype, name="0"), + FlaxResNetConvLayer(reduces_channels, stride=self.stride, dtype=self.dtype, name="1"), + FlaxResNetConvLayer(self.out_channels, kernel_size=1, activation=None, dtype=self.dtype, name="2"), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetBottleNeckLayer(nn.Module): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the + input by a factor of `reduction` in order to make the second `3x3` convolution faster. The last `1x1` convolution + remaps the reduced features to `out_channels`. + """ + + in_channels: int + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + reduction: int = 4 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) + if should_apply_shortcut + else None + ) + + self.layer = FlaxResNetBottleNeckLayerCollection( + self.out_channels, + stride=self.stride, + activation=self.activation, + reduction=self.reduction, + dtype=self.dtype, + ) + + self.activation_func = ACT2FN[self.activation] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + + if self.shortcut is not None: + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state = self.layer(hidden_state, deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetStageLayersCollection(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + config: ResNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + layer = FlaxResNetBottleNeckLayer if self.config.layer_type == "bottleneck" else FlaxResNetBasicLayer + + layers = [ + # downsampling is done in the first layer with stride of 2 + layer( + self.in_channels, + self.out_channels, + stride=self.stride, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + ] + + for i in range(self.depth - 1): + layers.append( + layer( + self.out_channels, + self.out_channels, + activation=self.config.hidden_act, + dtype=self.dtype, + name=str(i + 1), + ) + ) + + self.layers = layers + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = x + for layer in self.layers: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetStage(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + config: ResNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = FlaxResNetStageLayersCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + depth=self.depth, + dtype=self.dtype, + ) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + return self.layers(x, deterministic=deterministic) + + +class FlaxResNetStageCollection(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) + stages = [ + FlaxResNetStage( + self.config, + self.config.embedding_size, + self.config.hidden_sizes[0], + stride=2 if self.config.downsample_in_first_stage else 1, + depth=self.config.depths[0], + dtype=self.dtype, + name="0", + ) + ] + + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): + stages.append( + FlaxResNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) + ) + + self.stages = stages + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + hidden_state = stage_module(hidden_state, deterministic=deterministic) + + return hidden_state, hidden_states + + +class FlaxResNetEncoder(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.stages = FlaxResNetStageCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_state, hidden_states = self.stages( + hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic + ) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return FlaxBaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class FlaxResNetPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: ResNetConfig, + input_shape=(1, 224, 224, 3), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + rngs = {"params": rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False) + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + params: Optional[dict] = None, + train: bool = False, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + + # Handle any PRNG if needed + rngs = {} + + return self.module.apply( + { + "params": params["params"] if params is not None else self.params["params"], + "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], + }, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=["batch_stats"] if train else False, # Returning tuple with batch_stats only when train is True + ) + + +class FlaxResNetModule(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedder = FlaxResNetEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxResNetEncoder(self.config, dtype=self.dtype) + + # Adaptive average pooling used in resnet + self.pooler = partial( + nn.avg_pool, + padding=((0, 0), (0, 0)), + ) + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler( + last_hidden_state, + window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + ).transpose(0, 3, 1, 2) + + last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class FlaxResNetModel(FlaxResNetPreTrainedModel): + module_class = FlaxResNetModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxResNetModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50") + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxResNetModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxResNetModel, output_type=FlaxBaseModelOutputWithPoolingAndNoAttention, config_class=ResNetConfig +) + + +class FlaxResNetClassifierCollection(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return self.classifier(x) + + +class FlaxResNetForImageClassificationModule(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.resnet = FlaxResNetModule(config=self.config, dtype=self.dtype) + + if self.config.num_labels > 0: + self.classifier = FlaxResNetClassifierCollection(self.config, dtype=self.dtype) + else: + self.classifier = Identity() + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet( + pixel_values, + deterministic=deterministic, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output[:, :, 0, 0]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class FlaxResNetForImageClassification(FlaxResNetPreTrainedModel): + module_class = FlaxResNetForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxResNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxResNetForImageClassification, output_type=FlaxImageClassifierOutputWithNoAttention, config_class=ResNetConfig +) + + +__all__ = ["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_resnet.py b/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..59a509fe03cd538487ee44eea7f59a2b4177407c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_resnet.py @@ -0,0 +1,435 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ResNet model.""" + +import math +from typing import Optional + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.backbone_utils import BackboneMixin +from .configuration_resnet import ResNetConfig + + +logger = logging.get_logger(__name__) + + +class ResNetConvLayer(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu" + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else nn.Identity() + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetEmbeddings(nn.Module): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: ResNetConfig): + super().__init__() + self.embedder = ResNetConvLayer( + config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act + ) + self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.num_channels = config.num_channels + + def forward(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values) + embedding = self.pooler(embedding) + return embedding + + +class ResNetShortCut(nn.Module): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.normalization = nn.BatchNorm2d(out_channels) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class ResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + self.shortcut = ( + ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + ResNetConvLayer(in_channels, out_channels, stride=stride), + ResNetConvLayer(out_channels, out_channels, activation=None), + ) + self.activation = ACT2FN[activation] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetBottleNeckLayer(nn.Module): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If + `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + activation: str = "relu", + reduction: int = 4, + downsample_in_bottleneck: bool = False, + ): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + self.shortcut = ( + ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + ResNetConvLayer( + in_channels, reduces_channels, kernel_size=1, stride=stride if downsample_in_bottleneck else 1 + ), + ResNetConvLayer(reduces_channels, reduces_channels, stride=stride if not downsample_in_bottleneck else 1), + ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[activation] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetStage(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + def __init__( + self, + config: ResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 2, + depth: int = 2, + ): + super().__init__() + + layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer + + if config.layer_type == "bottleneck": + first_layer = layer( + in_channels, + out_channels, + stride=stride, + activation=config.hidden_act, + downsample_in_bottleneck=config.downsample_in_bottleneck, + ) + else: + first_layer = layer(in_channels, out_channels, stride=stride, activation=config.hidden_act) + self.layers = nn.Sequential( + first_layer, *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)] + ) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class ResNetEncoder(nn.Module): + def __init__(self, config: ResNetConfig): + super().__init__() + self.stages = nn.ModuleList([]) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages.append( + ResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + self.stages.append(ResNetStage(config, in_channels, out_channels, depth=depth)) + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +@auto_docstring +class ResNetPreTrainedModel(PreTrainedModel): + config: ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + _no_split_modules = ["ResNetConvLayer", "ResNetShortCut"] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +@auto_docstring +class ResNetModel(ResNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.embedder = ResNetEmbeddings(config) + self.encoder = ResNetEncoder(config) + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class ResNetForImageClassification(ResNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.resnet = ResNetModel(config) + # classification head + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(), + ) + # initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +@auto_docstring( + custom_intro=""" + ResNet backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): + has_attentions = False + + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embedding_size] + config.hidden_sizes + self.embedder = ResNetEmbeddings(config) + self.encoder = ResNetEncoder(config) + + # initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + embedding_output = self.embedder(pixel_values) + + outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = ["ResNetForImageClassification", "ResNetModel", "ResNetPreTrainedModel", "ResNetBackbone"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_tf_resnet.py b/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_tf_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c415f97b05ef0d23bbb55811c9ac0aeed968cb --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/resnet/modeling_tf_resnet.py @@ -0,0 +1,596 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow ResNet model.""" + +from typing import Optional, Union + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFBaseModelOutputWithPoolingAndNoAttention, + TFImageClassifierOutputWithNoAttention, +) +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_resnet import ResNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ResNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/resnet-50" +_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat" + + +class TFResNetConvLayer(keras.layers.Layer): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + activation: str = "relu", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.pad_value = kernel_size // 2 + self.conv = keras.layers.Conv2D( + out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution" + ) + # Use same default momentum and epsilon as PyTorch equivalent + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") + self.activation = ACT2FN[activation] if activation is not None else keras.layers.Activation("linear") + self.in_channels = in_channels + self.out_channels = out_channels + + def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: + # Pad to match that done in the PyTorch Conv2D model + height_pad = width_pad = (self.pad_value, self.pad_value) + hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) + hidden_state = self.conv(hidden_state) + return hidden_state + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build([None, None, None, self.in_channels]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.out_channels]) + + +class TFResNetEmbeddings(keras.layers.Layer): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.embedder = TFResNetConvLayer( + config.num_channels, + config.embedding_size, + kernel_size=7, + stride=2, + activation=config.hidden_act, + name="embedder", + ) + self.pooler = keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="pooler") + self.num_channels = config.num_channels + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + _, _, _, num_channels = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = pixel_values + hidden_state = self.embedder(hidden_state) + hidden_state = tf.pad(hidden_state, [[0, 0], [1, 1], [1, 1], [0, 0]]) + hidden_state = self.pooler(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedder", None) is not None: + with tf.name_scope(self.embedder.name): + self.embedder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFResNetShortCut(keras.layers.Layer): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs) -> None: + super().__init__(**kwargs) + self.convolution = keras.layers.Conv2D( + out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution" + ) + # Use same default momentum and epsilon as PyTorch equivalent + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") + self.in_channels = in_channels + self.out_channels = out_channels + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = x + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution", None) is not None: + with tf.name_scope(self.convolution.name): + self.convolution.build([None, None, None, self.in_channels]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.out_channels]) + + +class TFResNetBasicLayer(keras.layers.Layer): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + def __init__( + self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", **kwargs + ) -> None: + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + self.conv1 = TFResNetConvLayer(in_channels, out_channels, stride=stride, name="layer.0") + self.conv2 = TFResNetConvLayer(out_channels, out_channels, activation=None, name="layer.1") + self.shortcut = ( + TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else keras.layers.Activation("linear", name="shortcut") + ) + self.activation = ACT2FN[activation] + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + residual = hidden_state + hidden_state = self.conv1(hidden_state, training=training) + hidden_state = self.conv2(hidden_state, training=training) + residual = self.shortcut(residual, training=training) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build(None) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build(None) + if getattr(self, "shortcut", None) is not None: + with tf.name_scope(self.shortcut.name): + self.shortcut.build(None) + + +class TFResNetBottleNeckLayer(keras.layers.Layer): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + activation: str = "relu", + reduction: int = 4, + **kwargs, + ) -> None: + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + self.conv0 = TFResNetConvLayer(in_channels, reduces_channels, kernel_size=1, name="layer.0") + self.conv1 = TFResNetConvLayer(reduces_channels, reduces_channels, stride=stride, name="layer.1") + self.conv2 = TFResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None, name="layer.2") + self.shortcut = ( + TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else keras.layers.Activation("linear", name="shortcut") + ) + self.activation = ACT2FN[activation] + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + residual = hidden_state + hidden_state = self.conv0(hidden_state, training=training) + hidden_state = self.conv1(hidden_state, training=training) + hidden_state = self.conv2(hidden_state, training=training) + residual = self.shortcut(residual, training=training) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv0", None) is not None: + with tf.name_scope(self.conv0.name): + self.conv0.build(None) + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build(None) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build(None) + if getattr(self, "shortcut", None) is not None: + with tf.name_scope(self.shortcut.name): + self.shortcut.build(None) + + +class TFResNetStage(keras.layers.Layer): + """ + A ResNet stage composed of stacked layers. + """ + + def __init__( + self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs + ) -> None: + super().__init__(**kwargs) + + layer = TFResNetBottleNeckLayer if config.layer_type == "bottleneck" else TFResNetBasicLayer + + layers = [layer(in_channels, out_channels, stride=stride, activation=config.hidden_act, name="layers.0")] + layers += [ + layer(out_channels, out_channels, activation=config.hidden_act, name=f"layers.{i + 1}") + for i in range(depth - 1) + ] + self.stage_layers = layers + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + for layer in self.stage_layers: + hidden_state = layer(hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "stage_layers", None) is not None: + for layer in self.stage_layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFResNetEncoder(keras.layers.Layer): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages = [ + TFResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + name="stages.0", + ) + ] + for i, (in_channels, out_channels, depth) in enumerate( + zip(config.hidden_sizes, config.hidden_sizes[1:], config.depths[1:]) + ): + self.stages.append(TFResNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}")) + + def call( + self, + hidden_state: tf.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> TFBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state, training=training) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "stages", None) is not None: + for layer in self.stages: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFResNetPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + + @property + def input_signature(self): + return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)} + + +RESNET_START_DOCSTRING = r""" + This model is a TensorFlow + [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a + regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ConvNextImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@keras_serializable +class TFResNetMainLayer(keras.layers.Layer): + config_class = ResNetConfig + + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + self.embedder = TFResNetEmbeddings(config, name="embedder") + self.encoder = TFResNetEncoder(config, name="encoder") + self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True) + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TF 2.0 image layers can't use NCHW format when running on CPU. + # We transpose to NHWC format and then transpose back after the full forward pass. + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) + embedding_output = self.embedder(pixel_values, training=training) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + # Transpose all the outputs to the NCHW format + # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width) + last_hidden_state = tf.transpose(last_hidden_state, (0, 3, 1, 2)) + pooled_output = tf.transpose(pooled_output, (0, 3, 1, 2)) + hidden_states = () + for hidden_state in encoder_outputs[1:]: + hidden_states = hidden_states + tuple(tf.transpose(h, (0, 3, 1, 2)) for h in hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + hidden_states + + hidden_states = hidden_states if output_hidden_states else None + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedder", None) is not None: + with tf.name_scope(self.embedder.name): + self.embedder.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class TFResNetModel(TFResNetPreTrainedModel): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self.resnet = TFResNetMainLayer(config=config, name="resnet") + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + resnet_outputs = self.resnet( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return resnet_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "resnet", None) is not None: + with tf.name_scope(self.resnet.name): + self.resnet.build(None) + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self.num_labels = config.num_labels + self.resnet = TFResNetMainLayer(config, name="resnet") + # classification head + self.classifier_layer = ( + keras.layers.Dense(config.num_labels, name="classifier.1") + if config.num_labels > 0 + else keras.layers.Activation("linear", name="classifier.1") + ) + self.config = config + + def classifier(self, x: tf.Tensor) -> tf.Tensor: + x = keras.layers.Flatten()(x) + logits = self.classifier_layer(x) + return logits + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple[tf.Tensor], TFImageClassifierOutputWithNoAttention]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet( + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "resnet", None) is not None: + with tf.name_scope(self.resnet.name): + self.resnet.build(None) + if getattr(self, "classifier_layer", None) is not None: + with tf.name_scope(self.classifier_layer.name): + self.classifier_layer.build([None, None, self.config.hidden_sizes[-1]]) + + +__all__ = ["TFResNetForImageClassification", "TFResNetModel", "TFResNetPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/roberta/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/roberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9418d33d354ab6192d749d06ee32e39f5de670 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/roberta/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_roberta import * + from .modeling_flax_roberta import * + from .modeling_roberta import * + from .modeling_tf_roberta import * + from .tokenization_roberta import * + from .tokenization_roberta_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/roberta/configuration_roberta.py b/venv/lib/python3.13/site-packages/transformers/models/roberta/configuration_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..04917804a2252ccc52201fd7f4dbcf6b480f0fd0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/roberta/configuration_roberta.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RoBERTa configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RobertaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is + used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa + [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import RobertaConfig, RobertaModel + + >>> # Initializing a RoBERTa configuration + >>> configuration = RobertaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RobertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "roberta" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class RobertaOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["RobertaConfig", "RobertaOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_flax_roberta.py b/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_flax_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..3b46c0fa682ff2b6441984d5b285d7e1cad3c2d7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_flax_roberta.py @@ -0,0 +1,1500 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base" +_CONFIG_FOR_DOC = "RobertaConfig" + +remat = nn_partitioning.remat + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: jnp.ndarray + padding_idx: int + + Returns: jnp.ndarray + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = (input_ids != padding_idx).astype("i4") + + if mask.ndim > 2: + mask = mask.reshape((-1, mask.shape[-1])) + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + incremental_indices = incremental_indices.reshape(input_ids.shape) + else: + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + + return incremental_indices.astype("i4") + padding_idx + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta +class FlaxRobertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta +class FlaxRobertaSelfAttention(nn.Module): + config: RobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta +class FlaxRobertaSelfOutput(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta +class FlaxRobertaAttention(nn.Module): + config: RobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype) + self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta +class FlaxRobertaIntermediate(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta +class FlaxRobertaOutput(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta +class FlaxRobertaLayer(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) + self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) + self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta +class FlaxRobertaLayerCollection(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + if self.gradient_checkpointing: + FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta +class FlaxRobertaEncoder(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.layer = FlaxRobertaLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta +class FlaxRobertaPooler(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +class FlaxRobertaLMHead(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.decoder = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN["gelu"](hidden_states) + hidden_states = self.layer_norm(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +class FlaxRobertaClassificationHead(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.out_proj = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = nn.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + base_model_prefix = "roberta" + + module_class: nn.Module = None + + def __init__( + self, + config: RobertaConfig, + input_shape: tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.ones_like(input_ids) + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: Optional[dict] = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: Optional[dict] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxRobertaAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta +class FlaxRobertaModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + gradient_checkpointing: bool = False + + def setup(self): + self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRobertaEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaModel(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaModule + + +append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) + + +class FlaxRobertaForMaskedLMModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) +class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForMaskedLMModule + + +append_call_sample_docstring( + FlaxRobertaForMaskedLM, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutputWithPooling, + _CONFIG_FOR_DOC, + mask="", +) + + +class FlaxRobertaForSequenceClassificationModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxRobertaForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta +class FlaxRobertaForMultipleChoiceModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxRobertaForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta +class FlaxRobertaForTokenClassificationModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForTokenClassificationModule + + +append_call_sample_docstring( + FlaxRobertaForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta +class FlaxRobertaForQuestionAnsweringModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxRobertaForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRobertaForCausalLMModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxRobertaForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxRobertaForCausalLM", + "FlaxRobertaForMaskedLM", + "FlaxRobertaForMultipleChoice", + "FlaxRobertaForQuestionAnswering", + "FlaxRobertaForSequenceClassification", + "FlaxRobertaForTokenClassification", + "FlaxRobertaModel", + "FlaxRobertaPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_roberta.py b/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..33fb44118a904cbce0e69ca5071180fc0435bb92 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_roberta.py @@ -0,0 +1,1562 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RoBERTa model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + + +class RobertaEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta +class RobertaSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta +class RobertaSdpaSelfAttention(RobertaSelfAttention): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) + self.dropout_prob = config.attention_probs_dropout_prob + + # Adapted from RobertaSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + past_key_values, + output_attentions, + cache_position, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + return attn_output, None + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class RobertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": RobertaSelfAttention, + "sdpa": RobertaSdpaSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA +class RobertaAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, + ) + self.output = RobertaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class RobertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class RobertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta +class RobertaLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RobertaAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) + self.intermediate = RobertaIntermediate(config) + self.output = RobertaOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta +class RobertaEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class RobertaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class RobertaPreTrainedModel(PreTrainedModel): + config: RobertaConfig + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"] + _supports_sdpa = True + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaLMHead + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, RobertaLMHead): + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ +) +# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->Roberta, BERT->ROBERTA +class RobertaModel(RobertaPreTrainedModel): + _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = RobertaEmbeddings(config) + self.encoder = RobertaEncoder(config) + + self.pooler = RobertaPooler(config) if add_pooling_layer else None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base") + >>> config = AutoConfig.from_pretrained("FacebookAI/roberta-base") + >>> config.is_decoder = True + >>> model = RobertaForCausalLM.from_pretrained("FacebookAI/roberta-base", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring +class RobertaForMaskedLM(RobertaPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaLMHead(nn.Module): + """Roberta Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@auto_docstring( + custom_intro=""" + RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """ +) +class RobertaForSequenceClassification(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class RobertaForMultipleChoice(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.roberta = RobertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.roberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(reshaped_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class RobertaForTokenClassification(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = RobertaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@auto_docstring +class RobertaForQuestionAnswering(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "RobertaForCausalLM", + "RobertaForMaskedLM", + "RobertaForMultipleChoice", + "RobertaForQuestionAnswering", + "RobertaForSequenceClassification", + "RobertaForTokenClassification", + "RobertaModel", + "RobertaPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_tf_roberta.py b/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_tf_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c56b85d5f37433c627572e495a12a034022ab0 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/roberta/modeling_tf_roberta.py @@ -0,0 +1,1782 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 RoBERTa model.""" + +from __future__ import annotations + +import math +import warnings + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base" +_CONFIG_FOR_DOC = "RobertaConfig" + + +class TFRobertaEmbeddings(keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.padding_idx = 1 + self.config = config + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: tf.Tensor + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask + + return incremental_indices + self.padding_idx + + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) + else: + position_ids = tf.expand_dims( + tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Roberta +class TFRobertaPooler(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Roberta +class TFRobertaSelfAttention(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFRobertaModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Roberta +class TFRobertaSelfOutput(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Roberta +class TFRobertaAttention(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFRobertaSelfAttention(config, name="self") + self.dense_output = TFRobertaSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Roberta +class TFRobertaIntermediate(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Roberta +class TFRobertaOutput(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Roberta +class TFRobertaLayer(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFRobertaAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFRobertaAttention(config, name="crossattention") + self.intermediate = TFRobertaIntermediate(config, name="intermediate") + self.bert_output = TFRobertaOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Roberta +class TFRobertaEncoder(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_values: tuple[tuple[tf.Tensor]] | None, + use_cache: bool | None, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFRobertaMainLayer(keras.layers.Layer): + config_class = RobertaConfig + + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.encoder = TFRobertaEncoder(config, name="encoder") + self.pooler = TFRobertaPooler(config, name="pooler") if add_pooling_layer else None + # The embeddings must be the last declaration in order to follow the weights order + self.embeddings = TFRobertaEmbeddings(config, name="embeddings") + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + + +class TFRobertaPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + base_model_prefix = "roberta" + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class TFRobertaModel(TFRobertaPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.roberta = TFRobertaMainLayer(config, name="roberta") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFBaseModelOutputWithPoolingAndCrossAttentions: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + + +class TFRobertaLMHead(keras.layers.Layer): + """Roberta Head for masked language modeling.""" + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.act = get_tf_activation("gelu") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + # project back to size of vocabulary with bias + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) +class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.1, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config: RobertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + logits = self.lm_head(hidden_states=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFRobertaClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.out_proj = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + self.config = config + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.classifier = TFRobertaClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="cardiffnlp/twitter-roberta-base-emotion", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'optimism'", + expected_loss=0.08, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFRobertaMainLayer(config, name="roberta") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + outputs = self.roberta( + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="ydshieh/roberta-large-ner-english", + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", + expected_loss=0.01, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="ydshieh/roberta-base-squad2", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="' puppet'", + expected_loss=0.86, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFRobertaForCausalLM", + "TFRobertaForMaskedLM", + "TFRobertaForMultipleChoice", + "TFRobertaForQuestionAnswering", + "TFRobertaForSequenceClassification", + "TFRobertaForTokenClassification", + "TFRobertaMainLayer", + "TFRobertaModel", + "TFRobertaPreTrainedModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/roberta/tokenization_roberta.py b/venv/lib/python3.13/site-packages/transformers/models/roberta/tokenization_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..67cdcbbf488a53abcd51db017de963cea48ea647 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/roberta/tokenization_roberta.py @@ -0,0 +1,402 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RoBERTa.""" + +import json +import os +from functools import lru_cache +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +@lru_cache +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class RobertaTokenizer(PreTrainedTokenizer): + """ + Constructs a RoBERTa tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import RobertaTokenizer + + >>> tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + + # these special tokens are not part of the vocab.json, let's add them in the correct order + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + vocab = dict(self.encoder).copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RoBERTa sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + +__all__ = ["RobertaTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/roberta/tokenization_roberta_fast.py b/venv/lib/python3.13/site-packages/transformers/models/roberta/tokenization_roberta_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ddcfc82d49e43e3fbe7c32313e67f41ed1be70 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/roberta/tokenization_roberta_fast.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for RoBERTa.""" + +import json +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_roberta import RobertaTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +class RobertaTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" RoBERTa tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 + tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import RobertaTokenizerFast + + >>> tokenizer = RobertaTokenizerFast.from_pretrained("FacebookAI/roberta-base") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = RobertaTokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + trim_offsets=True, + **kwargs, + ): + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + **kwargs, + ) + + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + @property + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + Roberta tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on Roberta. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + +__all__ = ["RobertaTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efc3772ca6d015271102638c3948958777498fd9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_rt_detr_v2 import * + from .modeling_rt_detr_v2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4a53483b1f7f0f10a2cf208df241165d2e9411 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py @@ -0,0 +1,387 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rt_detr_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class RTDetrV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a + RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RT-DETR architecture. + + e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + eval_size (`tuple[int, int]`, *optional*): + Height and width used to compute the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + decoder_n_levels (`int`, *optional*, defaults to 3): + The number of feature levels used by the decoder. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Scaling factor applied to the attention offsets in the decoder. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + + Examples: + + ```python + >>> from transformers import RTDetrV2Config, RTDetrV2Model + + >>> # Initializing a RT-DETR configuration + >>> configuration = RTDetrV2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RTDetrV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "rt_detr_v2" + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + freeze_backbone_batch_norms=True, + backbone_kwargs=None, + # encoder HybridEncoder + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + # decoder RTDetrV2Transformer + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + # Loss + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + eos_coefficient=1e-4, + decoder_n_levels=3, # default value + decoder_offset_scale=0.5, # default value + decoder_method="default", + **kwargs, + ): + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + # backbone + if backbone_config is None and backbone is None: + logger.info( + "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone." + ) + backbone_model_type = "rt_detr_resnet" + config_class = CONFIG_MAPPING[backbone_model_type] + # this will map it to RTDetrResNetConfig + # note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1 + # and we would need to create RTDetrV2ResNetModel + backbone_config = config_class( + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + self.backbone_kwargs = backbone_kwargs + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.eos_coefficient = eos_coefficient + + if not hasattr(self, "d_model"): + self.d_model = d_model + + if not hasattr(self, "encoder_attention_heads"): + self.encoder_attention_heads = encoder_attention_heads + # add the new attributes with the given values or defaults + self.decoder_n_levels = decoder_n_levels + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + @classmethod + def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`RTDetrV2Config`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) + + +__all__ = ["RTDetrV2Config"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..548a9378a2c0ca5d1a488cba55c356e1ff33f67f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -0,0 +1,1998 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rt_detr_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2CLS, ACT2FN +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ModelOutput, auto_docstring, is_torchdynamo_compiling, torch_int +from ...utils.backbone_utils import load_backbone +from .configuration_rt_detr_v2 import RTDetrV2Config + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: list[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# the main change +class RTDetrV2MultiscaleDeformableAttention(nn.Module): + """ + RTDetrV2 version of multiscale deformable attention, extending the base implementation + with improved offset handling and initialization. + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + num_heads = config.decoder_attention_heads + n_points = config.decoder_n_points + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + + # V2-specific attributes + self.n_levels = config.decoder_n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.offset_scale = config.decoder_offset_scale + self.method = config.decoder_method + + # Initialize n_points list and scale + n_points_list = [self.n_points for _ in range(self.n_levels)] + self.n_points_list = n_points_list + n_points_scale = [1 / n for n in n_points_list for _ in range(n)] + self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # Process inputs up to sampling locations calculation using parent class logic + if position_embeddings is not None: + hidden_states = hidden_states + position_embeddings + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + + # V2-specific sampling offsets shape + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2 + ) + + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1) + + # V2-specific sampling locations calculation + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + # V2-specific attention implementation choice + output = multi_scale_deformable_attention_v2( + value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method + ) + + output = self.output_proj(output) + return output, attention_weights + + +class RTDetrV2MultiheadAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # get queries, keys and values + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + # expand attention_mask + if attention_mask is not None: + # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size()) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + if attention_mask.dtype == torch.bool: + attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( + attention_mask, -torch.inf + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class RTDetrV2DecoderLayer(nn.Module): + def __init__(self, config: RTDetrV2Config): + super().__init__() + # self-attention + self.self_attn = RTDetrV2MultiheadAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # override only the encoder attention module with v2 version + self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config) + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # feedforward neural networks + self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model) + self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = second_residual + hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@auto_docstring +class RTDetrV2PreTrainedModel(PreTrainedModel): + config: RTDetrV2Config + base_model_prefix = "rt_detr_v2" + main_input_name = "pixel_values" + _no_split_modules = [r"RTDetrV2HybridEncoder", r"RTDetrV2DecoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(layer.weight) + nn.init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + nn.init.constant_(layer.layers[-1].weight, 0) + nn.init.constant_(layer.layers[-1].bias, 0) + + elif isinstance(module, RTDetrV2MultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + + elif isinstance(module, RTDetrV2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(module.enc_score_head.weight) + nn.init.constant_(module.enc_score_head.bias, bias) + + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + nn.init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + nn.init.xavier_uniform_(module.denoising_class_embed.weight) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RTDetrV2Decoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + """ +) +class RTDetrV2DecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + intermediate_predicted_corners: Optional[torch.FloatTensor] = None + initial_reference_points: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class RTDetrV2Decoder(RTDetrV2PreTrainedModel): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)]) + self.query_pos_head = RTDetrV2MLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each self-attention layer. + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + + reference_points = F.sigmoid(reference_points) + + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L252 + for idx, decoder_layer in enumerate(self.layers): + reference_points_input = reference_points.unsqueeze(2) + position_embeddings = self.query_pos_head(reference_points) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + predicted_corners = self.bbox_embed[idx](hidden_states) + new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points)) + reference_points = new_reference_points.detach() + + intermediate += (hidden_states,) + intermediate_reference_points += ( + (new_reference_points,) if self.bbox_embed is not None else (reference_points,) + ) + + if self.class_embed is not None: + logits = self.class_embed[idx](hidden_states) + intermediate_logits += (logits,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + if self.class_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_logits, + intermediate_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return RTDetrV2DecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RT-DETR encoder-decoder model. + """ +) +class RTDetrV2ModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points used for the first decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + intermediate_predicted_corners: Optional[torch.FloatTensor] = None + initial_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + init_reference_points: Optional[torch.FloatTensor] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[dict] = None + + +class RTDetrV2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrV2FrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = RTDetrV2FrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class RTDetrV2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_rt_detr_v2_resnet.py. + + nn.BatchNorm2d layers are replaced by RTDetrV2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/RTDetrV2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +class RTDetrV2ConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrV2EncoderLayer(nn.Module): + def __init__(self, config: RTDetrV2Config): + super().__init__() + self.normalize_before = config.normalize_before + + # self-attention + self.self_attn = RTDetrV2MultiheadAttention( + embed_dim=config.encoder_hidden_dim, + num_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.encoder_activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim) + self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class RTDetrV2RepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + + activation = config.activation_function + hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion) + self.conv1 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1) + self.conv2 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class RTDetrV2CSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + + in_channels = config.encoder_hidden_dim * 2 + out_channels = config.encoder_hidden_dim + num_blocks = 3 + activation = config.activation_function + + hidden_channels = int(out_channels * config.hidden_expansion) + self.conv1 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.Sequential(*[RTDetrV2RepVggBlock(config) for _ in range(num_blocks)]) + if hidden_channels != out_channels: + self.conv3 = RTDetrV2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = nn.Identity() + + def forward(self, hidden_state): + hidden_state_1 = self.conv1(hidden_state) + hidden_state_1 = self.bottlenecks(hidden_state_1) + hidden_state_2 = self.conv2(hidden_state) + return self.conv3(hidden_state_1 + hidden_state_2) + + +class RTDetrV2Encoder(nn.Module): + def __init__(self, config: RTDetrV2Config): + super().__init__() + + self.layers = nn.ModuleList([RTDetrV2EncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor: + hidden_states = src + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + return hidden_states + + +class RTDetrV2HybridEncoder(nn.Module): + """ + Decoder consisting of a projection layer, a set of `RTDetrV2Encoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069 + + Args: + config: RTDetrV2Config + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + self.config = config + self.in_channels = config.encoder_in_channels + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.num_fpn_stages = len(self.in_channels) - 1 + self.num_pan_stages = len(self.in_channels) - 1 + activation = config.activation_function + + # encoder transformer + self.encoder = nn.ModuleList([RTDetrV2Encoder(config) for _ in range(len(self.encode_proj_layers))]) + + # top-down FPN + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(self.num_fpn_stages): + lateral_conv = RTDetrV2ConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=1, + stride=1, + activation=activation, + ) + fpn_block = RTDetrV2CSPRepLayer(config) + self.lateral_convs.append(lateral_conv) + self.fpn_blocks.append(fpn_block) + + # bottom-up PAN + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(self.num_pan_stages): + downsample_conv = RTDetrV2ConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=3, + stride=2, + activation=activation, + ) + pan_block = RTDetrV2CSPRepLayer(config) + self.downsample_convs.append(downsample_conv) + self.pan_blocks.append(pan_block) + + @staticmethod + def build_2d_sincos_position_embedding( + width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ): + grid_w = torch.arange(torch_int(width), device=device).to(dtype) + grid_h = torch.arange(torch_int(height), device=device).to(dtype) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # encoder + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + height, width = hidden_states[enc_ind].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + device=src_flatten.device, + dtype=src_flatten.dtype, + ) + else: + pos_embed = None + + layer_outputs = self.encoder[i]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + hidden_states[enc_ind] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + + # top-down FPN + fpn_feature_maps = [hidden_states[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + # apply lateral block + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + # apply fpn block + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps = fpn_feature_maps[::-1] + + # bottom-up PAN + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + if not return_dict: + return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions + ) + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`list[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """ +) +class RTDetrV2Model(RTDetrV2PreTrainedModel): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + + # Create backbone + self.backbone = RTDetrV2ConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + + # Create encoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/hybrid_encoder.py#L212 + num_backbone_outs = len(intermediate_channel_sizes) + encoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[_] + encoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) + + # Create encoder + self.encoder = RTDetrV2HybridEncoder(config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = RTDetrV2MLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + + # Create decoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412 + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = config.decoder_in_channels[_] + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list) + # decoder + self.decoder = RTDetrV2Decoder(config) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @compile_compatible_method_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, device=device).to(dtype), + torch.arange(end=width, device=device).to(dtype), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + grid_xy = grid_xy.unsqueeze(0) + 0.5 + grid_xy[..., 0] /= width + grid_xy[..., 1] /= height + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device)) + + return anchors, valid_mask + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], RTDetrV2ModelOutput]: + r""" + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, RTDetrV2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd") + >>> model = RTDetrV2Model.from_pretrained("PekingU/RTDetrV2_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + features = self.backbone(pixel_values, pixel_mask) + + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if output_hidden_states else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 + else encoder_outputs[1] + if output_attentions + else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs[0]): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple( + value + for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits] + if value is not None + ) + dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values]) + tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs + + return tuple_outputs + + return RTDetrV2ModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class RTDetrV2MLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_paddle/ppdet/modeling/transformers/utils.py#L453 + + """ + + def __init__(self, config, input_dim, d_model, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [d_model] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`RTDetrV2ForObjectDetection`]. + """ +) +class RTDetrV2ObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~RTDetrV2ImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + intermediate_predicted_corners: Optional[torch.FloatTensor] = None + initial_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor]] = None + init_reference_points: Optional[tuple[torch.FloatTensor]] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[dict] = None + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """ +) +class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = ["bbox_embed", "class_embed"] + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + # RTDETR encoder-decoder model + self.model = RTDetrV2Model(config) + + # Detection heads on top + class_embed = partial(nn.Linear, config.d_model, config.num_labels) + bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + + self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)]) + self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)]) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.FloatTensor], RTDetrV2ObjectDetectionOutput]: + r""" + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import RTDetrV2ImageProcessor, RTDetrV2ForObjectDetection + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = RTDetrV2ImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd") + >>> model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/RTDetrV2_r50vd") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[ + ... 0 + ... ] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21] + Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5] + Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22] + Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48] + Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + denoising_meta_values = ( + outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None + ) + + outputs_class = outputs.intermediate_logits if return_dict else outputs[2] + outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3] + predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4] + initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5] + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5] + enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4] + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + predicted_corners=predicted_corners, + initial_reference_points=initial_reference_points, + **kwargs, + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return RTDetrV2ObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + intermediate_predicted_corners=outputs.intermediate_predicted_corners, + initial_reference_points=outputs.initial_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) + + +__all__ = ["RTDetrV2Model", "RTDetrV2PreTrainedModel", "RTDetrV2ForObjectDetection"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/modular_rt_detr_v2.py b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/modular_rt_detr_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..af28015d146265529d24a56db61181526ad98f75 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/rt_detr_v2/modular_rt_detr_v2.py @@ -0,0 +1,636 @@ +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from functools import partial +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...configuration_utils import PretrainedConfig +from ...utils import is_torchdynamo_compiling, logging +from ...utils.backbone_utils import ( + verify_backbone_config_arguments, +) +from ..auto import CONFIG_MAPPING +from ..rt_detr.modeling_rt_detr import ( + RTDetrDecoder, + RTDetrDecoderLayer, + RTDetrForObjectDetection, + RTDetrMLPPredictionHead, + RTDetrModel, + RTDetrPreTrainedModel, +) + + +logger = logging.get_logger(__name__) + + +class RTDetrV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a + RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RT-DETR architecture. + + e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + eval_size (`tuple[int, int]`, *optional*): + Height and width used to compute the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + decoder_n_levels (`int`, *optional*, defaults to 3): + The number of feature levels used by the decoder. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Scaling factor applied to the attention offsets in the decoder. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + + Examples: + + ```python + >>> from transformers import RTDetrV2Config, RTDetrV2Model + + >>> # Initializing a RT-DETR configuration + >>> configuration = RTDetrV2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RTDetrV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "rt_detr_v2" + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + freeze_backbone_batch_norms=True, + backbone_kwargs=None, + # encoder HybridEncoder + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + # decoder RTDetrV2Transformer + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + # Loss + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + eos_coefficient=1e-4, + decoder_n_levels=3, # default value + decoder_offset_scale=0.5, # default value + decoder_method="default", + **kwargs, + ): + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + # backbone + if backbone_config is None and backbone is None: + logger.info( + "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone." + ) + backbone_model_type = "rt_detr_resnet" + config_class = CONFIG_MAPPING[backbone_model_type] + # this will map it to RTDetrResNetConfig + # note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1 + # and we would need to create RTDetrV2ResNetModel + backbone_config = config_class( + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + self.backbone_kwargs = backbone_kwargs + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.eos_coefficient = eos_coefficient + + if not hasattr(self, "d_model"): + self.d_model = d_model + + if not hasattr(self, "encoder_attention_heads"): + self.encoder_attention_heads = encoder_attention_heads + # add the new attributes with the given values or defaults + self.decoder_n_levels = decoder_n_levels + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + @classmethod + def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`RTDetrV2Config`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: list[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# the main change +class RTDetrV2MultiscaleDeformableAttention(nn.Module): + """ + RTDetrV2 version of multiscale deformable attention, extending the base implementation + with improved offset handling and initialization. + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + num_heads = config.decoder_attention_heads + n_points = config.decoder_n_points + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + + # V2-specific attributes + self.n_levels = config.decoder_n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.offset_scale = config.decoder_offset_scale + self.method = config.decoder_method + + # Initialize n_points list and scale + n_points_list = [self.n_points for _ in range(self.n_levels)] + self.n_points_list = n_points_list + n_points_scale = [1 / n for n in n_points_list for _ in range(n)] + self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # Process inputs up to sampling locations calculation using parent class logic + if position_embeddings is not None: + hidden_states = hidden_states + position_embeddings + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + + # V2-specific sampling offsets shape + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2 + ) + + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1) + + # V2-specific sampling locations calculation + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + # V2-specific attention implementation choice + output = multi_scale_deformable_attention_v2( + value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method + ) + + output = self.output_proj(output) + return output, attention_weights + + +class RTDetrV2DecoderLayer(RTDetrDecoderLayer): + def __init__(self, config: RTDetrV2Config): + # initialize parent class + super().__init__(config) + # override only the encoder attention module with v2 version + self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config) + + +class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel): + pass + + +class RTDetrV2Decoder(RTDetrDecoder): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)]) + + +class RTDetrV2Model(RTDetrModel): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + # decoder + self.decoder = RTDetrV2Decoder(config) + + +class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead): + pass + + +class RTDetrV2ForObjectDetection(RTDetrForObjectDetection, RTDetrV2PreTrainedModel): + def __init__(self, config: RTDetrV2Config): + RTDetrV2PreTrainedModel.__init__(self, config) + # RTDETR encoder-decoder model + self.model = RTDetrV2Model(config) + + # Detection heads on top + class_embed = partial(nn.Linear, config.d_model, config.num_labels) + bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + + self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)]) + self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)]) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + +__all__ = [ + "RTDetrV2Config", + "RTDetrV2Model", + "RTDetrV2PreTrainedModel", + "RTDetrV2ForObjectDetection", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8a2b98e636b7653747861461b7d648b91164c7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sam import * + from .image_processing_sam import * + from .image_processing_sam_fast import * + from .modeling_sam import * + from .modeling_tf_sam import * + from .processing_sam import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam/configuration_sam.py b/venv/lib/python3.13/site-packages/transformers/models/sam/configuration_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..11a3e421d42eab2d98b67e6b0be00dbcc0bbc469 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam/configuration_sam.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAM model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SamPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamPromptEncoder`]. The [`SamPromptEncoder`] + module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield + a similar configuration to that of the SAM-vit-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + """ + + base_config_key = "prompt_encoder_config" + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.image_embedding_size = image_size // patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + + +class SamMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamMaskDecoder`]. It is used to instantiate a SAM + mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults + will yield a similar configuration to that of the SAM-vit-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function used inside the `SamMaskDecoder` module. + mlp_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 2): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsampling rate of the attention layer. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of outputs from the `SamMaskDecoder` module. In the Segment Anything paper, this is set to 3. + iou_head_depth (`int`, *optional*, defaults to 3): + The number of layers in the IoU head module. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The dimensionality of the hidden states in the IoU head module. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + + """ + + base_config_key = "mask_decoder_config" + + def __init__( + self, + hidden_size=256, + hidden_act="relu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + + +class SamVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamVisionModel`]. It is used to instantiate a SAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the SAM ViT-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + num_pos_feats (`int`, *optional*, defaults to 128): + The dimensionality of the position embedding. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * + hidden_size`. + + Example: + + ```python + >>> from transformers import ( + ... SamVisionConfig, + ... SamVisionModel, + ... ) + + >>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration + >>> configuration = SamVisionConfig() + + >>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration + >>> model = SamVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + base_config_key = "vision_config" + model_type = "sam_vision_model" + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=128, + mlp_dim=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim + + +class SamConfig(PretrainedConfig): + r""" + [`SamConfig`] is the configuration class to store the configuration of a [`SamModel`]. It is used to instantiate a + SAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + SAM-ViT-H [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `SamVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamVisionConfig`]. + prompt_encoder_config (Union[`dict`, `SamPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `SamMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamMaskDecoderConfig`]. + + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ( + ... SamVisionConfig, + ... SamPromptEncoderConfig, + ... SamMaskDecoderConfig, + ... SamModel, + ... ) + + >>> # Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration + >>> configuration = SamConfig() + + >>> # Initializing a SamModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration + >>> model = SamModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SamConfig from a SamVisionConfig, SamPromptEncoderConfig, and SamMaskDecoderConfig + + >>> # Initializing SAM vision, SAM Q-Former and language model configurations + >>> vision_config = SamVisionConfig() + >>> prompt_encoder_config = SamPromptEncoderConfig() + >>> mask_decoder_config = SamMaskDecoderConfig() + + >>> config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "sam" + sub_configs = { + "prompt_encoder_config": SamPromptEncoderConfig, + "mask_decoder_config": SamMaskDecoderConfig, + "vision_config": SamVisionConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, SamVisionConfig): + vision_config = vision_config.to_dict() + if isinstance(prompt_encoder_config, SamPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, SamMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = SamVisionConfig(**vision_config) + self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config) + self.initializer_range = initializer_range + + +__all__ = ["SamConfig", "SamMaskDecoderConfig", "SamPromptEncoderConfig", "SamVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam/image_processing_sam.py b/venv/lib/python3.13/site-packages/transformers/models/sam/image_processing_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..6acb775b43db458765892b48fe8fa372513aaa81 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam/image_processing_sam.py @@ -0,0 +1,1497 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SAM.""" + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_tf_available, + is_torch_available, + is_torchvision_available, + logging, + requires_backends, +) + + +if is_torch_available(): + import torch + import torch.nn.functional as F + +if is_torchvision_available(): + from torchvision.ops.boxes import batched_nms + +if is_tf_available(): + import tensorflow as tf + from tensorflow.experimental import numpy as tnp + + from ...tf_utils import flatten, shape_list + +logger = logging.get_logger(__name__) + + +class SamImageProcessor(BaseImageProcessor): + r""" + Constructs a SAM image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`): + Size of the output image after resizing. Resizes the longest edge of the image to match + `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `size` parameter in the + `preprocess` method. + mask_size (`dict`, *optional*, defaults to `{"longest_edge": 256}`): + Size of the output segmentation map after resizing. Resizes the longest edge of the image to match + `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `mask_size` parameter + in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the + `preprocess` method. + pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`): + Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess` + method. + mask_pad_size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`): + Size of the output segmentation map after padding. Can be overridden by the `mask_pad_size` parameter in + the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + mask_size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: bool = True, + pad_size: Optional[int] = None, + mask_pad_size: Optional[int] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 1024} + size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size + + pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024} + pad_size = get_size_dict(pad_size, default_to_square=True) + + mask_size = mask_size if mask_size is not None else {"longest_edge": 256} + mask_size = ( + get_size_dict(max_size=mask_size, default_to_square=False) + if not isinstance(mask_size, dict) + else mask_size + ) + + mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256} + mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.mask_size = mask_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + self.mask_pad_size = mask_pad_size + self.do_convert_rgb = do_convert_rgb + + def pad_image( + self, + image: np.ndarray, + pad_size: dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom. + + Args: + image (`np.ndarray`): + Image to pad. + pad_size (`dict[str, int]`): + Size of the output image after padding. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the + `data_format` of the `image` will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = pad_size["height"], pad_size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + pad_width = output_width - input_width + pad_height = output_height - input_height + + padded_image = pad( + image, + ((0, pad_height), (0, pad_width)), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return padded_image + + def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int): + """ + Compute the output size given input size and target long side length. + """ + oldh, oldw = old_shape + scale = longest_edge * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + newh = int(newh + 0.5) + neww = int(neww + 0.5) + return (newh, neww) + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest + edge of the image will be resized to the specified size, while the other edge will be resized to + maintain the aspect ratio. + resample: + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "longest_edge" not in size: + raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}") + input_size = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"]) + return resize( + image, + size=(output_height, output_width), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool, + do_rescale: bool, + do_normalize: bool, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + rescale_factor: Optional[float] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + reshaped_input_size = get_image_size(image, channel_dim=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + if do_pad: + image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format) + + return image, reshaped_input_size + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[dict[str, int]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]: + # PIL RGBA images are converted to RGB + if do_convert_rgb: + image = convert_to_rgb(image) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + original_size = get_image_size(image, channel_dim=input_data_format) + + image, reshaped_input_size = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + pad_size=pad_size, + input_data_format=input_data_format, + ) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image, original_size, reshaped_input_size + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: Optional[bool] = None, + mask_size: Optional[dict[str, int]] = None, + do_pad: Optional[bool] = None, + mask_pad_size: Optional[dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + segmentation_map = to_numpy_array(segmentation_map) + + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + + original_size = get_image_size(segmentation_map, channel_dim=input_data_format) + + segmentation_map, _ = self._preprocess( + image=segmentation_map, + do_resize=do_resize, + size=mask_size, + resample=PILImageResampling.NEAREST, + do_rescale=False, + do_normalize=False, + do_pad=do_pad, + pad_size=mask_pad_size, + input_data_format=input_data_format, + ) + + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + segmentation_map = segmentation_map.astype(np.int64) + + return segmentation_map, original_size + + def __call__(self, images, segmentation_maps=None, **kwargs): + # Overrides the `__call__` method of the `BaseImageProcessor` class such that the images and segmentation maps can both + # be passed in as positional arguments. + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + mask_size: Optional[dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[dict[str, int]] = None, + mask_pad_size: Optional[dict[str, int]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The longest edge of the image is resized to + `size["longest_edge"]` whilst preserving the aspect ratio. + mask_size (`dict[str, int]`, *optional*, defaults to `self.mask_size`): + Controls the size of the segmentation map after `resize`. The longest edge of the image is resized to + `size["longest_edge"]` whilst preserving the aspect ratio. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image pixel values by rescaling factor. + rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to apply to the image pixel values. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. + pad_size (`dict[str, int]`, *optional*, defaults to `self.pad_size`): + Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and + `pad_size["width"]` if `do_pad` is set to `True`. + mask_pad_size (`dict[str, int]`, *optional*, defaults to `self.mask_pad_size`): + Controls the size of the padding applied to the segmentation map. The image is padded to + `mask_pad_size["height"]` and `mask_pad_size["width"]` if `do_pad` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size + mask_size = mask_size if mask_size is not None else self.mask_size + mask_size = ( + get_size_dict(max_size=mask_size, default_to_square=False) + if not isinstance(mask_size, dict) + else mask_size + ) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + pad_size = pad_size if pad_size is not None else self.pad_size + pad_size = get_size_dict(pad_size, default_to_square=True) + mask_pad_size = mask_pad_size if mask_pad_size is not None else self.mask_pad_size + mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True) + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if segmentation_maps is not None: + segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) + + if not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + images, original_sizes, reshaped_input_sizes = zip( + *( + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + pad_size=pad_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ) + ) + + data = { + "pixel_values": images, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + if segmentation_maps is not None: + segmentation_maps, original_mask_sizes = zip( + *( + self._preprocess_mask( + segmentation_map=mask, + do_resize=do_resize, + mask_size=mask_size, + do_pad=do_pad, + mask_pad_size=mask_pad_size, + input_data_format=input_data_format, + ) + for mask in segmentation_maps + ) + ) + + # masks should start out the same size as input images + assert all( + original_im_size == original_mask_size + for original_im_size, original_mask_size in zip(original_sizes, original_mask_sizes) + ), "Segmentation maps should be the same size as input images." + + data["labels"] = segmentation_maps + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_masks( + self, + masks, + original_sizes, + reshaped_input_sizes, + mask_threshold=0.0, + binarize=True, + pad_size=None, + return_tensors="pt", + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[list[torch.Tensor], list[np.ndarray], list[tf.Tensor]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + return_tensors (`str`, *optional*, defaults to `"pt"`): + If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors. + Returns: + (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where + (height, width) is given by original_size. + """ + if return_tensors == "pt": + return self._post_process_masks_pt( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + elif return_tensors == "tf": + return self._post_process_masks_tf( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'") + + def _post_process_masks_pt( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[list[torch.Tensor], list[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + requires_backends(self, ["torch"]) + pad_size = self.pad_size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + if isinstance(original_sizes, (torch.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)): + reshaped_input_sizes = reshaped_input_sizes.tolist() + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = torch.from_numpy(masks[i]) + elif not isinstance(masks[i], torch.Tensor): + raise TypeError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`") + interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False) + interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] + interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + def _post_process_masks_tf( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`tf.Tensor`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`tf.Tensor`): + The original size of the images before resizing for input to the model, in (height, width) format. + reshaped_input_sizes (`tf.Tensor`): + The size of the image input to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is + given by original_size. + """ + requires_backends(self, ["tf"]) + pad_size = self.pad_size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + + output_masks = [] + for i, original_size in enumerate(original_sizes): + # tf.image expects NHWC, we transpose the NCHW inputs for it + mask = tf.transpose(masks[i], perm=[0, 2, 3, 1]) + interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear") + interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :] + interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear") + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + # And then we transpose them back at the end + output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2])) + + return output_masks + + def post_process_for_mask_generation( + self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt" + ): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`Union[list[torch.Tensor], list[tf.Tensor]]`): + List of all predicted segmentation masks + all_scores (`Union[list[torch.Tensor], list[tf.Tensor]]`): + List of all predicted iou scores + all_boxes (`Union[list[torch.Tensor], list[tf.Tensor]]`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + if return_tensors == "pt": + return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) + elif return_tensors == "tf": + return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) + + def generate_crop_boxes( + self, + image, + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, + device: Optional["torch.device"] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + return_tensors: str = "pt", + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`np.ndarray`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sample from each crop. + crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*, defaults to None): + Device to use for the computation. If None, cpu will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + input_data_format, + ) + if return_tensors == "pt": + if device is None: + device = torch.device("cpu") + crop_boxes = torch.tensor(crop_boxes, device=device) + points_per_crop = torch.tensor(points_per_crop, device=device) + # cropped_images stays as np + input_labels = torch.tensor(input_labels, device=device) + + elif return_tensors == "tf": + if device is not None: + raise ValueError("device is not a supported argument when return_tensors is tf!") + crop_boxes = tf.convert_to_tensor(crop_boxes) + points_per_crop = tf.convert_to_tensor(points_per_crop) + # cropped_images stays as np + input_labels = tf.convert_to_tensor(input_labels) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'.") + return crop_boxes, points_per_crop, cropped_images, input_labels + + def filter_masks( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + return_tensors="pt", + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`Union[torch.Tensor, tf.Tensor]`): + Input masks. + iou_scores (`Union[torch.Tensor, tf.Tensor]`): + List of IoU scores. + original_size (`tuple[int,int]`): + Size of the original image. + cropped_box_image (`np.ndarray`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + if return_tensors == "pt": + return self._filter_masks_pt( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) + elif return_tensors == "tf": + return self._filter_masks_tf( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) + + def _filter_masks_pt( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`torch.Tensor`): + Input masks. + iou_scores (`torch.Tensor`): + List of IoU scores. + original_size (`tuple[int,int]`): + Size of the original image. + cropped_box_image (`np.ndarray`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + requires_backends(self, ["torch"]) + original_height, original_width = original_size + iou_scores = iou_scores.flatten(0, 1) + masks = masks.flatten(0, 1) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + if masks.device != iou_scores.device: + iou_scores = iou_scores.to(masks.device) + + batch_size = masks.shape[0] + + keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box(masks) + + keep_mask = ~_is_box_near_crop_edge( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppression + masks = _mask_to_rle_pytorch(masks) + + return masks, scores, converted_boxes + + def _filter_masks_tf( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`tf.Tensor`): + Input masks. + iou_scores (`tf.Tensor`): + List of IoU scores. + original_size (`tuple[int,int]`): + Size of the original image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + requires_backends(self, ["tf"]) + original_height, original_width = original_size + iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]]) + masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]]) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + batch_size = masks.shape[0] + + keep_mask = tf.ones(batch_size, dtype=tf.bool) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box_tf(masks) + + keep_mask = ~_is_box_near_crop_edge_tf( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppression + masks = _mask_to_rle_tf(masks) + + return masks, scores, converted_boxes + + +def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + ) + unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + stability_scores = intersections / unions + return stability_scores + + +def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int): + # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure + # we get the right division results. + intersections = tf.count_nonzero( + masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32 + ) + unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32) + stability_scores = intersections / unions + return stability_scores + + +def _build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def _normalize_coordinates( + target_size: int, coords: np.ndarray, original_size: tuple[int, int], is_bounding_box=False +) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width) + format. + """ + old_height, old_width = original_size + + scale = target_size * 1.0 / max(old_height, old_width) + new_height, new_width = old_height * scale, old_width * scale + new_width = int(new_width + 0.5) + new_height = int(new_height + 0.5) + + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_width / old_width) + coords[..., 1] = coords[..., 1] * (new_height / old_height) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + +def _generate_crop_boxes( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[list[list[int]], list[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sample per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + + if isinstance(image, list): + raise TypeError("Only one image is allowed for crop generation.") + image = to_numpy_array(image) + original_size = get_image_size(image, input_data_format) + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format + ) + crop_boxes = np.array(crop_boxes) + crop_boxes = crop_boxes.astype(np.float32) + points_per_crop = np.array([point_grid_per_crop]) + points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3)) + + input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + +def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): + """ + Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format + consists of the following required indices: + - X: X coordinate of the top left of the bounding box + - Y: Y coordinate of the top left of the bounding box + - W: width of the bounding box + - H: height of the bounding box + """ + crop_boxes, layer_idxs = [], [] + im_height, im_width = original_size + short_side = min(im_height, im_width) + + # Original image + crop_boxes.append([0, 0, im_width, im_height]) + layer_idxs.append(0) + for i_layer in range(crop_n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side)) + crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side)) + + crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)] + + for left, top in product(crop_box_x0, crop_box_y0): + box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None +): + """ + Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are + also passed. + """ + cropped_images = [] + total_points_per_crop = [] + for i, crop_box in enumerate(crop_boxes): + left, top, right, bottom = crop_box + + channel_dim = infer_channel_dimension_format(image, input_data_format) + if channel_dim == ChannelDimension.LAST: + cropped_im = image[top:bottom, left:right, :] + else: + cropped_im = image[:, top:bottom, left:right] + + cropped_images.append(cropped_im) + + cropped_im_size = get_image_size(cropped_im, channel_dim) + points_scale = np.array(cropped_im_size)[None, ::-1] + + points = points_grid[layer_idxs[i]] * points_scale + normalized_points = _normalize_coordinates(target_size, points, original_size) + total_points_per_crop.append(normalized_points) + + return cropped_images, total_points_per_crop + + +def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return torch.nn.functional.pad(masks, pad, value=0) + + +def _pad_masks_tf(masks, crop_box: list[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return tf.pad(masks, pad, constant_values=0) + + +def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + + left, top, _, _ = crop_box + offset = torch.tensor([[left, top, left, top]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + boxes = (boxes + offset).float() + + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32) + orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32) + + left, top, _, _ = crop_box + offset = tf.convert_to_tensor([[left, top, left, top]]) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = tf.expand_dims(offset, 1) + boxes = tf.cast(boxes + offset, tf.float32) + + near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0) + near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0) + near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge) + return tf.reduce_any(near_crop_edge, axis=1) + + +def _batched_mask_to_box(masks: "torch.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + # torch.max below raises an error on empty inputs, just skip in this case + + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + + +def _batched_mask_to_box_tf(masks: "tf.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + + if tf.size(masks) == 0: + return tf.zeros([*masks.shape[:-2], 4]) + + # Normalize shape to Cxheightxwidth + shape = shape_list(masks) + height, width = shape[-2:] + + # Get top and bottom edges + in_height = tf.reduce_max(masks, axis=-1) + in_height_coords = in_height * tf.range(height)[None, :] + bottom_edges = tf.reduce_max(in_height_coords, axis=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges = tf.reduce_min(in_height_coords, axis=-1) + + # Get left and right edges + in_width, _ = tf.reduce_max(masks, axis=-2) + in_width_coords = in_width * tf.range(width)[None, :] + right_edges, _ = tf.reduce_max(in_width_coords, axis=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = tf.reduce_min(in_width_coords, axis=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1) + out = out * tf.expand_dims(~empty_filter, -1) + + # Return to original shape + out = tf.reshape(out, *shape[:-2], 4) + return out + + +def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = input_mask.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()] + out.append({"size": [height, width], "counts": counts}) + return out + + +def _mask_to_rle_tf(input_mask: "tf.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = tf.where(diff) + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += ( + [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()] + ) + out.append({"size": [height, width], "counts": counts}) + return out + + +def _rle_to_mask(rle: dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + height, width = rle["size"] + mask = np.empty(height * width, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity = not parity + mask = mask.reshape(width, height) + return mask.transpose() # Reshape to original shape + + +def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`torch.Tensor`): + binary masks in the RLE format + iou_scores (`torch.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`torch.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = batched_nms( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`tf.Tensor`): + binary masks in the RLE format + iou_scores (`tf.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`tf.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = tf.image.combined_non_max_suppression( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +__all__ = ["SamImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam/image_processing_sam_fast.py b/venv/lib/python3.13/site-packages/transformers/models/sam/image_processing_sam_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..5cfd5314899c4d9df1064ffbb11dcc0fbe9d990e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam/image_processing_sam_fast.py @@ -0,0 +1,749 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SAM.""" + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Optional, Union + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.ops.boxes import batched_nms +from torchvision.transforms.v2 import functional as F_t + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + pil_torch_interpolation_mapping, +) +from ...processing_utils import Unpack +from ...utils import auto_docstring + + +class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + mask_size (`dict[str, int]`, *optional*): + The size `{"longest_edge": int}` to resize the segmentation maps to. + mask_pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width": int}` to pad the segmentation maps to. Must be larger than any segmentation + map size provided for preprocessing. + """ + + mask_size: Optional[dict[str, int]] + mask_pad_size: Optional[dict[str, int]] + + +@auto_docstring +class SamImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"longest_edge": 1024} + mask_size = {"longest_edge": 256} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + valid_kwargs = SamFastImageProcessorKwargs + + do_pad = True + pad_size = {"height": 1024, "width": 1024} + mask_pad_size = {"height": 256, "width": 256} + + def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int): + """ + Compute the output size given input size and target long side length. + """ + oldh, oldw = old_shape + scale = longest_edge * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + newh = int(newh + 0.5) + neww = int(neww + 0.5) + return (newh, neww) + + def resize( + self, image: "torch.Tensor", size: SizeDict, interpolation: Optional["F_t.InterpolationMode"], **kwargs + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest + edge of the image will be resized to the specified size, while the other edge will be resized to + maintain the aspect ratio. + interpolation: + `F_t.InterpolationMode` filter to use when resizing the image e.g. `F_t.InterpolationMode.BICUBIC`. + + Returns: + `torch.Tensor`: The resized image. + """ + if not size.longest_edge: + raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}") + input_size = image.shape[-2:] + output_height, output_width = self._get_preprocess_shape(input_size, size.longest_edge) + return super().resize( + image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs + ) + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + pad_size: Optional[SizeDict] = None, + mask_size: Optional[SizeDict] = None, + mask_pad_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if pad_size is not None: + pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size")) + if mask_size is not None: + mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size")) + if mask_pad_size is not None: + mask_pad_size = SizeDict(**get_size_dict(mask_pad_size, param_name="mask_pad_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + kwargs["size"] = size + kwargs["pad_size"] = pad_size + kwargs["mask_size"] = mask_size + kwargs["mask_pad_size"] = mask_pad_size + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["data_format"] = data_format + + # torch resize uses interpolation instead of resample + # Check if resample is an int before checking if it's an instance of PILImageResampling + # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. + # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + return kwargs + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[SamFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[SamFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + original_sizes = [image.shape[-2:] for image in images] + images_kwargs = kwargs.copy() + pixel_values = self._preprocess(images, **images_kwargs)["pixel_values"] + reshaped_input_sizes = [image.shape[-2:] for image in images] + data = { + "pixel_values": pixel_values, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + "interpolation": F_t.InterpolationMode.NEAREST_EXACT, + "size": segmentation_maps_kwargs.pop("mask_size"), + "pad_size": segmentation_maps_kwargs.pop("mask_pad_size"), + } + ) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + data["labels"] = processed_segmentation_maps["pixel_values"].squeeze(1).to(torch.int64) + + return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) + + def generate_crop_boxes( + self, + image: "torch.Tensor", + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, + device: Optional["torch.device"] = None, + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`torch.Tensor`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sample from each crop. + crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*, defaults to None): + Device to use for the computation. If None, cpu will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + image = self._process_image(image) + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + ) + if device is None: + device = torch.device("cpu") + crop_boxes = crop_boxes.to(device) + points_per_crop = points_per_crop.to(device) + # cropped_images stays as torch.Tensor + input_labels = input_labels.to(device) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + def filter_masks( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`torch.Tensor`): + Input masks. + iou_scores (`torch.Tensor`): + List of IoU scores. + original_size (`tuple[int,int]`): + Size of the original image. + cropped_box_image (`torch.Tensor`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + original_height, original_width = original_size + iou_scores = iou_scores.flatten(0, 1) + masks = masks.flatten(0, 1) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + if masks.device != iou_scores.device: + iou_scores = iou_scores.to(masks.device) + + batch_size = masks.shape[0] + + keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box(masks) + + keep_mask = ~_is_box_near_crop_edge( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppression + masks = _mask_to_rle(masks) + + return masks, scores, converted_boxes + + def post_process_masks( + self, + masks, + original_sizes, + reshaped_input_sizes, + mask_threshold=0.0, + binarize=True, + pad_size=None, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[torch.Tensor], List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + pad_size = self.size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + if isinstance(original_sizes, (torch.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)): + reshaped_input_sizes = reshaped_input_sizes.tolist() + + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = torch.from_numpy(masks[i]) + elif not isinstance(masks[i], torch.Tensor): + raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`") + interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False) + interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] + interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`torch.Tensor`): + List of all predicted segmentation masks + all_scores (`torch.Tensor`): + List of all predicted iou scores + all_boxes (`torch.Tensor`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + """ + return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh) + + +def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + ) + unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + stability_scores = intersections / unions + return stability_scores + + +def _mask_to_rle(input_mask: "torch.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = input_mask.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()] + out.append({"size": [height, width], "counts": counts}) + return out + + +def _batched_mask_to_box(masks: "torch.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + # torch.max below raises an error on empty inputs, just skip in this case + + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + + +def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + + left, top, _, _ = crop_box + offset = torch.tensor([[left, top, left, top]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + boxes = (boxes + offset).float() + + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return torch.nn.functional.pad(masks, pad, value=0) + + +def _generate_crop_boxes( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, +) -> tuple[list[list[int]], list[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sample per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + + if isinstance(image, list): + raise ValueError("Only one image is allowed for crop generation.") + original_size = image.shape[-2:] + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size + ) + crop_boxes = torch.tensor(crop_boxes) + crop_boxes = crop_boxes.float() + points_per_crop = torch.stack(point_grid_per_crop) + points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3) + cropped_images = torch.stack(cropped_images) + + input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + +def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): + """ + Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format + consists of the following required indices: + - X: X coordinate of the top left of the bounding box + - Y: Y coordinate of the top left of the bounding box + - W: width of the bounding box + - H: height of the bounding box + """ + crop_boxes, layer_idxs = [], [] + im_height, im_width = original_size + short_side = min(im_height, im_width) + + # Original image + crop_boxes.append([0, 0, im_width, im_height]) + layer_idxs.append(0) + for i_layer in range(crop_n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side)) + crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side)) + + crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)] + + for left, top in product(crop_box_x0, crop_box_y0): + box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def _build_point_grid(n_per_side: int) -> torch.Tensor: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = torch.linspace(offset, 1 - offset, n_per_side) + points_x = torch.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = torch.tile(points_one_side[:, None], (1, n_per_side)) + points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2) + return points + + +def _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None +): + """ + Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are + also passed. + """ + cropped_images = [] + total_points_per_crop = [] + for i, crop_box in enumerate(crop_boxes): + left, top, right, bottom = crop_box + cropped_im = image[:, top:bottom, left:right] + + cropped_images.append(cropped_im) + + cropped_im_size = cropped_im.shape[-2:] + points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0) + + points = points_grid[layer_idxs[i]] * points_scale + normalized_points = _normalize_coordinates(target_size, points, original_size) + total_points_per_crop.append(normalized_points) + + return cropped_images, total_points_per_crop + + +def _normalize_coordinates( + target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False +) -> torch.Tensor: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width) + format. + """ + old_height, old_width = original_size + + scale = target_size * 1.0 / max(old_height, old_width) + new_height, new_width = old_height * scale, old_width * scale + new_width = int(new_width + 0.5) + new_height = int(new_height + 0.5) + + coords = deepcopy(coords).float() + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_width / old_width) + coords[..., 1] = coords[..., 1] * (new_height / old_height) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + +def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor: + """Compute a binary mask from an uncompressed RLE.""" + height, width = rle["size"] + mask = torch.empty(height * width, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity = not parity + mask = mask.reshape(width, height) + return mask.transpose(0, 1) # Reshape to original shape + + +def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`torch.Tensor`): + binary masks in the RLE format + iou_scores (`torch.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`torch.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = batched_nms( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +__all__ = ["SamImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam/modeling_sam.py b/venv/lib/python3.13/site-packages/transformers/models/sam/modeling_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..034f7485f426c5ea7ab344af885f8954f1b8c408 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam/modeling_sam.py @@ -0,0 +1,1368 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SAM model.""" + +import collections +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + auto_docstring, + logging, +) +from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + """ +) +class SamVisionEncoderOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Segment-Anything model's output + """ +) +class SamImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class SamPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class SamMLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam +class SamLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SamAttention(nn.Module): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + + self.is_causal = False + + def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + return hidden_states.transpose(1, 2) + + def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: + batch, n_tokens, n_heads, c_per_head = hidden_states.shape + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_similarity: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = self._recombine_heads(attn_output, point_batch_size) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SamTwoWayAttentionBlock(nn.Module): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = SamAttention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate) + self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.mlp = SamMLPBlock(config) + self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate) + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + # Self attention block + if self.skip_first_layer_pe: + queries, _ = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out, _ = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + return queries, keys, attn_out + + +class SamTwoWayTransformer(nn.Module): + def __init__(self, config: SamMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = SamAttention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, _ = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + **kwargs, + ) + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys + + +class SamFeedForward(nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ): + super().__init__() + self.num_layers = num_layers + self.activation = nn.ReLU() + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +class SamMaskDecoder(nn.Module): + def __init__(self, config: SamMaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = SamTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + + self.iou_prediction_head = SamFeedForward( + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + the embeddings from the image encoder + image_positional_embedding (`torch.Tensor`): + positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes + dense_prompt_embeddings (`torch.Tensor`): + the embeddings of the mask inputs + multimask_output (bool): + Whether to return multiple masks or a single mask. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1 + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings is not None: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-point + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + + # Run the transformer, image_positional_embedding are consumed + point_embedding, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width + ) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + return masks, iou_pred + + +class SamPositionalEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.scale = config.hidden_size // 2 + self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats))) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class SamMaskEmbedding(nn.Module): + def __init__(self, config: SamPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = SamLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = SamLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class SamPromptEncoder(nn.Module): + def __init__(self, config: SamConfig): + super().__init__() + self.shared_embedding = SamPositionalEmbedding(config.vision_config) + config = config.prompt_encoder_config + self.mask_embed = SamMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = nn.ModuleList( + [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)] + ) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = torch.zeros(target_point_shape, device=points.device) + padding_label = -torch.ones(target_labels_shape, device=labels.device) + points = torch.cat([points, padding_point], dim=2) + labels = torch.cat([labels, padding_label], dim=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding)) + + point_embedding = torch.where( + (labels == 0)[:, :, :, None], + point_embedding + self.point_embed[0].weight[None, None, :, :], + point_embedding, + ) + + point_embedding = torch.where( + (labels == 1)[:, :, :, None], + point_embedding + self.point_embed[1].weight[None, None, :, :], + point_embedding, + ) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = boxes.reshape(batch_size, nb_boxes, 2, 2) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding[:, :, 0, :] += self.point_embed[2].weight + corner_embedding[:, :, 1, :] += self.point_embed[3].weight + return corner_embedding + + def forward( + self, + input_points: Optional[tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size = input_points.shape[0] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class SamVisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def get_decomposed_rel_pos( + self, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: tuple[int, int], + k_size: tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + + decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + + return decomposed_rel_pos + + def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + return attn_output, attn_weights + + +class SamVisionSdpaAttention(SamVisionAttention): + """ + Multi-head Attention block with relative position embeddings. + Using SDPA instead of the default attention. + """ + + def __init__(self, config, window_size): + super().__init__(config, window_size) + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + if output_attentions: + logger.warning_once( + "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_bias = None + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape( + batch_size, self.num_attention_heads, height * width, height * width + ) + attn_bias = decomposed_rel_pos + + query = query.view(batch_size, self.num_attention_heads, height * width, -1) + key = key.view(batch_size, self.num_attention_heads, height * width, -1) + value = value.view(batch_size, self.num_attention_heads, height * width, -1) + + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias) + + attn_output = ( + attn_output.view(batch_size, self.num_attention_heads, height, width, -1) + .permute(0, 2, 3, 1, 4) + .reshape(batch_size, height, width, -1) + ) + + attn_output = self.proj(attn_output) + return attn_output, None + + +SAM_VISION_ATTENTION_CLASSES = { + "eager": SamVisionAttention, + "sdpa": SamVisionSdpaAttention, +} + + +class SamVisionLayer(GradientCheckpointingLayer): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SamMLPBlock(config) + self.window_size = window_size + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + return hidden_states + + +class SamVisionNeck(nn.Module): + def __init__(self, config: SamVisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +@auto_docstring +class SamPreTrainedModel(PreTrainedModel): + config: SamConfig + base_model_prefix = "sam" + main_input_name = "pixel_values" + _no_split_modules = ["SamVisionAttention"] + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module): + super()._init_weights(module) + if isinstance(module, SamVisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() + elif isinstance(module, SamVisionEncoder): + if self.config.use_abs_pos: + module.pos_embed.data.zero_() + + +class SamVisionEncoder(SamPreTrainedModel): + _can_record_outputs = {"hidden_states": SamVisionLayer, "attentions": SamVisionAttention} + + def __init__(self, config: SamVisionConfig): + super().__init__(config) + self.config = config + self.image_size = config.image_size + self.patch_embed = SamPatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = SamVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = SamVisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + @check_model_inputs(tie_last_hidden_states=False) + def forward( + self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs] + ) -> SamVisionEncoderOutput: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + hidden_states = self.neck(hidden_states) + return SamVisionEncoderOutput( + last_hidden_state=hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The vision model from Sam without any head or projection on top. + """ +) +class SamVisionModel(SamPreTrainedModel): + config: SamVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SamVisionConfig): + super().__init__(config) + self.vision_encoder = SamVisionEncoder(config) + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_encoder.patch_embed + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, SamVisionEncoderOutput]: + return self.vision_encoder(pixel_values, **kwargs) + + +@auto_docstring( + custom_intro=""" + Segment Anything Model (SAM) for generating segmentation masks, given an input image and + input points and labels, boxes, or masks. + """ +) +class SamModel(SamPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)} + + def __init__(self, config: SamConfig): + super().__init__(config) + self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) + + self.vision_encoder = SamVisionEncoder(config.vision_config) + self.prompt_encoder = SamPromptEncoder(config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = SamMaskDecoder(config.mask_decoder_config) + + self.post_init() + + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones((size, size), device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + vision_output = self.vision_encoder( + pixel_values, + **kwargs, + ) + image_embeddings = vision_output[0] + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @check_model_inputs() + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SamImageSegmentationOutput: + r""" + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("facebook/sam-vit-base") + >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", + f" got {input_points.shape}.", + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + f" got {input_boxes.shape}.", + ) + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}." + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs) + image_embeddings = vision_outputs.last_hidden_state + vision_hidden_states = vision_outputs.hidden_states + vision_attentions = vision_outputs.attentions + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: + raise ValueError( + "The batch size of the image embeddings and the input points must be the same. ", + f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.", + " if you want to pass multiple points for the same image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + ) + + return SamImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + +__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam/modeling_tf_sam.py b/venv/lib/python3.13/site-packages/transformers/models/sam/modeling_tf_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..ac81288fa182b027752e16f0ff2525803f58a8b7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam/modeling_tf_sam.py @@ -0,0 +1,1723 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a +discrepancy, the original file should be regarded as the 'reference' version. +""" + +from __future__ import annotations + +import collections +from dataclasses import dataclass + +import numpy as np +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import TFBaseModelOutput +from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs +from ...tf_utils import flatten, functional_layernorm +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SamConfig" +_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" + + +@dataclass +class TFSamVisionEncoderOutput(ModelOutput): + """ + Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: tf.Tensor | None = None + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFSamImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: tf.Tensor | None = None + pred_masks: tf.Tensor | None = None + vision_hidden_states: tuple[tf.Tensor, ...] | None = None + vision_attentions: tuple[tf.Tensor, ...] | None = None + mask_decoder_attentions: tuple[tf.Tensor, ...] | None = None + + +class TFSamPatchEmbeddings(keras.layers.Layer): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = keras.layers.Conv2D( + hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" + ) + + def call(self, pixel_values): + batch_size, num_channels, height, width = shape_list(pixel_values) + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1])) + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + + +class TFSamMLPBlock(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1") + self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2") + self.act = ACT2FN[config.hidden_act] + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "lin1", None) is not None: + with tf.name_scope(self.lin1.name): + self.lin1.build([None, None, self.config.hidden_size]) + if getattr(self, "lin2", None) is not None: + with tf.name_scope(self.lin2.name): + self.lin2.build([None, None, self.config.mlp_dim]) + + +class TFSamLayerNorm(keras.layers.Layer): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(**kwargs) + self.eps = eps + self.data_format = data_format + self.normalized_shape = normalized_shape + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + + def build(self, input_shape): + self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight") + self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias") + super().build(input_shape) + + def call(self, x: tf.Tensor) -> tf.Tensor: + if self.data_format == "channels_last": + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1) + elif self.data_format == "channels_first": + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1) + return x + + +class TFSamAttention(keras.layers.Layer): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj") + self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj") + self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj") + self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj") + + def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor: + batch, point_batch_size, n_tokens, channel = shape_list(hidden_states) + c_per_head = channel // num_attention_heads + hidden_states = tf.reshape( + hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + ) + return tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + + def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor: + batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + return tf.reshape( + hidden_states, + (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head), + ) + + def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = shape_list(query)[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = shape_list(query) + attn = tf.matmul( + query, tf.transpose(key, perm=[0, 1, 3, 2]) + ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / tf.math.sqrt(float(c_per_head)) + attn = tf.nn.softmax(attn, axis=-1) + + # Get output + out = tf.matmul(attn, value) + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.hidden_size]) + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.hidden_size]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.internal_dim]) + + +class TFSamTwoWayAttentionBlock(keras.layers.Layer): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn") + self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1") + + self.cross_attn_token_to_image = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image" + ) + self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2") + + self.mlp = TFSamMLPBlock(config, name="mlp") + self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3") + + self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4") + self.cross_attn_image_to_token = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token" + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def call( + self, + queries: tf.Tensor, + keys: tf.Tensor, + query_point_embedding: tf.Tensor, + key_point_embedding: tf.Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build([None, None, None, self.hidden_size]) + if getattr(self, "cross_attn_token_to_image", None) is not None: + with tf.name_scope(self.cross_attn_token_to_image.name): + self.cross_attn_token_to_image.build(None) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build([None, None, None, self.hidden_size]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "layer_norm3", None) is not None: + with tf.name_scope(self.layer_norm3.name): + self.layer_norm3.build([None, None, None, self.hidden_size]) + if getattr(self, "layer_norm4", None) is not None: + with tf.name_scope(self.layer_norm4.name): + self.layer_norm4.build([None, None, None, self.hidden_size]) + if getattr(self, "cross_attn_image_to_token", None) is not None: + with tf.name_scope(self.cross_attn_image_to_token.name): + self.cross_attn_image_to_token.build(None) + + +class TFSamTwoWayTransformer(keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = [] + + for i in range(self.num_hidden_layers): + self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}")) + + self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") + self.layer_norm_final_attn = keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layer_norm_final_attn" + ) + + def call( + self, + point_embeddings: tf.Tensor, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | TFBaseModelOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None] + image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None] + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "final_attn_token_to_image", None) is not None: + with tf.name_scope(self.final_attn_token_to_image.name): + self.final_attn_token_to_image.build(None) + if getattr(self, "layer_norm_final_attn", None) is not None: + with tf.name_scope(self.layer_norm_final_attn.name): + self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size]) + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSamFeedForward(keras.layers.Layer): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs + ): + super().__init__(**kwargs) + self.num_layers = num_layers + self.activation = keras.layers.ReLU() + self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") + self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") + self.layers = [ + keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}") + for i in range(num_layers - 2) + ] + self.sigmoid_output = sigmoid_output + self.hidden_dim = hidden_dim + self.input_dim = input_dim + + def call(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = tf.sigmoid(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "proj_in", None) is not None: + with tf.name_scope(self.proj_in.name): + self.proj_in.build([None, None, self.input_dim]) + if getattr(self, "proj_out", None) is not None: + with tf.name_scope(self.proj_out.name): + self.proj_out.build([None, None, self.hidden_dim]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build([None, None, self.hidden_dim]) + + +class TFSamMaskDecoder(keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.transformer = TFSamTwoWayTransformer(config, name="transformer") + + self.upscale_conv1 = keras.layers.Conv2DTranspose( + self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first" + ) + self.upscale_conv2 = keras.layers.Conv2DTranspose( + self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first" + ) + self.upscale_layer_norm = TFSamLayerNorm( + self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" + ) + self.activation = tf.nn.gelu + + mlps_list = [] + for i in range(self.num_mask_tokens): + mlps_list += [ + TFSamFeedForward( + self.hidden_size, + self.hidden_size, + self.hidden_size // 8, + 3, + name=f"output_hypernetworks_mlps_._{i}", + ) + ] + self.output_hypernetworks_mlps = mlps_list + + self.iou_prediction_head = TFSamFeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + name="iou_prediction_head", + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True) + self.mask_tokens = self.add_weight( + shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True + ) + + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "upscale_conv1", None) is not None: + with tf.name_scope(self.upscale_conv1.name): + self.upscale_conv1.build([None, self.hidden_size, None, None]) + if getattr(self, "upscale_conv2", None) is not None: + with tf.name_scope(self.upscale_conv2.name): + self.upscale_conv2.build([None, self.hidden_size // 4, None, None]) + if getattr(self, "upscale_layer_norm", None) is not None: + with tf.name_scope(self.upscale_layer_norm.name): + self.upscale_layer_norm.build(None) + if getattr(self, "iou_prediction_head", None) is not None: + with tf.name_scope(self.iou_prediction_head.name): + self.iou_prediction_head.build(None) + for mlp in self.output_hypernetworks_mlps: + with tf.name_scope(mlp.name): + mlp.build(None) + + def call( + self, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + sparse_prompt_embeddings: tf.Tensor, + dense_prompt_embeddings: tf.Tensor, + multimask_output: bool, + output_attentions: bool | None = None, + ) -> tuple[tf.Tensor, tf.Tensor]: + batch_size, num_channels, height, width = shape_list(image_embeddings) + point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1]) + + output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32) + output_tokens = tf.tile( + output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1] + ) # Should be (batch_size, point_size, 5, 32) + + # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only + # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced + # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. + if shape_list(sparse_prompt_embeddings)[1] != 0: + tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) + else: + tokens = output_tokens + point_embeddings = tf.cast(tokens, self.iou_token.dtype) + + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0) + image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0) + + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2)) + image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = tf.stack(hyper_in_list, axis=2) + + _, num_channels, height, width = shape_list(upscaled_embedding) + upscaled_embedding = tf.reshape( + upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width] + ) + masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width]) + + iou_pred = self.iou_prediction_head(iou_token_out) + + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs = (masks, iou_pred) + + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +class TFSamPositionalEmbedding(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.scale = config.hidden_size // 2 + self.config = config + + def build(self, input_shape): + # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized? + self.positional_embedding = self.add_weight( + name="positional_embedding", + shape=(2, self.config.num_pos_feats), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), + trainable=False, + ) + super().build(input_shape) + + def call(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = tf.identity(input_coords) + + if input_shape is not None: + coordinates = tf.stack( + [ + tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], + tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0], + ], + axis=-1, + ) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = tf.cast(coordinates, self.positional_embedding.dtype) + coordinates = tf.matmul(coordinates, self.positional_embedding) + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) + + +class TFSamMaskEmbedding(keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, **kwargs): + super().__init__(**kwargs) + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") + self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2") + self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") + self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") + self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") + self.config = config + + def call(self, masks): + masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first + return dense_embeddings + + def build(self, input_shape=None): + # This class needs an explicit build method because it isn't called with the standard dummy inputs + if self.built: + return + self.built = True + with tf.name_scope("conv1"): + self.conv1.build([None, None, None, 1]) + with tf.name_scope("conv2"): + self.conv2.build([None, None, None, self.mask_input_channels]) + with tf.name_scope("conv3"): + self.conv3.build([None, None, None, self.mask_input_channels * 4]) + with tf.name_scope("layer_norm1"): + self.layer_norm1.build([None, None, None, self.mask_input_channels]) + with tf.name_scope("layer_norm2"): + self.layer_norm2.build([None, None, None, self.mask_input_channels * 4]) + + +class TFSamPromptEncoder(keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): + super().__init__(**kwargs) + self.shared_embedding = shared_patch_embedding + self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed") + self.no_mask_embed = None + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = [] + self.hidden_size = config.hidden_size + self.not_a_point_embed = None + self.config = config + + def build(self, input_shape=None): + self.no_mask_embed = self.add_weight( + name="no_mask_embed.weight", + shape=(1, self.hidden_size), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + self.point_embed = [ + self.add_weight( + name=f"point_embed_._{i}.weight", + shape=(1, self.hidden_size), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + for i in range(self.config.num_point_embeddings) + ] + self.not_a_point_embed = self.add_weight( + name="not_a_point_embed.weight", + shape=(1, self.hidden_size), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + with tf.name_scope("mask_embed"): + # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs + self.mask_embed.build( + (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size) + ) + + if self.built: + return + self.built = True + if getattr(self, "mask_embed", None) is not None: + with tf.name_scope(self.mask_embed.name): + self.mask_embed.build(None) + + def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1]) + target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1) + padding_point = tf.zeros(target_point_shape, dtype=points.dtype) + padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) + points = tf.concat([points, padding_point], axis=2) + labels = tf.concat([labels, padding_label], axis=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding) + + point_embedding = tf.where( + labels[..., None] != -10, + point_embedding, + tf.zeros_like(point_embedding), + ) + point_embedding = tf.where( + (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding + ) + point_embedding = tf.where( + (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding + ) + return point_embedding + + def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = shape_list(boxes)[:2] + coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding += tf.where( + tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0, + self.point_embed[2][0], + self.point_embed[3][0], + ) + return corner_embedding + + def call( + self, + batch_size: int | None, + input_points: tuple[tf.Tensor, tf.Tensor] | None, + input_labels: tf.Tensor | None, + input_boxes: tf.Tensor | None, + input_masks: tf.Tensor | None, + ) -> tuple[tf.Tensor, tf.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`tf.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`tf.Tensor`, *optional*): + boxes to embed + masks (`tf.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + if input_points is not None: + batch_size, point_batch_size = shape_list(input_points)[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = tf.zeros( + (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype + ) + sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) + if input_boxes is not None: + batch_size = shape_list(input_boxes)[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed[0] + dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1)) + dense_embeddings = tf.tile( + dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1]) + ) + if sparse_embeddings is None: + sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) + + return sparse_embeddings, dense_embeddings + + +class TFSamVisionAttention(keras.layers.Layer): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size, **kwargs): + super().__init__(**kwargs) + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + self.input_size = input_size + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.head_dim = head_dim + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv") + self.proj = keras.layers.Dense(config.hidden_size, name="proj") + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + self.config = config + + def build(self, input_shape=None): + if self.input_size is not None: + # initialize relative positional embeddings + self.rel_pos_h = self.add_weight( + shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h" + ) + self.rel_pos_w = self.add_weight( + shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w" + ) + + if self.built: + return + self.built = True + if getattr(self, "qkv", None) is not None: + with tf.name_scope(self.qkv.name): + self.qkv.build([None, None, self.config.hidden_size]) + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build([None, None, self.config.hidden_size]) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`tf.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = tf.image.resize( + tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)), + size=(max_rel_dist, rel_pos.shape[1]), + method="bilinear", + ) + rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist)) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0) + k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) + + def get_decomposed_rel_pos( + self, + query: tf.Tensor, + rel_pos_h: tf.Tensor, + rel_pos_w: tf.Tensor, + q_size: tuple[int, int], + k_size: tuple[int, int], + ) -> tf.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + query (`tf.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`tf.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`tf.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = shape_list(query) + reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) + rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + + rel_h = tf.expand_dims(rel_h, axis=-1) + rel_w = tf.expand_dims(rel_w, axis=-2) + decomposed_rel_pos = rel_h + rel_w + + return decomposed_rel_pos + + def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor: + batch_size, height, width, _ = shape_list(hidden_states) + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) + qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4)) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = tf.unstack( + tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0 + ) + attn_weights = tf.matmul(query * self.scale, key, transpose_b=True) + + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights)) + attn_weights = attn_weights + decomposed_rel_pos + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + + if training: + attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + else: + attn_probs = attn_weights + + attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1)) + attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4)) + attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size)) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +class TFSamVisionLayer(keras.layers.Layer): + def __init__(self, config, window_size, **kwargs): + super().__init__(**kwargs) + self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") + self.attn = TFSamVisionAttention(config, window_size, name="attn") + self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") + self.mlp = TFSamMLPBlock(config, name="mlp") + self.window_size = window_size + self.config = config + + def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> tuple[tf.Tensor, tuple[int, int]]: + batch_size, height, width, channel = shape_list(hidden_states) + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + if pad_h > 0 or pad_w > 0: + hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = tf.reshape( + hidden_states, + [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel], + ) + windows = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel] + ) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: tf.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int] + ) -> tf.Tensor: + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = tf.reshape( + windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] + ) + hidden_states = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1] + ) + + if pad_height > height or pad_width > width: + hidden_states = hidden_states[:, :height, :width, :] + return hidden_states + + def call( + self, + hidden_states: tf.Tensor, + output_attentions: bool | None = False, + training: bool | None = False, + ) -> tuple[tf.Tensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + training=training, + ) + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build([None, None, None, self.config.hidden_size]) + if getattr(self, "attn", None) is not None: + with tf.name_scope(self.attn.name): + self.attn.build(None) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build([None, None, None, self.config.hidden_size]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + + +class TFSamVisionNeck(keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.conv1 = keras.layers.Conv2D( + config.output_channels, + kernel_size=1, + use_bias=False, + name="conv1", + ) + self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") + self.conv2 = keras.layers.Conv2D( + config.output_channels, + kernel_size=3, + padding="same", + use_bias=False, + name="conv2", + ) + self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2") + + def call(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build([None, None, None, self.config.hidden_size]) + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build(None) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build([None, None, None, self.config.output_channels]) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build(None) + + +class TFSamVisionEncoder(keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.image_size = config.image_size + + self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed") + + self.pos_embed = None + + self.layers = [] + for i in range(config.num_hidden_layers): + layer = TFSamVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + name=f"layers_._{i}", + ) + self.layers.append(layer) + + self.neck = TFSamVisionNeck(config, name="neck") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = self.add_weight( + shape=[ + 1, + self.config.image_size // self.config.patch_size, + self.config.image_size // self.config.patch_size, + self.config.hidden_size, + ], + initializer="zeros", + trainable=True, + name="pos_embed", + ) + + if getattr(self, "patch_embed", None) is not None: + with tf.name_scope(self.patch_embed.name): + self.patch_embed.build(None) + if getattr(self, "neck", None) is not None: + with tf.name_scope(self.neck.name): + self.neck.build(None) + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + def get_input_embeddings(self): + return self.patch_embed + + def call( + self, + pixel_values: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFSamVisionEncoderOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return TFSamVisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class TFSamPreTrainedModel(TFPreTrainedModel): + config_class = SamConfig + base_model_prefix = "sam" + main_input_name = "pixel_values" + + +SAM_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) + subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`SamConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SAM_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second + dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per + input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size, + the number of boxes per image and the coordinates of the top left and bottom right point of the box. In the + order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `call` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +SAM_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The vision model from Sam without any head or projection on top.""", + SAM_START_DOCSTRING, +) +class TFSamVisionModel(TFSamPreTrainedModel): + config_class = SamVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(config, **kwargs) + self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vision_encoder", None) is not None: + with tf.name_scope(self.vision_encoder.name): + self.vision_encoder.build(None) + + def get_input_embeddings(self): + return self.vision_encoder.patch_embed + + @unpack_inputs + @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig) + def call( + self, + pixel_values: TFModelInputType | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> TFSamVisionEncoderOutput | tuple[tf.Tensor]: + r""" + Returns: + + """ + return self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + +@add_start_docstrings( + "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", + " optional 2D location and bounding boxes.", + SAM_START_DOCSTRING, +) +class TFSamModel(TFSamPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") + + self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder") + self.prompt_encoder = TFSamPromptEncoder( + config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder" + ) + self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") + self.config = config + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + grid = tf.ones((size, size)) + y_embed = tf.math.cumsum(grid, axis=0) - 0.5 + x_embed = tf.math.cumsum(grid, axis=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1)) + return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width + + def get_image_embeddings( + self, + pixel_values, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + return image_embeddings + + def get_prompt_embeddings( + self, + input_points: tf.Tensor | None = None, + input_labels: tf.Tensor | None = None, + input_boxes: tf.Tensor | None = None, + input_masks: tf.Tensor | None = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @unpack_inputs + @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) + def call( + self, + pixel_values: TFModelInputType | None = None, + input_points: tf.Tensor | None = None, + input_labels: tf.Tensor | None = None, + input_boxes: tf.Tensor | None = None, + input_masks: tf.Tensor | None = None, + image_embeddings: tf.Tensor | None = None, + multimask_output: bool = True, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> TFSamImageSegmentationOutput | tuple[tf.Tensor]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", + f" got {input_points.shape}.", + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + f" got {input_boxes.shape}.", + ) + if input_points is not None and input_boxes is not None: + point_batch_size = shape_list(input_points)[1] + box_batch_size = shape_list(input_boxes)[1] + if point_batch_size != box_batch_size: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}." + ) + if pixel_values is not None: + # Ensures that later checks pass even with an all-None shape from the serving signature + pixel_values = tf.ensure_shape( + pixel_values, + [ + None, + self.config.vision_config.num_channels, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ], + ) + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] + image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + training=training, + ) + image_embeddings = vision_outputs["last_hidden_state"] + + if output_hidden_states: + vision_hidden_states = vision_outputs["hidden_states"] + if output_attentions: + vision_attentions = vision_outputs["attentions"] + + if input_points is not None and input_labels is None: + input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: + raise ValueError( + "The batch size of the image embeddings and the input points must be the same. ", + f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.", + " if you want to pass multiple points for the same image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + batch_size=shape_list(image_embeddings)[0], + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return TFSamImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) + + def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput: + hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None + + return TFSamImageSegmentationOutput( + iou_scores=output.iou_scores, + pred_masks=output.pred_masks, + vision_hidden_states=hs if self.config.output_hidden_states else None, + vision_attentions=attns if self.config.output_attentions else None, + mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "shared_image_embedding", None) is not None: + with tf.name_scope(self.shared_image_embedding.name): + self.shared_image_embedding.build(None) + if getattr(self, "vision_encoder", None) is not None: + with tf.name_scope(self.vision_encoder.name): + self.vision_encoder.build(None) + if getattr(self, "prompt_encoder", None) is not None: + with tf.name_scope(self.prompt_encoder.name): + self.prompt_encoder.build(None) + if getattr(self, "mask_decoder", None) is not None: + with tf.name_scope(self.mask_decoder.name): + self.mask_decoder.build(None) + + +__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam/processing_sam.py b/venv/lib/python3.13/site-packages/transformers/models/sam/processing_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b3728fe273efc92adf57f119bce1eb899534b7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam/processing_sam.py @@ -0,0 +1,297 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SAM. +""" + +from copy import deepcopy +from typing import Optional, Union + +import numpy as np + +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput +from ...utils import is_tf_available, is_torch_available +from ...video_utils import VideoInput + + +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + + +class SamImagesKwargs(ImagesKwargs): + segmentation_maps: Optional[ImageInput] + input_points: Optional[list[list[float]]] + input_labels: Optional[list[list[int]]] + input_boxes: Optional[list[list[list[float]]]] + point_pad_value: Optional[int] + + +class SamProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: SamImagesKwargs + _defaults = { + "images_kwargs": { + "point_pad_value": -10, + } + } + + +class SamProcessor(ProcessorMixin): + r""" + Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of + [`~SamImageProcessor.__call__`] for more information. + + Args: + image_processor (`SamImageProcessor`): + An instance of [`SamImageProcessor`]. The image processor is a required input. + """ + + attributes = ["image_processor"] + image_processor_class = "SamImageProcessor" + + def __init__(self, image_processor): + super().__init__(image_processor) + self.target_size = self.image_processor.size["longest_edge"] + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + audio: Optional[AudioInput] = None, + video: Optional[VideoInput] = None, + **kwargs, + ) -> BatchEncoding: + """ + This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + """ + output_kwargs = self._merge_kwargs( + SamProcessorKwargs, + tokenizer_init_kwargs={}, + **kwargs, + ) + input_points = output_kwargs["images_kwargs"].pop("input_points", None) + input_labels = output_kwargs["images_kwargs"].pop("input_labels", None) + input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None) + point_pad_value = output_kwargs["images_kwargs"].pop("point_pad_value", None) + + encoding_image_processor = self.image_processor( + images, + **output_kwargs["images_kwargs"], + ) + + # pop arguments that are not used in the forward but used nevertheless + original_sizes = encoding_image_processor["original_sizes"] + + if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor + original_sizes = original_sizes.numpy() + + input_points, input_labels, input_boxes = self._check_and_preprocess_points( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + ) + + encoding_image_processor = self._normalize_and_convert( + encoding_image_processor, + original_sizes, + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + return_tensors=output_kwargs["common_kwargs"].get("return_tensors"), + point_pad_value=point_pad_value, + ) + + return encoding_image_processor + + def _normalize_and_convert( + self, + encoding_image_processor, + original_sizes, + input_points=None, + input_labels=None, + input_boxes=None, + return_tensors="pt", + point_pad_value=-10, + ): + if input_points is not None: + if len(original_sizes) != len(input_points): + input_points = [ + self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points + ] + else: + input_points = [ + self._normalize_coordinates(self.target_size, point, original_size) + for point, original_size in zip(input_points, original_sizes) + ] + # check that all arrays have the same shape + if not all(point.shape == input_points[0].shape for point in input_points): + if input_labels is not None: + input_points, input_labels = self._pad_points_and_labels( + input_points, input_labels, point_pad_value + ) + + input_points = np.array(input_points) + + if input_labels is not None: + input_labels = np.array(input_labels) + + if input_boxes is not None: + if len(original_sizes) != len(input_boxes): + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True) + for box in input_boxes + ] + else: + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True) + for box, original_size in zip(input_boxes, original_sizes) + ] + input_boxes = np.array(input_boxes) + + if input_boxes is not None: + if return_tensors == "pt": + input_boxes = torch.from_numpy(input_boxes) + # boxes batch size of 1 by default + input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes + elif return_tensors == "tf": + input_boxes = tf.convert_to_tensor(input_boxes) + # boxes batch size of 1 by default + input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes + encoding_image_processor.update({"input_boxes": input_boxes}) + if input_points is not None: + if return_tensors == "pt": + input_points = torch.from_numpy(input_points) + # point batch size of 1 by default + input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points + elif return_tensors == "tf": + input_points = tf.convert_to_tensor(input_points) + # point batch size of 1 by default + input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points + encoding_image_processor.update({"input_points": input_points}) + if input_labels is not None: + if return_tensors == "pt": + input_labels = torch.from_numpy(input_labels) + # point batch size of 1 by default + input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels + elif return_tensors == "tf": + input_labels = tf.convert_to_tensor(input_labels) + # point batch size of 1 by default + input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels + encoding_image_processor.update({"input_labels": input_labels}) + + return encoding_image_processor + + def _pad_points_and_labels(self, input_points, input_labels, point_pad_value): + r""" + The method pads the 2D points and labels to the maximum number of points in the batch. + """ + expected_nb_points = max(point.shape[0] for point in input_points) + processed_input_points = [] + for i, point in enumerate(input_points): + if point.shape[0] != expected_nb_points: + point = np.concatenate( + [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0 + ) + input_labels[i] = np.append(input_labels[i], [point_pad_value]) + processed_input_points.append(point) + input_points = processed_input_points + return input_points, input_labels + + def _normalize_coordinates( + self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size) + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _check_and_preprocess_points( + self, + input_points=None, + input_labels=None, + input_boxes=None, + ): + r""" + Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they + are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`, + it is converted to a `numpy.ndarray` and then to a `list`. + """ + if input_points is not None: + if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor + input_points = input_points.numpy().tolist() + + if not isinstance(input_points, list) or not isinstance(input_points[0], list): + raise ValueError("Input points must be a list of list of floating points.") + input_points = [np.array(input_point) for input_point in input_points] + else: + input_points = None + + if input_labels is not None: + if hasattr(input_labels, "numpy"): + input_labels = input_labels.numpy().tolist() + + if not isinstance(input_labels, list) or not isinstance(input_labels[0], list): + raise ValueError("Input labels must be a list of list integers.") + input_labels = [np.array(label) for label in input_labels] + else: + input_labels = None + + if input_boxes is not None: + if hasattr(input_boxes, "numpy"): + input_boxes = input_boxes.numpy().tolist() + + if ( + not isinstance(input_boxes, list) + or not isinstance(input_boxes[0], list) + or not isinstance(input_boxes[0][0], list) + ): + raise ValueError("Input boxes must be a list of list of list of floating points.") + input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] + else: + input_boxes = None + + return input_points, input_labels, input_boxes + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + return list(image_processor_input_names + ["original_sizes", "reshaped_input_sizes"]) + + def post_process_masks(self, *args, **kwargs): + return self.image_processor.post_process_masks(*args, **kwargs) + + +__all__ = ["SamProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38e6621b063af7a706186a1bf810e3f4709dc1b8 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sam2 import * + from .image_processing_sam2_fast import * + from .modeling_sam2 import * + from .processing_sam2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam2/configuration_sam2.py b/venv/lib/python3.13/site-packages/transformers/models/sam2/configuration_sam2.py new file mode 100644 index 0000000000000000000000000000000000000000..e14583181d38abfff63209160ab2a6eb5219a05f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam2/configuration_sam2.py @@ -0,0 +1,453 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAM2 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class Sam2HieraDetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Sam2HieraDetModel`]. It is used to instantiate + a HieraDet model as defined in the original sam2 repo according to the specified arguments, defining the model architecture. + Instantiating a configuration defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny + [facebook/sam2.1-hiera-tiny](https://huggingface.co/facebook/sam2.1-hiera-tiny) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 96): + The hidden dimension of the image encoder. + num_attention_heads (`int`, *optional*, defaults to 1): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of channels in the image. + image_size (`list[int]`, *optional*, defaults to `[1024, 1024]`): + The size of the image. + patch_kernel_size (`list[int]`, *optional*, defaults to `[7, 7]`): + The kernel size of the patch. + patch_stride (`list[int]`, *optional*, defaults to `[4, 4]`): + The stride of the patch. + patch_padding (`list[int]`, *optional*, defaults to `[3, 3]`): + The padding of the patch. + query_stride (`list[int]`, *optional*, defaults to `[2, 2]`): + The downsample stride between stages. + window_positional_embedding_background_size (`list[int]`, *optional*, defaults to `[7, 7]`): + The window size per stage when not using global attention. + num_query_pool_stages (`int`, *optional*, defaults to 3): + The number of query pool stages. + blocks_per_stage (`list[int]`, *optional*, defaults to `[1, 2, 7, 2]`): + The number of blocks per stage. + embed_dim_per_stage (`list[int]`, *optional*, defaults to `[96, 192, 384, 768]`): + The embedding dimension per stage. + num_attention_heads_per_stage (`list[int]`, *optional*, defaults to `[1, 2, 4, 8]`): + The number of attention heads per stage. + window_size_per_stage (`list[int]`, *optional*, defaults to `[8, 4, 14, 7]`): + The window size per stage. + global_attention_blocks (`list[int]`, *optional*, defaults to `[5, 7, 9]`): + The blocks where global attention is used. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the MLP hidden dimension to the embedding dimension. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the neck. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon for the layer normalization. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + """ + + base_config_key = "backbone_config" + model_type = "sam2_hiera_det_model" + + def __init__( + self, + hidden_size=96, + num_attention_heads=1, + num_channels=3, + image_size=None, + patch_kernel_size=None, + patch_stride=None, + patch_padding=None, + query_stride=None, + window_positional_embedding_background_size=None, + num_query_pool_stages=3, + blocks_per_stage=None, + embed_dim_per_stage=None, + num_attention_heads_per_stage=None, + window_size_per_stage=None, + global_attention_blocks=None, + mlp_ratio=4.0, + hidden_act="gelu", + layer_norm_eps=1e-6, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + image_size = image_size if image_size is not None else [1024, 1024] + patch_kernel_size = patch_kernel_size if patch_kernel_size is not None else [7, 7] + patch_stride = patch_stride if patch_stride is not None else [4, 4] + patch_padding = patch_padding if patch_padding is not None else [3, 3] + query_stride = query_stride if query_stride is not None else [2, 2] + window_positional_embedding_background_size = ( + window_positional_embedding_background_size + if window_positional_embedding_background_size is not None + else [7, 7] + ) + blocks_per_stage = blocks_per_stage if blocks_per_stage is not None else [1, 2, 7, 2] + embed_dim_per_stage = embed_dim_per_stage if embed_dim_per_stage is not None else [96, 192, 384, 768] + num_attention_heads_per_stage = ( + num_attention_heads_per_stage if num_attention_heads_per_stage is not None else [1, 2, 4, 8] + ) + window_size_per_stage = window_size_per_stage if window_size_per_stage is not None else [8, 4, 14, 7] + global_attention_blocks = global_attention_blocks if global_attention_blocks is not None else [5, 7, 9] + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_kernel_size = patch_kernel_size + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.query_stride = query_stride + self.window_positional_embedding_background_size = window_positional_embedding_background_size + self.num_query_pool_stages = num_query_pool_stages + self.blocks_per_stage = blocks_per_stage + self.embed_dim_per_stage = embed_dim_per_stage + self.num_attention_heads_per_stage = num_attention_heads_per_stage + self.window_size_per_stage = window_size_per_stage + self.global_attention_blocks = global_attention_blocks + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + + +class Sam2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Sam2VisionModel`]. It is used to instantiate a SAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny + [facebook/sam2.1-hiera-tiny](https://huggingface.co/facebook/sam2.1-hiera-tiny) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*): + Configuration for the vision backbone. This is used to instantiate the backbone using + `AutoModel.from_config`. + backbone_channel_list (`List[int]`, *optional*, defaults to `[768, 384, 192, 96]`): + The list of channel dimensions for the backbone. + backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`): + The spatial sizes of the feature maps from the backbone. + fpn_hidden_size (`int`, *optional*, defaults to 256): + The hidden dimension of the FPN. + fpn_kernel_size (`int`, *optional*, defaults to 1): + The kernel size for the convolutions in the neck. + fpn_stride (`int`, *optional*, defaults to 1): + The stride for the convolutions in the neck. + fpn_padding (`int`, *optional*, defaults to 0): + The padding for the convolutions in the neck. + fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`): + The levels for the top-down FPN connections. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of feature levels from the FPN to use. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the neck. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon for the layer normalization. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + """ + + base_config_key = "vision_config" + model_type = "sam2_vision_model" + sub_configs = { + "backbone_config": AutoConfig, + } + + def __init__( + self, + backbone_config=None, + backbone_channel_list=None, + backbone_feature_sizes=None, + fpn_hidden_size=256, + fpn_kernel_size=1, + fpn_stride=1, + fpn_padding=0, + fpn_top_down_levels=None, + num_feature_levels=3, + hidden_act="gelu", + layer_norm_eps=1e-6, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + backbone_channel_list = [768, 384, 192, 96] if backbone_channel_list is None else backbone_channel_list + backbone_feature_sizes = ( + [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes + ) + fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels + + if isinstance(backbone_config, dict): + backbone_config["model_type"] = backbone_config.get("model_type", "sam2_hiera_det_model") + backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config) + elif isinstance(backbone_config, Sam2HieraDetConfig): + pass + elif backbone_config is None: + backbone_config = Sam2HieraDetConfig() + + self.backbone_config = backbone_config + + # Neck + self.backbone_channel_list = backbone_channel_list + self.backbone_feature_sizes = backbone_feature_sizes + self.fpn_hidden_size = fpn_hidden_size + self.fpn_kernel_size = fpn_kernel_size + self.fpn_stride = fpn_stride + self.fpn_padding = fpn_padding + self.fpn_top_down_levels = fpn_top_down_levels + self.num_feature_levels = num_feature_levels + + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + + +class Sam2PromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Sam2PromptEncoder`]. The [`Sam2PromptEncoder`] + module is used to encode the input 2D points and bounding boxes. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + scale (`float`, *optional*, defaults to 1): + The scale factor for the prompt encoder. + """ + + base_config_key = "prompt_encoder_config" + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + scale=1, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.scale = scale + + +class Sam2MaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Sam2MaskDecoder`]. It is used to instantiate a SAM2 + memory encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the SAM2 mask decoder. + mlp_dim (`int`, *optional*, defaults to 2048): + The dimension of the MLP in the two-way transformer. + num_hidden_layers (`int`, *optional*, defaults to 2): + The number of hidden layers in the two-way transformer. + num_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads in the two-way transformer. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsample rate for the attention layers. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of multimask outputs. + iou_head_depth (`int`, *optional*, defaults to 3): + The depth of the IoU head. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the IoU head. + dynamic_multimask_via_stability (`bool`, *optional*, defaults to `True`): + Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (`float`, *optional*, defaults to 0.05): + The stability delta for the dynamic multimask. + dynamic_multimask_stability_thresh (`float`, *optional*, defaults to 0.98): + The stability threshold for the dynamic multimask. + + """ + + base_config_key = "mask_decoder_config" + + def __init__( + self, + hidden_size=256, + hidden_act="gelu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_multimask_outputs = num_multimask_outputs + self.hidden_act = hidden_act + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + # TwoWayTransformer configuration + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.mlp_dim = mlp_dim + self.attention_downsample_rate = attention_downsample_rate + + +class Sam2Config(PretrainedConfig): + r""" + [`Sam2Config`] is the configuration class to store the configuration of a [`Sam2Model`]. It is used to instantiate a + SAM2 model according to the specified arguments, defining the memory attention, memory encoder, and image encoder + configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny + [facebook/sam2.1-hiera-tiny](https://huggingface.co/facebook/sam2.1-hiera-tiny) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `Sam2VisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`Sam2VisionConfig`]. + prompt_encoder_config (Union[`dict`, `Sam2PromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`Sam2PromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `Sam2MaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`Sam2MaskDecoderConfig`]. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation for parameter initialization. + + Example: + + ```python + >>> from transformers import ( + ... Sam2VisionConfig, + ... Sam2PromptEncoderConfig, + ... Sam2MaskDecoderConfig, + ... Sam2Model, + ... ) + + >>> # Initializing a Sam2Config with `"facebook/sam2.1_hiera_tiny"` style configuration + >>> configuration = Sam2config() + + >>> # Initializing a Sam2Model (with random weights) from the `"facebook/sam2.1_hiera_tiny"` style configuration + >>> model = Sam2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Sam2Config from a Sam2VisionConfig, Sam2PromptEncoderConfig, and Sam2MaskDecoderConfig + + >>> # Initializing SAM2 vision encoder, memory attention, and memory encoder configurations + >>> vision_config = Sam2VisionConfig() + >>> prompt_encoder_config = Sam2PromptEncoderConfig() + >>> mask_decoder_config = Sam2MaskDecoderConfig() + + >>> config = Sam2Config(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "sam2" + sub_configs = { + "vision_config": AutoConfig, + "prompt_encoder_config": Sam2PromptEncoderConfig, + "mask_decoder_config": Sam2MaskDecoderConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + if isinstance(prompt_encoder_config, Sam2PromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, Sam2MaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = vision_config + self.prompt_encoder_config = Sam2PromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = Sam2MaskDecoderConfig(**mask_decoder_config) + + self.initializer_range = initializer_range + + +__all__ = [ + "Sam2Config", + "Sam2HieraDetConfig", + "Sam2VisionConfig", + "Sam2PromptEncoderConfig", + "Sam2MaskDecoderConfig", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam2/image_processing_sam2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/sam2/image_processing_sam2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..a773e8ad54d7b640f39a17c0274621e1e4d815ee --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam2/image_processing_sam2_fast.py @@ -0,0 +1,730 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam2/modular_sam2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from copy import deepcopy +from itertools import product +from typing import Any, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.ops.boxes import batched_nms + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + pil_torch_interpolation_mapping, +) +from ...processing_utils import Unpack +from ...utils import TensorType, auto_docstring + + +class Sam2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + mask_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width": int}` to resize the segmentation maps to. + """ + + mask_size: Optional[dict[str, int]] + + +def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + ) + unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + stability_scores = intersections / unions + return stability_scores + + +def _mask_to_rle(input_mask: "torch.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = input_mask.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()] + out.append({"size": [height, width], "counts": counts}) + return out + + +def _batched_mask_to_box(masks: "torch.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + # torch.max below raises an error on empty inputs, just skip in this case + + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + + +def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + + left, top, _, _ = crop_box + offset = torch.tensor([[left, top, left, top]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + boxes = (boxes + offset).float() + + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return torch.nn.functional.pad(masks, pad, value=0) + + +def _generate_crop_boxes( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, +) -> tuple[list[list[int]], list[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sam2ple per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sam2pled in layer n is scaled down by crop_n_points_downscale_factor**n. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + + if isinstance(image, list): + raise ValueError("Only one image is allowed for crop generation.") + original_size = image.shape[-2:] + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size + ) + crop_boxes = torch.tensor(crop_boxes) + crop_boxes = crop_boxes.float() + points_per_crop = torch.stack(point_grid_per_crop) + points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3) + cropped_images = torch.stack(cropped_images) + + input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + +def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): + """ + Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format + consists of the following required indices: + - X: X coordinate of the top left of the bounding box + - Y: Y coordinate of the top left of the bounding box + - W: width of the bounding box + - H: height of the bounding box + """ + crop_boxes, layer_idxs = [], [] + im_height, im_width = original_size + short_side = min(im_height, im_width) + + # Original image + crop_boxes.append([0, 0, im_width, im_height]) + layer_idxs.append(0) + for i_layer in range(crop_n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side)) + crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side)) + + crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)] + + for left, top in product(crop_box_x0, crop_box_y0): + box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def _build_point_grid(n_per_side: int) -> torch.Tensor: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = torch.linspace(offset, 1 - offset, n_per_side) + points_x = torch.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = torch.tile(points_one_side[:, None], (1, n_per_side)) + points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2) + return points + + +def _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None +): + """ + Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are + also passed. + """ + cropped_images = [] + total_points_per_crop = [] + for i, crop_box in enumerate(crop_boxes): + left, top, right, bottom = crop_box + cropped_im = image[:, top:bottom, left:right] + + cropped_images.append(cropped_im) + + cropped_im_size = cropped_im.shape[-2:] + points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0) + + points = points_grid[layer_idxs[i]] * points_scale + normalized_points = _normalize_coordinates(target_size, points, original_size) + total_points_per_crop.append(normalized_points) + + return cropped_images, total_points_per_crop + + +def _normalize_coordinates( + target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False +) -> torch.Tensor: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width) + format. + """ + old_height, old_width = original_size + + scale = target_size * 1.0 / max(old_height, old_width) + new_height, new_width = old_height * scale, old_width * scale + new_width = int(new_width + 0.5) + new_height = int(new_height + 0.5) + + coords = deepcopy(coords).float() + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_width / old_width) + coords[..., 1] = coords[..., 1] * (new_height / old_height) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + +def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor: + """Compute a binary mask from an uncompressed RLE.""" + height, width = rle["size"] + mask = torch.empty(height * width, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity = not parity + mask = mask.reshape(width, height) + return mask.transpose(0, 1) # Reshape to original shape + + +def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`torch.Tensor`): + binary masks in the RLE format + iou_scores (`torch.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`torch.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = batched_nms( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +@auto_docstring +class Sam2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 1024, "width": 1024} + mask_size = {"height": 256, "width": 256} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + valid_kwargs = Sam2FastImageProcessorKwargs + + # modular artefacts + do_pad = None + pad_size = None + mask_pad_size = None + + def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + mask_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if mask_size is not None: + mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + kwargs["size"] = size + kwargs["mask_size"] = mask_size + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["data_format"] = data_format + + # torch resize uses interpolation instead of resample + # Check if resample is an int before checking if it's an instance of PILImageResampling + # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. + # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + return kwargs + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[Sam2FastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[Sam2FastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + original_sizes = [image.shape[-2:] for image in images] + images_kwargs = kwargs.copy() + pixel_values = self._preprocess(images, **images_kwargs) + reshaped_input_sizes = [image.shape[-2:] for image in images] + data = { + "pixel_values": pixel_values, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST], + "size": segmentation_maps_kwargs.pop("mask_size"), + } + ) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) + + def generate_crop_boxes( + self, + image: "torch.Tensor", + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, + device: Optional["torch.device"] = None, + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`torch.Tensor`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sam2ple from each crop. + crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1): + The number of points-per-side sam2pled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*, defaults to None): + Device to use for the computation. If None, cpu will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + image = self._process_image(image) + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + ) + if device is None: + device = torch.device("cpu") + crop_boxes = crop_boxes.to(device) + points_per_crop = points_per_crop.to(device) + # cropped_images stays as torch.Tensor + input_labels = input_labels.to(device) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + def filter_masks( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`torch.Tensor`): + Input masks. + iou_scores (`torch.Tensor`): + List of IoU scores. + original_size (`tuple[int,int]`): + Size of the original image. + cropped_box_image (`torch.Tensor`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + original_height, original_width = original_size + iou_scores = iou_scores.flatten(0, 1) + masks = masks.flatten(0, 1) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the sam2e batch size.") + + if masks.device != iou_scores.device: + iou_scores = iou_scores.to(masks.device) + + batch_size = masks.shape[0] + + keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box(masks) + + keep_mask = ~_is_box_near_crop_edge( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppression + masks = _mask_to_rle(masks) + + return masks, scores, converted_boxes + + def post_process_masks( + self, + masks, + original_sizes, + mask_threshold=0.0, + binarize=True, + max_hole_area=0.0, + max_sprinkle_area=0.0, + apply_non_overlapping_constraints=False, + **kwargs, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold for binarization and post-processing operations. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + max_hole_area (`float`, *optional*, defaults to 0.0): + The maximum area of a hole to fill. + max_sprinkle_area (`float`, *optional*, defaults to 0.0): + The maximum area of a sprinkle to fill. + apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): + Whether to apply non-overlapping constraints to the masks. + + Returns: + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + if isinstance(original_sizes, (torch.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + # TODO: add connected components kernel for postprocessing + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = torch.from_numpy(masks[i]) + elif not isinstance(masks[i], torch.Tensor): + raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`") + interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False) + if apply_non_overlapping_constraints: + interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`torch.Tensor`): + List of all predicted segmentation masks + all_scores (`torch.Tensor`): + List of all predicted iou scores + all_boxes (`torch.Tensor`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + """ + return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh) + + def pad_image(self): + raise NotImplementedError("No pad_image for SAM 2.") + + def _preprocess( + self, + images: list["torch.Tensor"], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> "torch.Tensor": + return super()._preprocess(images, return_tensors=return_tensors, **kwargs).pixel_values + + def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor: + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks + + +__all__ = ["Sam2ImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam2/modeling_sam2.py b/venv/lib/python3.13/site-packages/transformers/models/sam2/modeling_sam2.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f5489fed0fd88264789ff4293a36f30074a422 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam2/modeling_sam2.py @@ -0,0 +1,1611 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam2/modular_sam2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from transformers.utils.generic import OutputRecorder + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ModelOutput, auto_docstring +from ...utils.generic import TransformersKwargs, check_model_inputs +from ..auto import AutoModel +from .configuration_sam2 import ( + Sam2Config, + Sam2HieraDetConfig, + Sam2MaskDecoderConfig, + Sam2PromptEncoderConfig, + Sam2VisionConfig, +) + + +@dataclass +@auto_docstring(custom_intro="Base class for the vision encoder's outputs.") +class Sam2VisionEncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + fpn_hidden_states (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. + fpn_position_encoding (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the + model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + fpn_hidden_states: Optional[torch.FloatTensor] = None + fpn_position_encoding: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring(custom_intro="Base class for the Sam2 model's output.") +class Sam2ImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`): + The Intersection over Union (IoU) scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): + The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed + by the processor to be brought to the original image size. + object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`): + Logits for the object score, indicating if an object is present. + image_embeddings (`tuple(torch.FloatTensor)`): + The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each + tensor has shape `(batch_size, channels, height, width)`. + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. + Hidden-states of the vision model at the output of each stage. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the vision model. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the mask decoder. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + object_score_logits: Optional[torch.FloatTensor] = None + image_embeddings: tuple[torch.FloatTensor, ...] = None + vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class Sam2PatchEmbeddings(nn.Module): + r""" + Turns pixel values into patch embeddings for transformer consumption. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details. + + Returns: + embeddings (`torch.FloatTensor`): + Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding + """ + + def __init__(self, config: Sam2HieraDetConfig): + super().__init__() + num_channels = config.num_channels + hidden_size = config.hidden_size + + self.projection = nn.Conv2d( + num_channels, + hidden_size, + kernel_size=config.patch_kernel_size, + stride=config.patch_stride, + padding=config.patch_padding, + ) + + def forward(self, pixel_values): + _, num_channels, height, width = pixel_values.shape + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class Sam2SinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + @compile_compatible_method_lru_cache(maxsize=1) + def forward( + self, + shape: torch.Size, + device: Union[torch.device, str], + dtype: torch.dtype, + mask: Optional[Tensor] = None, + ) -> Tensor: + if mask is None: + mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) + not_mask = (~mask).to(dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class Sam2VisionNeck(nn.Module): + def __init__(self, config: Sam2VisionConfig): + super().__init__() + self.config = config + + self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True) + self.convs = nn.ModuleList() + for in_channels in config.backbone_channel_list: + self.convs.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=config.fpn_hidden_size, + kernel_size=config.fpn_kernel_size, + stride=config.fpn_stride, + padding=config.fpn_padding, + ), + ) + self.fpn_top_down_levels = config.fpn_top_down_levels + + def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: + fpn_hidden_states = () + fpn_position_encoding = () + + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + lateral_features = hidden_states[i].permute(0, 3, 1, 2) + lateral_features = self.convs[n - i](lateral_features) + if i not in self.fpn_top_down_levels or i == n: + prev_features = lateral_features + else: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode="nearest", + align_corners=None, + antialias=False, + ).to(lateral_features.dtype) + prev_features = lateral_features + top_down_features + + prev_position_encoding = self.position_encoding( + prev_features.shape, prev_features.device, prev_features.dtype + ).to(prev_features.dtype) + + fpn_hidden_states += (prev_features,) + fpn_position_encoding += (prev_position_encoding,) + + return fpn_hidden_states, fpn_position_encoding + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def do_pool(x: torch.Tensor, query_stride: Optional[int] = None) -> torch.Tensor: + if query_stride is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + return x + + +class Sam2MultiScaleAttention(nn.Module): + def __init__( + self, + config: Sam2HieraDetConfig, + dim: int, + dim_out: int, + num_attention_heads: int, + query_stride: Optional[tuple[int, int]] = None, + ): + super().__init__() + + self.config = config + + self.dim = dim + self.dim_out = dim_out + self.query_stride = query_stride + + self.num_attention_heads = num_attention_heads + head_dim = dim_out // num_attention_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + self.is_causal = False + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + query, key, value = torch.unbind(qkv, 2) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + # Q pooling (for downsample at stage changes) + if self.query_stride: + query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride) + height, width = query.shape[1:3] # downsampled shape + query = query.reshape(batch_size, height * width, self.num_attention_heads, -1) + + # transpose query, key, value to (B, nHead, H * W, C) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + is_causal=self.is_causal, + scaling=self.scale, + **kwargs, + ) + attn_output = attn_output.reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + return attn_output + + +class Sam2FeedForward(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: str = "relu", + sigmoid_output: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.activation = ACT2FN[activation] + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +def window_partition(hidden_state, window_size): + """ + Partition into non-overlapping windows with padding if needed. + + Args: + hidden_state (`torch.Tensor`): + Input tokens with [batch_size, height, width, num_channels]. + window_size (`int`): + Window size. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements: + - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels]. + - (padded_height, padded_width): padded height and width before partition + """ + batch_size, height, width, num_channels = hidden_state.shape + + pad_height = (window_size - height % window_size) % window_size + pad_width = (window_size - width % window_size) % window_size + + # Noop in case pad_width == 0 and pad_height == 0. + hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height)) + + padded_height, padded_width = height + pad_height, width + pad_width + + hidden_state = hidden_state.view( + batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels + ) + windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows, (padded_height, padded_width) + + +def window_unpartition(windows, window_size, pad_height_width, height_width): + """ + Window unpartition into original sequences and removing padding. + + Args: + windows (`torch.Tensor`): + Input tokens with [batch_size * num_windows, window_size, window_size, num_channels]. + window_size (`int`): + Window size. + pad_height_width (`tuple[int]`): + Padded height and width (padded_height, padded_width). + height_width (`tuple[int]`): + Original height and width before padding. + + Returns: + hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels]. + """ + padded_height, padded_width = pad_height_width + height, width = height_width + batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size) + hidden_state = windows.view( + batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1 + ) + hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous() + hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1) + + # We always have height <= padded_height and width <= padded_width + hidden_state = hidden_state[:, :height, :width, :].contiguous() + return hidden_state + + +class Sam2MultiScaleBlock(GradientCheckpointingLayer): + def __init__( + self, + config: Sam2HieraDetConfig, + stage_idx: int, + block_idx: int, + total_block_idx: int, + ): + super().__init__() + + # take embed dim from previous stage if first block of stage + self.dim = ( + config.embed_dim_per_stage[stage_idx - 1] + if stage_idx > 0 and block_idx == 0 + else config.embed_dim_per_stage[stage_idx] + ) + self.dim_out = config.embed_dim_per_stage[stage_idx] + self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) + # take window size from previous stage if first block of stage + self.window_size = ( + config.window_size_per_stage[stage_idx - 1] + if stage_idx > 0 and block_idx == 0 + else config.window_size_per_stage[stage_idx] + ) + self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size + # use query stride for first block of stage if stage is a query pool stage + self.query_stride = ( + config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None + ) + + self.attn = Sam2MultiScaleAttention( + config, + self.dim, + self.dim_out, + num_attention_heads=config.num_attention_heads_per_stage[stage_idx], + query_stride=self.query_stride, + ) + self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps) + self.mlp = Sam2FeedForward( + self.dim_out, + int(self.dim_out * config.mlp_ratio), + self.dim_out, + num_layers=2, + activation=config.hidden_act, + ) + if self.dim != self.dim_out: + self.proj = nn.Linear(self.dim, self.dim_out) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states # batch_size, height, width, channel + + hidden_states = self.layer_norm1(hidden_states) + + # Skip connection + if self.dim != self.dim_out: + residual = do_pool(self.proj(hidden_states), self.query_stride) + + # Window partition + window_size = self.window_size + if self.window_size > 0: + H, W = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, pad_hw = window_partition(hidden_states, window_size) + + # Window Attention + Q Pooling (if stage change) + attn_output = self.attn( + hidden_states=hidden_states, + **kwargs, + ) + hidden_states = attn_output + if self.query_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.query_stride[0] + H, W = residual.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Hiera model's outputs that also contains a pooling of the last hidden states. + """ +) +class Sam2HieraDetModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): + hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the intermediate layers of the model. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@auto_docstring +class Sam2PreTrainedModel(PreTrainedModel): + config_class = Sam2Config + base_model_prefix = "sam2" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if isinstance(module, Sam2HieraDetModel): + if module.pos_embed is not None: + module.pos_embed.data.zero_() + if module.pos_embed_window is not None: + module.pos_embed_window.data.zero_() + if isinstance(module, Sam2Model): + if module.no_memory_embedding is not None: + module.no_memory_embedding.data.zero_() + + +class Sam2HieraDetModel(Sam2PreTrainedModel): + config_class = Sam2HieraDetConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Sam2MultiScaleBlock, + "attentions": Sam2MultiScaleAttention, + } + + def __init__(self, config: Sam2HieraDetConfig): + super().__init__(config) + + self.patch_embed = Sam2PatchEmbeddings(config) + # Windowed positional embedding (https://huggingface.co/papers/2311.05613) + self.pos_embed = nn.Parameter( + torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0]) + ) + self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist() + self.blocks = nn.ModuleList() + total_block_idx = 0 + for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage): + for block_idx in range(blocks_per_stage): + block = Sam2MultiScaleBlock( + config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx + ) + self.blocks.append(block) + total_block_idx += 1 + + def get_input_embeddings(self): + return self.patch_embed + + def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + @check_model_inputs() + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Sam2HieraDetModelOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3]) + + intermediate_hidden_states = () + for i, block_module in enumerate(self.blocks): + hidden_states = block_module(hidden_states, **kwargs) + + if i in self.stage_ends: + intermediate_hidden_states = intermediate_hidden_states + (hidden_states,) + + return Sam2HieraDetModelOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The vision model from Sam without any head or projection on top. + """ +) +class Sam2VisionModel(Sam2PreTrainedModel): + config_class = Sam2VisionConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Sam2MultiScaleBlock, + "attentions": Sam2MultiScaleAttention, + } + + def __init__(self, config: Sam2VisionConfig): + super().__init__(config) + self.config = config + + self.backbone = AutoModel.from_config(config.backbone_config) + + self.neck = Sam2VisionNeck(config) + self.num_feature_levels = config.num_feature_levels + + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + @check_model_inputs() + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Sam2VisionEncoderOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Forward through backbone + backbone_output = self.backbone(pixel_values, **kwargs) + hidden_states = backbone_output.last_hidden_state + intermediate_hidden_states = backbone_output.intermediate_hidden_states + + fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) + # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution + fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] + fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] + + return Sam2VisionEncoderOutput( + last_hidden_state=hidden_states, + fpn_hidden_states=fpn_hidden_states, + fpn_position_encoding=fpn_position_encoding, + ) + + +class Sam2PositionalEmbedding(nn.Module): + def __init__(self, config: Sam2PromptEncoderConfig): + super().__init__() + self.scale = config.scale + positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2)) + self.register_buffer("positional_embedding", positional_embedding) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + coordinates.to(torch.float32) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class Sam2MaskEmbedding(nn.Module): + def __init__(self, config: Sam2PromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = Sam2LayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = Sam2LayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class Sam2PromptEncoder(nn.Module): + def __init__(self, config: Sam2PromptEncoderConfig): + super().__init__() + self.shared_embedding = Sam2PositionalEmbedding(config) + self.mask_embed = Sam2MaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) + self.input_image_size = config.image_size + + self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0) + labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.zeros_like(point_embedding), + ) + + # Add point embeddings for labels >= 0 + point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes += 0.5 # Shift to center of pixel + coords = boxes.view(*boxes.shape[:2], 2, 2) + # add padding point for consistency with the original implementation + coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0) + corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) + corner_embedding[:, :, 0, :] += self.point_embed.weight[2] + corner_embedding[:, :, 1, :] += self.point_embed.weight[3] + corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) + return corner_embedding + + def forward( + self, + input_points: Optional[tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size = input_points.shape[0] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class Sam2Attention(nn.Module): + """ + SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + self.config = config + self.hidden_size = config.hidden_size + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.internal_dim // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_similarity: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class Sam2TwoWayAttentionBlock(nn.Module): + def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`Sam2MaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + self.self_attn = Sam2Attention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(config.hidden_size) + + self.cross_attn_token_to_image = Sam2Attention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size) + + self.mlp = Sam2FeedForward( + config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers + ) + self.layer_norm3 = nn.LayerNorm(config.hidden_size) + + self.layer_norm4 = nn.LayerNorm(config.hidden_size) + self.cross_attn_image_to_token = Sam2Attention(config) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + # Self attention block + if self.skip_first_layer_pe: + queries, _ = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out, _ = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + return queries, keys, attn_out + + +class Sam2TwoWayTransformer(nn.Module): + def __init__(self, config: Sam2MaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(Sam2TwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = Sam2Attention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, _ = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + **kwargs, + ) + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys + + +class Sam2LayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +class Sam2MaskDecoder(nn.Module): + def __init__(self, config: Sam2MaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = Sam2TwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = Sam2LayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [Sam2FeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + self.iou_prediction_head = Sam2FeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + sigmoid_output=True, + ) + + self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) + + self.obj_score_token = nn.Embedding(1, self.hidden_size) + self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3) + + self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + high_resolution_features: list[torch.Tensor], + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embeddings (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (`bool`): + Whether to return multiple masks or a single mask. + high_resolution_features (`list[torch.Tensor]`, *optional*): + The high-resolution features from the vision encoder. + attention_similarity (`torch.Tensor`, *optional*): + The attention similarity tensor. + target_embedding (`torch.Tensor`, *optional*): + The target embedding. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.shape[0] != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-mask + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + # Run the transformer + point_embeddings, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + iou_token_out = point_embeddings[:, :, 1, :] + mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).view( + batch_size * point_batch_size, num_channels, height, width + ) + + feat_s0, feat_s1 = high_resolution_features + feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) + feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) + upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) + + hyper_in_list: list[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + elif self.dynamic_multimask_via_stability and not self.training: + mask_slice = slice(0, 1) + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape + + return masks, iou_pred, sam_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, :, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, :, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P] + best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + best_scores_inds_expanded = best_scores_inds_expanded.expand( + -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1) + ) + best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] + best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, :, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, :, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + +@auto_docstring( + custom_intro=""" + Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and + input points and labels, boxes, or masks. + """ +) +class Sam2Model(Sam2PreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)} + _keys_to_ignore_on_load_unexpected = [ + r"^memory_.*", + r"^mask_downsample.*", + r"^object_pointer_proj.*", + r"^temporal_positional_encoding_projection_layer.*", + "no_memory_positional_encoding", + "no_object_pointer", + "occlusion_spatial_embedding_parameter", + ] + + def __init__(self, config: Sam2Config): + super().__init__(config) + self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config) + + self.num_feature_levels = config.vision_config.num_feature_levels + self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes + # a single token to indicate no memory embedding from previous frames + self.hidden_dim = config.vision_config.fpn_hidden_size + self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + + self.post_init() + + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self) -> torch.Tensor: + size = self.prompt_encoder.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones(size, device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size[0] + x_embed = x_embed / size[1] + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> list[torch.Tensor]: + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + batch_size = pixel_values.shape[0] + feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @check_model_inputs() + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Sam2ImageSegmentationOutput: + r""" + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny") + >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + if not ((pixel_values is None) ^ (image_embeddings is None)): + raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") + if input_points is not None and input_boxes is not None: + if input_points.shape[1] != input_boxes.shape[1]: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( + pixel_values, + **kwargs, + ) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is None and input_boxes is None: + # If no points are provide, pad with an empty point (with label -1) + input_points = torch.zeros( + batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device + ) + input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device) + + if input_masks is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: + input_masks = F.interpolate( + input_masks.float(), + size=self.prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ).to(input_masks.dtype) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder( + image_embeddings=image_embeddings[-1], + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + high_resolution_features=image_embeddings[:-1], + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + + return Sam2ImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_multimasks, + object_score_logits=object_score_logits, + image_embeddings=image_embeddings, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ + list[torch.Tensor], + list[torch.Tensor], + Optional[tuple[torch.FloatTensor, ...]], + Optional[tuple[torch.FloatTensor, ...]], + ]: + r""" + Extract and preprocess image features using the vision encoder. + + Args: + pixel_values (`torch.FloatTensor`): + Input pixel values of shape `(batch_size, num_channels, height, width)`. + + Returns: + `tuple`: A tuple containing: + - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels. + - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level. + - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder. + - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder. + """ + vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder( + pixel_values, + **kwargs, + ) + + feature_maps = vision_outputs.fpn_hidden_states + feature_maps_position_embeddings = vision_outputs.fpn_position_encoding + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + feature_maps = list(feature_maps) + feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) + feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) + + # flatten NxCxHxW to HWxNxC + feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] + feature_maps_position_embeddings = [ + feature_map_position_embedding.flatten(2).permute(2, 0, 1) + for feature_map_position_embedding in feature_maps_position_embeddings + ] + + return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions + + +__all__ = ["Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2HieraDetModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam2/modular_sam2.py b/venv/lib/python3.13/site-packages/transformers/models/sam2/modular_sam2.py new file mode 100644 index 0000000000000000000000000000000000000000..790ddc9439ec59b44c3168a98d816d9d190674ad --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam2/modular_sam2.py @@ -0,0 +1,1463 @@ +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SAM 2 model.""" + +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + pil_torch_interpolation_mapping, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + TensorType, + auto_docstring, + logging, +) +from ...utils.generic import TransformersKwargs, check_model_inputs +from ..auto import AutoModel +from ..maskformer.modeling_maskformer import MaskFormerSinePositionEmbedding +from ..sam.image_processing_sam_fast import SamImageProcessorFast +from ..sam.modeling_sam import ( + SamLayerNorm, + SamMaskDecoder, + SamMaskEmbedding, + SamModel, + SamPromptEncoder, + SamTwoWayAttentionBlock, + SamTwoWayTransformer, + eager_attention_forward, +) +from ..vitdet.modeling_vitdet import window_partition, window_unpartition +from .configuration_sam2 import ( + Sam2Config, + Sam2HieraDetConfig, + Sam2MaskDecoderConfig, + Sam2PromptEncoderConfig, + Sam2VisionConfig, +) + + +logger = logging.get_logger(__name__) + + +class Sam2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + mask_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width": int}` to resize the segmentation maps to. + """ + + mask_size: Optional[dict[str, int]] + + +@auto_docstring +class Sam2ImageProcessorFast(SamImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 1024, "width": 1024} + mask_size = {"height": 256, "width": 256} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + valid_kwargs = Sam2FastImageProcessorKwargs + + # modular artefacts + do_pad = None + pad_size = None + mask_pad_size = None + + def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]): + BaseImageProcessorFast.__init__(self, **kwargs) + + def pad_image(self): + raise NotImplementedError("No pad_image for SAM 2.") + + def _get_preprocess_shape(self): + raise NotImplementedError("No _get_preprocess_shape for SAM 2.") + + def resize(self): + raise NotImplementedError("No need to override resize for SAM 2.") + + def _preprocess( + self, + images: list["torch.Tensor"], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> "torch.Tensor": + return BaseImageProcessorFast._preprocess(self, images, return_tensors=return_tensors, **kwargs).pixel_values + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[Sam2FastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + original_sizes = [image.shape[-2:] for image in images] + images_kwargs = kwargs.copy() + pixel_values = self._preprocess(images, **images_kwargs) + reshaped_input_sizes = [image.shape[-2:] for image in images] + data = { + "pixel_values": pixel_values, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST], + "size": segmentation_maps_kwargs.pop("mask_size"), + } + ) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + mask_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if mask_size is not None: + mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + kwargs["size"] = size + kwargs["mask_size"] = mask_size + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["data_format"] = data_format + + # torch resize uses interpolation instead of resample + # Check if resample is an int before checking if it's an instance of PILImageResampling + # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. + # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + return kwargs + + def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor: + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks + + def post_process_masks( + self, + masks, + original_sizes, + mask_threshold=0.0, + binarize=True, + max_hole_area=0.0, + max_sprinkle_area=0.0, + apply_non_overlapping_constraints=False, + **kwargs, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold for binarization and post-processing operations. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + max_hole_area (`float`, *optional*, defaults to 0.0): + The maximum area of a hole to fill. + max_sprinkle_area (`float`, *optional*, defaults to 0.0): + The maximum area of a sprinkle to fill. + apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): + Whether to apply non-overlapping constraints to the masks. + + Returns: + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + if isinstance(original_sizes, (torch.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + # TODO: add connected components kernel for postprocessing + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = torch.from_numpy(masks[i]) + elif not isinstance(masks[i], torch.Tensor): + raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`") + interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False) + if apply_non_overlapping_constraints: + interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + +@dataclass +@auto_docstring(custom_intro="Base class for the vision encoder's outputs.") +class Sam2VisionEncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + fpn_hidden_states (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. + fpn_position_encoding (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the + model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + fpn_hidden_states: Optional[torch.FloatTensor] = None + fpn_position_encoding: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring(custom_intro="Base class for the Sam2 model's output.") +class Sam2ImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`): + The Intersection over Union (IoU) scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): + The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed + by the processor to be brought to the original image size. + object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`): + Logits for the object score, indicating if an object is present. + image_embeddings (`tuple(torch.FloatTensor)`): + The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each + tensor has shape `(batch_size, channels, height, width)`. + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. + Hidden-states of the vision model at the output of each stage. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the vision model. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the mask decoder. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + object_score_logits: Optional[torch.FloatTensor] = None + image_embeddings: tuple[torch.FloatTensor, ...] = None + vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class Sam2PatchEmbeddings(nn.Module): + r""" + Turns pixel values into patch embeddings for transformer consumption. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details. + + Returns: + embeddings (`torch.FloatTensor`): + Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding + """ + + def __init__(self, config: Sam2HieraDetConfig): + super().__init__() + num_channels = config.num_channels + hidden_size = config.hidden_size + + self.projection = nn.Conv2d( + num_channels, + hidden_size, + kernel_size=config.patch_kernel_size, + stride=config.patch_stride, + padding=config.patch_padding, + ) + + def forward(self, pixel_values): + _, num_channels, height, width = pixel_values.shape + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class Sam2SinePositionEmbedding(MaskFormerSinePositionEmbedding): + pass + + +class Sam2VisionNeck(nn.Module): + def __init__(self, config: Sam2VisionConfig): + super().__init__() + self.config = config + + self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True) + self.convs = nn.ModuleList() + for in_channels in config.backbone_channel_list: + self.convs.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=config.fpn_hidden_size, + kernel_size=config.fpn_kernel_size, + stride=config.fpn_stride, + padding=config.fpn_padding, + ), + ) + self.fpn_top_down_levels = config.fpn_top_down_levels + + def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: + fpn_hidden_states = () + fpn_position_encoding = () + + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + lateral_features = hidden_states[i].permute(0, 3, 1, 2) + lateral_features = self.convs[n - i](lateral_features) + if i not in self.fpn_top_down_levels or i == n: + prev_features = lateral_features + else: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode="nearest", + align_corners=None, + antialias=False, + ).to(lateral_features.dtype) + prev_features = lateral_features + top_down_features + + prev_position_encoding = self.position_encoding( + prev_features.shape, prev_features.device, prev_features.dtype + ).to(prev_features.dtype) + + fpn_hidden_states += (prev_features,) + fpn_position_encoding += (prev_position_encoding,) + + return fpn_hidden_states, fpn_position_encoding + + +def do_pool(x: torch.Tensor, query_stride: Optional[int] = None) -> torch.Tensor: + if query_stride is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + return x + + +class Sam2MultiScaleAttention(nn.Module): + def __init__( + self, + config: Sam2HieraDetConfig, + dim: int, + dim_out: int, + num_attention_heads: int, + query_stride: Optional[tuple[int, int]] = None, + ): + super().__init__() + + self.config = config + + self.dim = dim + self.dim_out = dim_out + self.query_stride = query_stride + + self.num_attention_heads = num_attention_heads + head_dim = dim_out // num_attention_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + self.is_causal = False + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + query, key, value = torch.unbind(qkv, 2) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + # Q pooling (for downsample at stage changes) + if self.query_stride: + query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride) + height, width = query.shape[1:3] # downsampled shape + query = query.reshape(batch_size, height * width, self.num_attention_heads, -1) + + # transpose query, key, value to (B, nHead, H * W, C) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + is_causal=self.is_causal, + scaling=self.scale, + **kwargs, + ) + attn_output = attn_output.reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + return attn_output + + +class Sam2FeedForward(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: str = "relu", + sigmoid_output: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.activation = ACT2FN[activation] + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +class Sam2MultiScaleBlock(GradientCheckpointingLayer): + def __init__( + self, + config: Sam2HieraDetConfig, + stage_idx: int, + block_idx: int, + total_block_idx: int, + ): + super().__init__() + + # take embed dim from previous stage if first block of stage + self.dim = ( + config.embed_dim_per_stage[stage_idx - 1] + if stage_idx > 0 and block_idx == 0 + else config.embed_dim_per_stage[stage_idx] + ) + self.dim_out = config.embed_dim_per_stage[stage_idx] + self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) + # take window size from previous stage if first block of stage + self.window_size = ( + config.window_size_per_stage[stage_idx - 1] + if stage_idx > 0 and block_idx == 0 + else config.window_size_per_stage[stage_idx] + ) + self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size + # use query stride for first block of stage if stage is a query pool stage + self.query_stride = ( + config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None + ) + + self.attn = Sam2MultiScaleAttention( + config, + self.dim, + self.dim_out, + num_attention_heads=config.num_attention_heads_per_stage[stage_idx], + query_stride=self.query_stride, + ) + self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps) + self.mlp = Sam2FeedForward( + self.dim_out, + int(self.dim_out * config.mlp_ratio), + self.dim_out, + num_layers=2, + activation=config.hidden_act, + ) + if self.dim != self.dim_out: + self.proj = nn.Linear(self.dim, self.dim_out) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states # batch_size, height, width, channel + + hidden_states = self.layer_norm1(hidden_states) + + # Skip connection + if self.dim != self.dim_out: + residual = do_pool(self.proj(hidden_states), self.query_stride) + + # Window partition + window_size = self.window_size + if self.window_size > 0: + H, W = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, pad_hw = window_partition(hidden_states, window_size) + + # Window Attention + Q Pooling (if stage change) + attn_output = self.attn( + hidden_states=hidden_states, + **kwargs, + ) + hidden_states = attn_output + if self.query_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.query_stride[0] + H, W = residual.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Hiera model's outputs that also contains a pooling of the last hidden states. + """ +) +class Sam2HieraDetModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): + hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the intermediate layers of the model. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@auto_docstring +class Sam2PreTrainedModel(PreTrainedModel): + config_class = Sam2Config + base_model_prefix = "sam2" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if isinstance(module, Sam2HieraDetModel): + if module.pos_embed is not None: + module.pos_embed.data.zero_() + if module.pos_embed_window is not None: + module.pos_embed_window.data.zero_() + if isinstance(module, Sam2Model): + if module.no_memory_embedding is not None: + module.no_memory_embedding.data.zero_() + + +class Sam2HieraDetModel(Sam2PreTrainedModel): + config_class = Sam2HieraDetConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Sam2MultiScaleBlock, + "attentions": Sam2MultiScaleAttention, + } + + def __init__(self, config: Sam2HieraDetConfig): + super().__init__(config) + + self.patch_embed = Sam2PatchEmbeddings(config) + # Windowed positional embedding (https://huggingface.co/papers/2311.05613) + self.pos_embed = nn.Parameter( + torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0]) + ) + self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist() + self.blocks = nn.ModuleList() + total_block_idx = 0 + for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage): + for block_idx in range(blocks_per_stage): + block = Sam2MultiScaleBlock( + config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx + ) + self.blocks.append(block) + total_block_idx += 1 + + def get_input_embeddings(self): + return self.patch_embed + + def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + @check_model_inputs() + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Sam2HieraDetModelOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3]) + + intermediate_hidden_states = () + for i, block_module in enumerate(self.blocks): + hidden_states = block_module(hidden_states, **kwargs) + + if i in self.stage_ends: + intermediate_hidden_states = intermediate_hidden_states + (hidden_states,) + + return Sam2HieraDetModelOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The vision model from Sam without any head or projection on top. + """ +) +class Sam2VisionModel(Sam2PreTrainedModel): + config_class = Sam2VisionConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Sam2MultiScaleBlock, + "attentions": Sam2MultiScaleAttention, + } + + def __init__(self, config: Sam2VisionConfig): + super().__init__(config) + self.config = config + + self.backbone = AutoModel.from_config(config.backbone_config) + + self.neck = Sam2VisionNeck(config) + self.num_feature_levels = config.num_feature_levels + + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + @check_model_inputs() + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Sam2VisionEncoderOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Forward through backbone + backbone_output = self.backbone(pixel_values, **kwargs) + hidden_states = backbone_output.last_hidden_state + intermediate_hidden_states = backbone_output.intermediate_hidden_states + + fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) + # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution + fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] + fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] + + return Sam2VisionEncoderOutput( + last_hidden_state=hidden_states, + fpn_hidden_states=fpn_hidden_states, + fpn_position_encoding=fpn_position_encoding, + ) + + +class Sam2PositionalEmbedding(nn.Module): + def __init__(self, config: Sam2PromptEncoderConfig): + super().__init__() + self.scale = config.scale + positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2)) + self.register_buffer("positional_embedding", positional_embedding) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + coordinates.to(torch.float32) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class Sam2MaskEmbedding(SamMaskEmbedding): + pass + + +class Sam2PromptEncoder(SamPromptEncoder): + def __init__(self, config: Sam2PromptEncoderConfig): + nn.Module.__init__(self) + self.shared_embedding = Sam2PositionalEmbedding(config) + self.mask_embed = Sam2MaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) + self.input_image_size = config.image_size + + self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0) + labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.zeros_like(point_embedding), + ) + + # Add point embeddings for labels >= 0 + point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes += 0.5 # Shift to center of pixel + coords = boxes.view(*boxes.shape[:2], 2, 2) + # add padding point for consistency with the original implementation + coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0) + corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) + corner_embedding[:, :, 0, :] += self.point_embed.weight[2] + corner_embedding[:, :, 1, :] += self.point_embed.weight[3] + corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) + return corner_embedding + + +class Sam2Attention(nn.Module): + """ + SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + self.config = config + self.hidden_size = config.hidden_size + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.internal_dim // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_similarity: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class Sam2TwoWayAttentionBlock(SamTwoWayAttentionBlock, GradientCheckpointingLayer): + def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False): + nn.Module.__init__(self) + self.self_attn = Sam2Attention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(config.hidden_size) + + self.cross_attn_token_to_image = Sam2Attention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size) + + self.mlp = Sam2FeedForward( + config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers + ) + self.layer_norm3 = nn.LayerNorm(config.hidden_size) + + self.layer_norm4 = nn.LayerNorm(config.hidden_size) + self.cross_attn_image_to_token = Sam2Attention(config) + + self.skip_first_layer_pe = skip_first_layer_pe + + +class Sam2TwoWayTransformer(SamTwoWayTransformer): + pass + + +class Sam2LayerNorm(SamLayerNorm): + pass + + +class Sam2MaskDecoder(SamMaskDecoder): + def __init__(self, config: Sam2MaskDecoderConfig): + super().__init__(config) + del self.iou_prediction_head + self.iou_prediction_head = Sam2FeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + sigmoid_output=True, + ) + + self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) + + self.obj_score_token = nn.Embedding(1, self.hidden_size) + self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3) + + self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, :, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, :, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P] + best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + best_scores_inds_expanded = best_scores_inds_expanded.expand( + -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1) + ) + best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] + best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, :, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, :, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + high_resolution_features: list[torch.Tensor], + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embeddings (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (`bool`): + Whether to return multiple masks or a single mask. + high_resolution_features (`list[torch.Tensor]`, *optional*): + The high-resolution features from the vision encoder. + attention_similarity (`torch.Tensor`, *optional*): + The attention similarity tensor. + target_embedding (`torch.Tensor`, *optional*): + The target embedding. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.shape[0] != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-mask + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + # Run the transformer + point_embeddings, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + iou_token_out = point_embeddings[:, :, 1, :] + mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).view( + batch_size * point_batch_size, num_channels, height, width + ) + + feat_s0, feat_s1 = high_resolution_features + feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) + feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) + upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) + + hyper_in_list: list[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + elif self.dynamic_multimask_via_stability and not self.training: + mask_slice = slice(0, 1) + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape + + return masks, iou_pred, sam_tokens_out, object_score_logits + + +@auto_docstring( + custom_intro=""" + Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and + input points and labels, boxes, or masks. + """ +) +class Sam2Model(SamModel): + _keys_to_ignore_on_load_unexpected = [ + r"^memory_.*", + r"^mask_downsample.*", + r"^object_pointer_proj.*", + r"^temporal_positional_encoding_projection_layer.*", + "no_memory_positional_encoding", + "no_object_pointer", + "occlusion_spatial_embedding_parameter", + ] + + def __init__(self, config: Sam2Config): + PreTrainedModel.__init__(self, config) + self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config) + + self.num_feature_levels = config.vision_config.num_feature_levels + self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes + # a single token to indicate no memory embedding from previous frames + self.hidden_dim = config.vision_config.fpn_hidden_size + self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + + self.post_init() + + def get_image_wide_positional_embeddings(self) -> torch.Tensor: + size = self.prompt_encoder.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones(size, device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size[0] + x_embed = x_embed / size[1] + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> list[torch.Tensor]: + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + batch_size = pixel_values.shape[0] + feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + return image_embeddings + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ + list[torch.Tensor], + list[torch.Tensor], + Optional[tuple[torch.FloatTensor, ...]], + Optional[tuple[torch.FloatTensor, ...]], + ]: + r""" + Extract and preprocess image features using the vision encoder. + + Args: + pixel_values (`torch.FloatTensor`): + Input pixel values of shape `(batch_size, num_channels, height, width)`. + + Returns: + `tuple`: A tuple containing: + - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels. + - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level. + - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder. + - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder. + """ + vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder( + pixel_values, + **kwargs, + ) + + feature_maps = vision_outputs.fpn_hidden_states + feature_maps_position_embeddings = vision_outputs.fpn_position_encoding + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + feature_maps = list(feature_maps) + feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) + feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) + + # flatten NxCxHxW to HWxNxC + feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] + feature_maps_position_embeddings = [ + feature_map_position_embedding.flatten(2).permute(2, 0, 1) + for feature_map_position_embedding in feature_maps_position_embeddings + ] + + return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions + + @check_model_inputs() + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Sam2ImageSegmentationOutput: + r""" + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny") + >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + if not ((pixel_values is None) ^ (image_embeddings is None)): + raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") + if input_points is not None and input_boxes is not None: + if input_points.shape[1] != input_boxes.shape[1]: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( + pixel_values, + **kwargs, + ) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is None and input_boxes is None: + # If no points are provide, pad with an empty point (with label -1) + input_points = torch.zeros( + batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device + ) + input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device) + + if input_masks is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: + input_masks = F.interpolate( + input_masks.float(), + size=self.prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ).to(input_masks.dtype) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder( + image_embeddings=image_embeddings[-1], + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + high_resolution_features=image_embeddings[:-1], + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + + return Sam2ImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_multimasks, + object_score_logits=object_score_logits, + image_embeddings=image_embeddings, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + +__all__ = [ + "Sam2Model", + "Sam2VisionModel", + "Sam2PreTrainedModel", + "Sam2ImageProcessorFast", + "Sam2HieraDetModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sam2/processing_sam2.py b/venv/lib/python3.13/site-packages/transformers/models/sam2/processing_sam2.py new file mode 100644 index 0000000000000000000000000000000000000000..5f147aab8dfadaeaa6790246e2275a76b0dd67be --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sam2/processing_sam2.py @@ -0,0 +1,526 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SAM2. +""" + +from copy import deepcopy +from typing import Optional, Union + +import numpy as np + +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding +from ...utils import TensorType, is_torch_available, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + + +@requires(backends=("torch",)) +class Sam2Processor(ProcessorMixin): + r""" + Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`Sam2Processor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of + [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information. + + Args: + image_processor (`Sam2ImageProcessorFast`): + An instance of [`Sam2ImageProcessorFast`]. + target_size (`int`, *optional*): + The target size (target_size, target_size) to which the image will be resized. + point_pad_value (`int`, *optional*, defaults to -10): + The value used for padding input points. + """ + + attributes = ["image_processor"] + image_processor_class = "Sam2ImageProcessorFast" + + def __init__(self, image_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs): + super().__init__(image_processor, **kwargs) + self.point_pad_value = point_pad_value + self.target_size = target_size if target_size is not None else self.image_processor.size["height"] + + def __call__( + self, + images: Optional[ImageInput] = None, + segmentation_maps: Optional[ImageInput] = None, + input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None, + input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None, + input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None, + original_sizes: Optional[Union[list[list[float]], torch.Tensor]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchEncoding: + r""" + This method uses [`Sam2ImageProcessorFast.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + + Args: + images (`ImageInput`, *optional*): + The image(s) to process. + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to process. + input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*): + The points to add to the frame. + input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*): + The labels for the points. + input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*): + The bounding boxes to add to the frame. + original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*): + The original sizes of the images. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. + **kwargs: + Additional keyword arguments to pass to the image processor. + + Returns: + A [`BatchEncoding`] with the following fields: + - `pixel_values` (`torch.Tensor`): The processed image(s). + - `original_sizes` (`list[list[float]]`): The original sizes of the images. + - `reshaped_input_sizes` (`torch.Tensor`): The reshaped input sizes of the images. + - `labels` (`torch.Tensor`): The processed segmentation maps (if provided). + - `input_points` (`torch.Tensor`): The processed points. + - `input_labels` (`torch.Tensor`): The processed labels. + - `input_boxes` (`torch.Tensor`): The processed bounding boxes. + """ + if images is not None: + encoding_image_processor = self.image_processor( + images, + segmentation_maps=segmentation_maps, + return_tensors=return_tensors, + **kwargs, + ) + elif original_sizes is not None: + if isinstance(original_sizes, torch.Tensor): + original_sizes = original_sizes.cpu().tolist() + encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors) + else: + raise ValueError("Either images or original_sizes must be provided") + + # pop arguments that are not used in the forward but used nevertheless + original_sizes = encoding_image_processor["original_sizes"] + # Check original_sizes is of length 1 or len(images) + if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images): + raise ValueError( + "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size." + ) + + # Process input points, labels, and boxes if provided + if input_points is not None or input_labels is not None or input_boxes is not None: + # Validate and convert inputs to standardized format + processed_points = self._validate_single_input( + input_points, + expected_depth=4, + input_name="points", + expected_format="[image level, object level, point level, point coordinates]", + expected_coord_size=2, + ) + processed_labels = self._validate_single_input( + input_labels, + expected_depth=3, + input_name="labels", + expected_format="[image level, object level, point level]", + ) + processed_boxes = self._validate_single_input( + input_boxes, + expected_depth=3, + input_name="boxes", + expected_format="[image level, box level, box coordinates]", + expected_coord_size=4, + ) + + # Get padding requirements for all inputs + if processed_points is not None: + points_max_dims = self._get_nested_dimensions(processed_points)[:3] + if processed_labels is not None: + labels_max_dims = self._get_nested_dimensions(processed_labels)[:3] + if processed_boxes is not None: + boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2] + + # Ensure points and labels have consistent dimensions + if processed_points is not None and processed_labels is not None: + if points_max_dims != labels_max_dims: + raise ValueError( + "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions." + ) + + # Check that boxes don't need padding (model limitation) + if processed_boxes is not None and len(processed_boxes) >= 2: + if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes): + raise ValueError( + "Input boxes have inconsistent dimensions that would require padding, " + "but boxes cannot be padded due to model limitations. " + "Please ensure all images have the same number of boxes." + ) + + # Pad and normalize all inputs to final tensor format + if processed_points is not None: + padded_points = self._pad_nested_list(processed_points, points_max_dims + [2]) + final_points = torch.tensor(padded_points, dtype=torch.float32) + self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True) + encoding_image_processor.update({"input_points": final_points}) + + if processed_labels is not None: + padded_labels = self._pad_nested_list(processed_labels, labels_max_dims) + final_labels = torch.tensor(padded_labels, dtype=torch.int64) + encoding_image_processor.update({"input_labels": final_labels}) + + if processed_boxes is not None: + final_boxes = torch.tensor(processed_boxes, dtype=torch.float32) + self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True) + encoding_image_processor.update({"input_boxes": final_boxes}) + + return encoding_image_processor + + def _normalize_coordinates( + self, target_size: int, coords: "torch.Tensor", original_size, is_bounding_box=False + ) -> "torch.Tensor": + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. + + Args: + target_size (`int`): + The target size of the image. + coords (`torch.Tensor`): + The coordinates to be normalized. + original_size (`tuple`): + The original size of the image. + is_bounding_box (`bool`, *optional*, defaults to `False`): + Whether the coordinates are bounding boxes. + """ + old_h, old_w = original_size + new_h, new_w = target_size, target_size + coords = deepcopy(coords).float() + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _convert_to_nested_list(self, data, expected_depth, current_depth=0): + """ + Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists. + + Args: + data: Input data in any format + expected_depth: Expected nesting depth + current_depth: Current depth in recursion + + Returns: + Nested list representation of the data + """ + if data is None: + return None + + # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array + if isinstance(data, torch.Tensor): # PyTorch tensor + if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor + return data.numpy().tolist() + else: + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, np.ndarray): # NumPy array + if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array + return data.tolist() + else: + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, list): + if current_depth == expected_depth: + # We've reached the expected depth, return as is + return data + else: + # Continue recursion + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, (int, float)): + return data + else: + raise ValueError(f"Unsupported data type: {type(data)}") + + def _get_nested_dimensions(self, nested_list, max_dims=None): + """ + Get the maximum dimensions at each level of nesting. + + Args: + nested_list (`list`): + Nested list structure. + max_dims (`list`, *optional*): + Current maximum dimensions (for recursion). + + Returns: + `list`: A list of maximum dimensions for each nesting level. + """ + if max_dims is None: + max_dims = [] + + if not isinstance(nested_list, list): + return max_dims + + if len(max_dims) == 0: + max_dims.append(len(nested_list)) + else: + max_dims[0] = max(max_dims[0], len(nested_list)) + + if len(nested_list) > 0: + for item in nested_list: + if isinstance(item, list): + sub_dims = self._get_nested_dimensions(item) + # Merge sub_dims into max_dims + for i, dim in enumerate(sub_dims): + if i + 1 >= len(max_dims): + max_dims.append(dim) + else: + max_dims[i + 1] = max(max_dims[i + 1], dim) + + return max_dims + + def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None): + """ + Recursively pad a nested list to match target dimensions. + + Args: + nested_list (`list`): + Nested list to pad. + target_dims (`list`): + Target dimensions for each level. + current_level (`int`, *optional*, defaults to 0): + Current nesting level. + pad_value (`int`, *optional*): + Value to use for padding. + + Returns: + `list`: The padded nested list. + """ + if pad_value is None: + pad_value = self.point_pad_value + + if current_level >= len(target_dims): + return nested_list + + # Ensure we have a list + if not isinstance(nested_list, list): + nested_list = [nested_list] + + # Pad current level + current_size = len(nested_list) + target_size = target_dims[current_level] + + # Pad with appropriate values + if current_level == len(target_dims) - 1: + # At the coordinate level, pad with pad_value + nested_list.extend([pad_value] * (target_size - current_size)) + else: + # At higher levels, pad with nested structures + if current_size > 0: + # Create appropriately sized template + if current_level < len(target_dims) - 2: + # For non-coordinate levels, create empty nested structure + template_dims = target_dims[current_level + 1 :] + template = self._create_empty_nested_structure(template_dims, pad_value) + else: + # For coordinate level, create list of pad_values + template = [pad_value] * target_dims[current_level + 1] + + nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)]) + else: + # Create from scratch + template_dims = target_dims[current_level + 1 :] + template = self._create_empty_nested_structure(template_dims, pad_value) + nested_list.extend([deepcopy(template) for _ in range(target_size)]) + + # Recursively pad sublists + if current_level < len(target_dims) - 1: + for i in range(len(nested_list)): + if isinstance(nested_list[i], list): + nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value) + + return nested_list + + def _create_empty_nested_structure(self, dims, pad_value): + """ + Create an empty nested structure with given dimensions filled with pad_value. + + Args: + dims (`list`): + The dimensions of the nested structure. + pad_value (`int`): + The value to fill the structure with. + """ + if len(dims) == 1: + return [pad_value] * dims[0] + else: + return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])] + + def _get_nesting_level(self, input_list): + """ + Get the nesting level of a list structure. + + Args: + input_list (`list`): + The list to get the nesting level of. + """ + if isinstance(input_list, list): + if len(input_list) == 0: + return 1 + return 1 + self._get_nesting_level(input_list[0]) + elif isinstance(input_list, (np.ndarray, torch.Tensor)): + # For arrays/tensors, the nesting level is the number of dimensions + return len(input_list.shape) + return 0 + + def _validate_single_input( + self, + data: Union[torch.Tensor, np.ndarray, list], + expected_depth: int, + input_name: str, + expected_format: str, + expected_coord_size: Optional[int] = None, + ) -> list: + """ + Validate a single input by ensuring proper nesting and raising an error if the input is not valid. + + Args: + data (`torch.Tensor`, `np.ndarray`, or `list`): + Input data to process. + expected_depth (`int`): + Expected nesting depth. + input_name (`str`): + Name of the input for error messages. + expected_format (`str`): + The expected format of the input. + expected_coord_size (`int`, *optional*): + Expected coordinate size (2 for points, 4 for boxes, None for labels). + . + """ + if data is None: + return None + + # Handle tensors and numpy arrays first + if isinstance(data, (torch.Tensor, np.ndarray)): + # For tensors/arrays, we can directly check the number of dimensions + if data.ndim != expected_depth: + raise ValueError( + f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions." + ) + elif expected_coord_size is not None: + if data.shape[-1] != expected_coord_size: + raise ValueError( + f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}." + ) + return self._convert_to_nested_list(data, expected_depth) + + # Handle nested lists + if isinstance(data, list): + current_depth = self._get_nesting_level(data) + if current_depth != expected_depth: + raise ValueError( + f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels." + ) + return self._convert_to_nested_list(data, expected_depth) + + def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False): + """ + Helper method to normalize coordinates in a tensor across multiple images. + + Args: + tensor (`torch.Tensor`): + Input tensor with coordinates. + original_sizes (`list`): + Original image sizes. + is_bounding_box (`bool`, *optional*, defaults to `False`): + Whether coordinates are bounding boxes. + preserve_padding (`bool`, *optional*, defaults to `False`): + Whether to preserve padding values (for points). + """ + if preserve_padding: + # For points: avoid normalizing pad values + mask = tensor != self.point_pad_value + coord_mask = mask.all(dim=-1, keepdim=True) + + for img_idx in range(len(original_sizes)): + if img_idx < tensor.shape[0]: + original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0] + normalized_coords = self._normalize_coordinates( + self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box + ) + + if preserve_padding: + # Only update non-padded values + img_mask = coord_mask[img_idx] + tensor[img_idx] = torch.where( + img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx] + ) + else: + tensor[img_idx] = normalized_coords + + def post_process_masks( + self, + masks, + original_sizes, + mask_threshold=0.0, + binarize=True, + max_hole_area=0.0, + max_sprinkle_area=0.0, + apply_non_overlapping_constraints=False, + **kwargs, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[torch.Tensor], List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold for binarization and post-processing operations. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + max_hole_area (`float`, *optional*, defaults to 0.0): + The maximum area of a hole to fill. + max_sprinkle_area (`float`, *optional*, defaults to 0.0): + The maximum area of a sprinkle to fill. + apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): + Whether to apply non-overlapping constraints to the masks. + + Returns: + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + return self.image_processor.post_process_masks( + masks, + original_sizes, + mask_threshold, + binarize, + max_hole_area, + max_sprinkle_area, + apply_non_overlapping_constraints, + **kwargs, + ) + + +__all__ = ["Sam2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d289de0474acfadeeee08019dd57f40f9a1a4b4 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_seamless_m4t import * + from .feature_extraction_seamless_m4t import * + from .modeling_seamless_m4t import * + from .processing_seamless_m4t import * + from .tokenization_seamless_m4t import * + from .tokenization_seamless_m4t_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/configuration_seamless_m4t.py b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/configuration_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4f911957fcb83f91113da96aeeb9280a0638fd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/configuration_seamless_m4t.py @@ -0,0 +1,416 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SeamlessM4T model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SeamlessM4TConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~SeamlessM4TModel`]. It is used to instantiate an + SeamlessM4T model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SeamlessM4T + ["facebook/hf-seamless-m4t-medium"](https://huggingface.co/"facebook/hf-seamless-m4t-medium") architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256102): + Vocabulary size of the SeamlessM4T model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`~SeamlessM4TModel`], [`~SeamlessM4TForTextToSpeech`] or + [`~SeamlessM4TForTextToText`]. + t2u_vocab_size (`int`, *optional*, defaults to 10082): + Unit vocabulary size of the SeamlessM4T model. Defines the number of different unit tokens that can be + represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4TModel`], + [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`]. + + > Parameters shared across sub-models + + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the "intermediate" layers in the architecture. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + encoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder and feed-forward layers. If string, + `"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all attention layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all activation layers in the model. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(d_model). + + > Text encoder and text decoder specific parameters + + encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text encoder. + decoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text decoder. + decoder_start_token_id (`int`, *optional*, defaults to 3): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only + applied in the text decoder. + max_new_tokens (`int`, *optional*, defaults to 256): + The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the _padding_ text token. Only applied to the text-decoder model. + bos_token_id (`int`, *optional*, defaults to 2): + The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model. + eos_token_id (`int`, *optional*, defaults to 3): + The id of the _end-of-stream_ text token. Only applied to the text-decoder model. + + > Speech encoder specific parameters + + speech_encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer speech encoder. + speech_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer speech encoder. + speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder. + speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`): + The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`, + `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + speech_encoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all layers in the speech encoder. + add_adapter (`bool`, *optional*, defaults to `True`): + Add an adapter layer on top of the speech encoder. + speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see + https://huggingface.co/papers/1909.11556) for more details. + feature_projection_input_dim (`int`, *optional*, defaults to 160): + Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing + input audios with [`SeamlessM4TFeatureExtractor`]. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer of the speech encoder. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer of the speech encoder. + adaptor_kernel_size (`int`, *optional*, defaults to 8): + Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_stride (`int`, *optional*, defaults to 8): + Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all layers in the speech adapter. + num_adapter_layers (`int`, *optional*, defaults to 1): + Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is + True`. + position_embeddings_type (`str`, *optional*, defaults to `"relative"`): + Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left + `None` no relative position embedding is applied. Only applied to the speech encoder. + rotary_embedding_base (`int`, *optional*, defaults to 10000): + If `"rotary"` position embeddings are used, defines the size of the embedding base. Only applied to the + speech encoder. + max_source_positions (`int`, *optional*, defaults to 4096): + if `"relative"` position embeddings are used, defines the maximum source input positions. Only applied to + the speech encoder. + conv_depthwise_kernel_size (`int`, *optional*, defaults to 31): + Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder. + + > Text-To-Unit (t2u) model specific parameters + + t2u_bos_token_id (`int`, *optional*, defaults to 0): + The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_pad_token_id (`int`, *optional*, defaults to 1): + The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_eos_token_id (`int`, *optional*, defaults to 2): + The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_decoder_start_token_id (`int`, *optional*, defaults to 2): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only + applied to the text-to-unit seq2seq model. + t2u_max_new_tokens (`int`, *optional*, defaults to 1024): + The maximum numbers of unit tokens to generate, ignoring the number of tokens in the prompt. Only applied + to the text-to-unit seq2seq model. + t2u_encoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit encoder. + t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder. + t2u_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit encoder. + t2u_decoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit decoder. + t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder. + t2u_decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit decoder. + t2u_max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model text-to-unit component might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + + > Hifi-Gan Vocoder specific parameters + + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the output audio will be generated, expressed in hertz (Hz). + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only. + upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`): + A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network. + The length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. Applies to the vocoder only. + upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling + network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match + the length of *upsample_rates*. Applies to the vocoder only. + resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive + field fusion (MRF) module. Applies to the vocoder only. + resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in + the multi-receptive field fusion (MRF) module. Applies to the vocoder only. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder + only. + unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000): + Vocabulary size of the SeamlessM4T vocoder. Defines the number of different unit tokens that can be + represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4TModel`], + [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`]. + unit_embed_dim (`int`, *optional*, defaults to 1280): + The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only. + lang_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only. + spkr_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only. + vocoder_num_langs (`int`, *optional*, defaults to 36): + Number of langs supported by the vocoder. Might be different from `t2u_num_langs`. + vocoder_num_spkrs (`int`, *optional*, defaults to 200): + Number of speakers supported by the vocoder. + variance_predictor_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the duration predictor. Applies to the vocoder only. + var_pred_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability of the duration predictor. Applies to the vocoder only. + vocoder_offset (`int`, *optional*, defaults to 4): + Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only. + + ```python + >>> from transformers import SeamlessM4TModel, SeamlessM4TConfig + + >>> # Initializing a SeamlessM4T "facebook/hf-seamless-m4t-medium" style configuration + >>> configuration = SeamlessM4TConfig() + + >>> # Initializing a model from the "facebook/hf-seamless-m4t-medium" style configuration + >>> model = SeamlessM4TModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "seamless_m4t" + + def __init__( + self, + vocab_size=256102, + t2u_vocab_size=10082, + # shared config + hidden_size=1024, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + max_position_embeddings=1024, + is_encoder_decoder=True, + encoder_layerdrop=0.05, + decoder_layerdrop=0.05, + activation_function="relu", + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + scale_embedding=True, + # text encoder|decoder + encoder_layers=24, + encoder_ffn_dim=8192, + encoder_attention_heads=16, + decoder_layers=24, + decoder_ffn_dim=8192, + decoder_attention_heads=16, + decoder_start_token_id=3, + max_new_tokens=256, + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, + # speech_encoder + speech_encoder_layers=24, + speech_encoder_attention_heads=16, + speech_encoder_intermediate_size=4096, + speech_encoder_hidden_act="swish", + speech_encoder_dropout=0.0, + add_adapter=True, + speech_encoder_layerdrop=0.1, + feature_projection_input_dim=160, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + adaptor_kernel_size=8, + adaptor_stride=8, + adaptor_dropout=0.1, + num_adapter_layers=1, + position_embeddings_type="relative", + rotary_embedding_base=10000, + max_source_positions=4096, + conv_depthwise_kernel_size=31, + # t2u config + t2u_bos_token_id=0, + t2u_pad_token_id=1, + t2u_eos_token_id=2, + t2u_decoder_start_token_id=2, + t2u_max_new_tokens=1024, + t2u_encoder_layers=6, + t2u_encoder_ffn_dim=8192, + t2u_encoder_attention_heads=16, + t2u_decoder_layers=6, + t2u_decoder_ffn_dim=8192, + t2u_decoder_attention_heads=16, + t2u_max_position_embeddings=2048, + # hifi-gan vocoder config + sampling_rate=16000, + upsample_initial_channel=512, + upsample_rates=[5, 4, 4, 2, 2], + upsample_kernel_sizes=[11, 8, 8, 4, 4], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_slope=0.1, + # specific to Code Hifi-Gan + unit_hifi_gan_vocab_size=10000, + unit_embed_dim=1280, + lang_embed_dim=256, + spkr_embed_dim=256, + vocoder_num_langs=36, + vocoder_num_spkrs=200, + variance_predictor_kernel_size=3, + var_pred_dropout=0.5, + vocoder_offset=4, + **kwargs, + ): + # overall_config + self.vocab_size = vocab_size + self.t2u_vocab_size = t2u_vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.max_new_tokens = max_new_tokens + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.activation_function = activation_function + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.scale_embedding = scale_embedding + # for proper config init + self.num_attention_heads = decoder_attention_heads + self.num_hidden_layers = decoder_layers + + # text|unit encoder|decoder + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_attention_heads = decoder_attention_heads + + # speech_encoder + self.speech_encoder_layers = speech_encoder_layers + self.speech_encoder_hidden_act = speech_encoder_hidden_act + self.speech_encoder_dropout = speech_encoder_dropout + self.speech_encoder_attention_heads = speech_encoder_attention_heads + self.speech_encoder_layerdrop = speech_encoder_layerdrop + self.speech_encoder_intermediate_size = speech_encoder_intermediate_size + self.feature_projection_input_dim = feature_projection_input_dim + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.adaptor_kernel_size = adaptor_kernel_size + self.adaptor_stride = adaptor_stride + self.adaptor_dropout = adaptor_dropout + self.num_adapter_layers = num_adapter_layers + self.position_embeddings_type = position_embeddings_type + self.rotary_embedding_base = rotary_embedding_base + self.max_source_positions = max_source_positions + self.conv_depthwise_kernel_size = conv_depthwise_kernel_size + self.add_adapter = add_adapter + + # t2u config + self.t2u_bos_token_id = t2u_bos_token_id + self.t2u_pad_token_id = t2u_pad_token_id + self.t2u_eos_token_id = t2u_eos_token_id + self.t2u_decoder_start_token_id = t2u_decoder_start_token_id + self.t2u_max_new_tokens = t2u_max_new_tokens + self.t2u_encoder_layers = t2u_encoder_layers + self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim + self.t2u_encoder_attention_heads = t2u_encoder_attention_heads + self.t2u_decoder_layers = t2u_decoder_layers + self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim + self.t2u_decoder_attention_heads = t2u_decoder_attention_heads + self.t2u_max_position_embeddings = t2u_max_position_embeddings + + # hifi-gan vocoder config + # original parameters specific to Hifi-Gan + self.sampling_rate = sampling_rate + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.leaky_relu_slope = leaky_relu_slope + + # specific to Code Hifi-Gan + self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size + self.unit_embed_dim = unit_embed_dim + self.lang_embed_dim = lang_embed_dim + self.spkr_embed_dim = spkr_embed_dim + self.vocoder_num_langs = vocoder_num_langs + self.vocoder_num_spkrs = vocoder_num_spkrs + self.variance_predictor_kernel_size = variance_predictor_kernel_size + self.var_pred_dropout = var_pred_dropout + self.vocoder_offset = vocoder_offset + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + is_encoder_decoder=is_encoder_decoder, + max_position_embeddings=max_position_embeddings, + **kwargs, + ) + + +__all__ = ["SeamlessM4TConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..59a20059f869ecca947d0999d3239573fbf10194 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -0,0 +1,309 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for SeamlessM4T +""" + +from typing import Optional, Union + +import numpy as np + +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a SeamlessM4T feature extractor. + + This feature extractor inherits from [`SequenceFeatureExtractor`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + num_mel_bins (`int`, *optional*, defaults to 80): + Number of Mel-frequency bins. + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding vectors. + stride (`int`, *optional*, defaults to 2): + Stride used to reshape audios from shape (batch_size,num_frames,num_mel_bins) to + (batch_size,num_frames//stride,num_mel_bins*stride). + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + num_mel_bins=80, + padding_value=0.0, + stride=2, + **kwargs, + ): + self.num_mel_bins = num_mel_bins + self.return_attention_mask = True + self.stride = stride + + mel_filters = mel_filter_bank( + num_frequency_bins=257, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = mel_filters + self.window = window_function(400, "povey", periodic=False) + + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + + @staticmethod + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 + ) -> list[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def _extract_fbank_features( + self, + waveform: np.ndarray, + ) -> np.ndarray: + """ + Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs + and hence the waveform should not be normalized before feature extraction. + """ + # by default, it extracts the left channel if stereo + if len(waveform.shape) == 2: + waveform = waveform[0] + + waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers + features = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + return features + + def __call__( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + max_length: Optional[int] = None, + truncation: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + do_normalize_per_mel_bins: Optional[bool] = True, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `torch.Tensor`, `list[float]`, `list[np.ndarray]`, `list[torch.Tensor]`, + `list[list[float]]`, `list[list[list[float]]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, + a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors, + a list of list of float values or a list of a list of list of float values. + If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `list[float]`, `raw_speech` is + considered a single-channel, single-sample sound. In all other cases, the first dimension of + `raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `list[...]`, + corresponds to the number of samples in the batch, and the number of channels + (i.e. mono or stereo character) is derived from the other dimensions + (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*, defaults to 2): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + + + For SeamlessM4T models, `attention_mask` should always be passed for batched inference, to avoid subtle + bugs. + + + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + do_normalize_per_mel_bins (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean unit-variance normalize the input per mel-channel. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to the tokenizer or the feature + extractor. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + return_attention_mask = ( + return_attention_mask if return_attention_mask is not None else self.return_attention_mask + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 3: + raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}") + + acceptable_types = ( + (torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list) + ) + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types)) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # extract fbank features + features = [self._extract_fbank_features(waveform) for waveform in raw_speech] + + if do_normalize_per_mel_bins: + # torch defaults to ddof=1, and numpy defaults to ddof=0 + features = [ + (x - np.expand_dims(x.mean(0), 0)) / np.sqrt(np.expand_dims(x.var(0, ddof=1), 0) + 1e-7) + for x in features + ] + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_features": features}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=True, + return_tensors="np", + ) + + # SeamlessM4T needs to process extracted features + input_features = padded_inputs.get("input_features") + attention_mask = padded_inputs.pop("attention_mask") + + batch_size, num_frames, num_channels = input_features.shape + + remainder = num_frames % self.stride + if remainder != 0: + input_features = input_features[:, : num_frames - remainder, :] + attention_mask = attention_mask[:, : num_frames - remainder] + + input_features = np.reshape( + input_features, (batch_size, num_frames // self.stride, num_channels * self.stride) + ) + + indices = np.arange(0, num_frames - remainder) + attention_mask = attention_mask[:, indices % self.stride == 1] + + padded_inputs["input_features"] = input_features + if return_attention_mask: + padded_inputs["attention_mask"] = attention_mask + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["SeamlessM4TFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/modeling_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe1d2fb714dae73684d3a4a5f77395ec1837144 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -0,0 +1,4070 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SeamlessM4T model.""" + +import copy +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Wav2Vec2BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_seamless_m4t import SeamlessM4TConfig + + +logger = logging.get_logger(__name__) + +SEAMLESS_M4T_COMMON_CUSTOM_ARGS = r""" + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + inputs_embeds (`torch.FloatTensor` of shape`(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. +""" + + +@dataclass +@auto_docstring( + custom_intro=""" + Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`], + [`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`]. + """ +) +class SeamlessM4TGenerationOutput(ModelOutput): + r""" + waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + The final audio waveform predicted by the model. + waveform_lengths (`torch.IntTensor` of shape `(batch_size,)`, *optional*): + The length in samples of each element in the `waveform` batch. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The generated translated sequences. This is the output of the text-to-text or the speech-to-text models. + The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished + early due to the `eos_token_id`. + unit_sequences (`torch.LongTensor` of shape `(batch_size, unit_sequence_length)`, *optional*): + The generated translated unit sequences. This is the output of the text-to-units model. The second + dimension (unit_sequence_length) is either equal to `t2u_max_length` or shorter if all batches finished + early due to the `t2u_eos_token_id`. + """ + + waveform: Optional[torch.FloatTensor] = None + waveform_lengths: Optional[torch.IntTensor] = None + sequences: Optional[tuple[torch.FloatTensor]] = None + unit_sequences: Optional[tuple[torch.FloatTensor]] = None + + +############ UTILS ################ + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): + """ + Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that + stops at the corresponding element in `seq_lens`. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): + The sequences to mask, where `*` is any number of sequence-specific dimensions including none. + seq_lens (`torch.Tensor` of shape `(batch)`: + Each element represents the length of the sequence at the same index in `hidden_states` + + Returns: + `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` + """ + batch_size, mask_seq_len = hidden_states.shape[:2] + + indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) + + bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) + + mask = hidden_states.new_ones((batch_size, mask_seq_len)) + + mask = mask.masked_fill(bool_mask, 0) + + return mask + + +def format_speech_generation_kwargs(kwargs): + """ + Format kwargs for SeamlessM4T models that generate speech, attribute kwargs to either the text generation or the + speech generation models. + + Args: + kwargs (`dict`)`: + Keyword arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + """ + # attribute kwargs to models + kwargs_text = {} + kwargs_speech = {} + for key, value in kwargs.items(): + if key.startswith("text_"): + key = key[len("text_") :] + kwargs_text[key] = value + elif key.startswith("speech_"): + key = key[len("speech_") :] + kwargs_speech[key] = value + elif key == "generation_config": + kwargs_text[key] = value + else: + # If the key is already in a specific config, then it's been set with a + # submodules specific value and we don't override + if key not in kwargs_text: + kwargs_text[key] = value + if key not in kwargs_speech: + kwargs_speech[key] = value + return kwargs_text, kwargs_speech + + +############ SPEECH ENCODER related code ################ + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SeamlessM4TConformer, feat_extract_activation->speech_encoder_hidden_act +class SeamlessM4TConformerPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SeamlessM4TConformerSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2->SeamlessM4T, num_attention_heads->speech_encoder_attention_heads +class SeamlessM4TConformerRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://huggingface.co/papers/2104.09864 + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.speech_encoder_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + # Embeddings are computed in the dtype of the inv_freq constant + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + # Computed embeddings are cast to the dtype of the hidden state inputs + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) + return self.cached_rotary_positional_embedding + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2->SeamlessM4T +class SeamlessM4TConformerRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (iSeamlessM4T +class SeamlessM4TConformerSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class SeamlessM4TConformerFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerFeedForward(nn.Module): + def __init__(self, config, act_fn=None, dropout=None): + super().__init__() + dropout = dropout if dropout is not None else config.speech_encoder_dropout + act_fn = act_fn if act_fn is not None else config.speech_encoder_hidden_act + + self.intermediate_dropout = nn.Dropout(dropout) + self.intermediate_dense = nn.Linear(config.hidden_size, config.speech_encoder_intermediate_size) + self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn + + self.output_dense = nn.Linear(config.speech_encoder_intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.pointwise_conv1 = nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding="same", + groups=config.hidden_size, + bias=False, + ) + self.batch_norm = nn.BatchNorm1d(config.hidden_size) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + self.pointwise_conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states, attention_mask=None): + hidden_states = self.layer_norm(hidden_states) + + # Ensure that we do not leak padded positions in depthwise convolution. + # Put 0 where necessary + if attention_mask is not None: + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SeamlessM4TConformerSelfAttention(nn.Module): + """Construct a SeamlessM4TConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config, use_position_embeddings=True): + super().__init__() + + self.head_size = config.hidden_size // config.speech_encoder_attention_heads + self.num_heads = config.speech_encoder_attention_heads + self.position_embeddings_type = config.position_embeddings_type if use_position_embeddings else None + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.speech_encoder_dropout) + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://huggingface.co/papers/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.position_embeddings_type == "relative": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" + " 'relative'" + ) + # apply relative_position_embeddings to qk scores + # as proposed in Transformer_XL: https://huggingface.co/papers/1901.02860 + scores = self._apply_relative_embeddings( + query=query, key=key, relative_position_embeddings=relative_position_embeddings + ) + else: + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + # apply attention_mask if necessary + if attention_mask is not None: + scores = scores + attention_mask + + # => (batch, head, time1, time2) + probs = torch.softmax(scores, dim=-1) + probs = self.dropout(probs) + + # => (batch, head, time1, d_k) + hidden_states = torch.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, hidden_size = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., : self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2 :] + rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view( + relative_position_embeddings.size(0), -1, self.num_heads, self.head_size + ) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://huggingface.co/papers/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) + scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class SeamlessM4TConformerEncoderLayer(GradientCheckpointingLayer): + """Conformer block based on https://huggingface.co/papers/2005.08100.""" + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4T, attention_dropout->speech_encoder_dropout, torch.nn->nn + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.speech_encoder_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = SeamlessM4TConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = nn.Dropout(dropout) + self.self_attn = SeamlessM4TConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = SeamlessM4TConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = SeamlessM4TConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + conv_attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class SeamlessM4TConformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = SeamlessM4TConformerRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = SeamlessM4TConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.dropout = nn.Dropout(config.speech_encoder_dropout) + self.layers = nn.ModuleList( + [SeamlessM4TConformerEncoderLayer(config) for _ in range(config.speech_encoder_layers)] + ) + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = self.training and dropout_probability < self.config.speech_encoder_layerdrop + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SeamlessM4TConformerAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.adaptor_dropout + + self.kernel_size = config.adaptor_kernel_size + self.stride = config.adaptor_stride + + # 1. residual convolution + self.residual_layer_norm = nn.LayerNorm(embed_dim) + self.residual_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.activation = nn.GLU(dim=1) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.self_attn = SeamlessM4TConformerSelfAttention(config, use_position_embeddings=False) + self.self_attn_dropout = nn.Dropout(dropout) + + # Feed-forward + self.ffn_layer_norm = nn.LayerNorm(embed_dim) + self.ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=dropout) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + pad = self.kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 + + return seq_lens.floor() + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + residual = self.residual_layer_norm(hidden_states) + + # Apply pooling to the residual to match the sequence length of the + # multi-head attention output. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + residual = residual.transpose(1, 2) + residual = self.residual_conv(residual) + residual = self.activation(residual) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + residual = residual.transpose(1, 2) + + hidden_states = self.self_attn_layer_norm(hidden_states) + # Apply pooling before feeding to the multihead-attention layer. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.self_attn_conv(hidden_states) + hidden_states = self.activation(hidden_states) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + hidden_states = hidden_states.transpose(1, 2) + + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + hidden_states.device + ) + attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) + attention_mask = _prepare_4d_attention_mask( + attention_mask, + hidden_states.dtype, + ) + + # The rest of the computation is identical to a vanilla Transformer + # encoder layer. + hidden_states, attn_weights = self.self_attn( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + residual + + return hidden_states + + +class SeamlessM4TConformerAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + self.layers = nn.ModuleList(SeamlessM4TConformerAdapterLayer(config) for _ in range(config.num_adapter_layers)) + + def forward(self, hidden_states, attention_mask): + # down project hidden_states if necessary + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + + return hidden_states + + +############ TEXT / UNITS related code ################ + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4T +class SeamlessM4TScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding +class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values_length: int = 0, + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +class SeamlessM4TAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.bart.modeling_bart.BartAttention.__init__ with Bart->SeamlessM4T + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[SeamlessM4TConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4T,DenseActDense->FeedForwardNetwork, d_model->hidden_size +class SeamlessM4TFeedForwardNetwork(nn.Module): + def __init__(self, config: SeamlessM4TConfig, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(config.hidden_size, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, config.hidden_size) + self.dropout = nn.Dropout(config.activation_dropout) + self.act = ACT2FN[config.activation_function] + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.fc2.weight, torch.Tensor) + and hidden_states.dtype != self.fc2.weight.dtype + and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8) + ): + hidden_states = hidden_states.to(self.fc2.weight.dtype) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SeamlessM4TEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None): + super().__init__() + encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim + encoder_attention_heads = ( + config.encoder_attention_heads if encoder_attention_heads is None else encoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=encoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None, layer_idx=None): + super().__init__() + decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim + decoder_attention_heads = ( + config.decoder_attention_heads if decoder_attention_heads is None else decoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.attn_dropout = nn.Dropout(config.dropout) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attention = SeamlessM4TAttention( + self.embed_dim, + decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=decoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): + encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by + very large negative values. + past_key_values (`Cache`): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + hidden_states, cross_attn_weights = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + return hidden_states, self_attn_weights, cross_attn_weights + + +############ SUB-MODELS related code ################ + + +@auto_docstring +class SeamlessM4TPreTrainedModel(PreTrainedModel): + config: SeamlessM4TConfig + base_model_prefix = "seamless_m4t" + supports_gradient_checkpointing = True + _no_split_modules = ["SeamlessM4TEncoderLayer", "SeamlessM4TDecoderLayer", "SeamlessM4TConformerEncoderLayer"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, SeamlessM4TConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, SeamlessM4TConformerPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, SeamlessM4TConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride + pad = kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - kernel_size) / stride) + 1 + + return seq_lens.floor() + + def compute_last_hidden_states_per_sample( + self, + hidden_states: tuple[tuple[torch.Tensor]], + beam_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Computes the last hidden states. + + Parameters: + hidden_states (`tuple[tuple[torch.Tensor]]`): + The generated hidden states. Tuple (one element for each generated token) of tuples (one element for + each layer of the decoder) of torch.FloatTensor of shape (batch_size*num_beams*num_return_sequences, + generated_length, hidden_size). + beam_indices (`torch.LongTensor`, *optional*): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at + generate-time. + + Return: + `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length, hidden_size)` + containing + the last hidden states. + ```""" + # 1. First, let's compute last_hidden_states from hidden_states. + # For each generation step, takes the hidden state from the last layer. + # shape: (batch_size*vocab_size*num_return_sequences, # generation_steps, hidden_dim) + last_hidden_states = torch.concat([hidden_states[-1] for hidden_states in hidden_states], dim=1) + + # 2. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + # in that case, return directly last_hidden_states + if beam_indices is None: + return last_hidden_states + + # 3. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices.clone()[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 4. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards anyways + beam_indices[beam_indices_mask] = 0 + + # 5. expand beam_indices to last_hidden_states dim + beam_indices = beam_indices.unsqueeze(-1) + beam_indices = beam_indices.expand(-1, -1, last_hidden_states.shape[-1]) + + # 6. select the right candidate for each beam + # in other words, new_last_hidden_states[i,j,k] = last_hidden_states[beam_indices[i,j,k], j, k] for all i, j, k + last_hidden_states = torch.gather(last_hidden_states, 0, beam_indices) + + return last_hidden_states + + +@auto_docstring( + custom_intro=""" + Transformer speech encoder consisting of *config.speech_encoder_layers* conformer self attention layers. + Each layer is a [`SeamlessM4TConformerEncoderLayer`]. + """ +) +class SeamlessM4TSpeechEncoder(SeamlessM4TPreTrainedModel): + main_input_name = "input_features" + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.feature_projection = SeamlessM4TConformerFeatureProjection(config) + self.encoder = SeamlessM4TConformerEncoder(config) + self.intermediate_ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=0.0) + self.adapter = SeamlessM4TConformerAdapter(config) if config.add_adapter else None + self.inner_layer_norm = nn.LayerNorm(config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_features is None: + raise ValueError( + """Both `input_features` and `inputs_embeds` are `None` in `SeamlessM4TSpeechEncoder.forward`. + Make sure one of them is not `None`.""" + ) + + hidden_states = self.feature_projection(input_features) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + expanded_hidden_states = self.intermediate_ffn(hidden_states) + hidden_states = hidden_states + 0.5 * expanded_hidden_states + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states, attention_mask=attention_mask) + + hidden_states = self.inner_layer_norm(hidden_states) + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# inspired from MBart and NllbMoe +@auto_docstring( + custom_intro=""" + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`SeamlessM4TEncoderLayer`]. + """ +) +class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + is_t2u_encoder: bool = False, + ): + r""" + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + is_t2u_encoder (`bool`, *optional*, defaults to `False`): + indicates if it belongs to the text-to-units model, in which case it won't have input embeddings + """ + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + embed_dim = config.hidden_size + + self.is_t2u_encoder = is_t2u_encoder + self.max_source_positions = config.max_position_embeddings + + if not self.is_t2u_encoder: + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_source_positions, + embed_dim, + self.padding_idx, + ) + + layers = [] + for _ in range(config.encoder_layers): + layers.append( + SeamlessM4TEncoderLayer( + config, + encoder_attention_heads=config.encoder_attention_heads, + encoder_ffn_dim=config.encoder_ffn_dim, + ) + ) + + self.layers = nn.ModuleList(layers) + + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and self.is_t2u_encoder: + raise ValueError( + "You cannot pass input_ids to the encoder of the text_to_units model. Pass inputs_embeds instead." + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.is_t2u_encoder: + embed_pos = self.embed_positions(input) + + hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device) + else: + hidden_states = inputs_embeds + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +@auto_docstring( + custom_intro=""" + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SeamlessM4TDecoderLayer`]. + """ +) +class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + ): + r""" + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + """ + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + # if embed_tokens defined, use its shape instead + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_tokens.weight = embed_tokens.weight + else: + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_target_positions, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + layers = [] + for i in range(config.decoder_layers): + layers.append( + SeamlessM4TDecoderLayer( + config, + decoder_attention_heads=config.decoder_attention_heads, + decoder_ffn_dim=config.decoder_ffn_dim, + layer_idx=i, + ) + ) + self.layers = nn.ModuleList(layers) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Transformer bare text-to-unit encoder-decoder. The encoder is a [`SeamlessM4TEncoder`] without embeddings and the decoder is a [`SeamlessM4TDecoder`]. + """ +) +class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + r""" + embed_tokens_decoder (`nn.Embedding`, *optional*): + input embedding of the decoder. + """ + super().__init__(config) + + self.encoder = SeamlessM4TEncoder(config, is_t2u_encoder=True) + self.decoder = SeamlessM4TDecoder(config, embed_tokens_decoder) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Transformer text-to-unit encoder-decoder with a language model head. The base encoder-decoder model is a [`SeamlessM4TTextToUnit`]. + """ +) +class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = [ + "vocoder", + "speech_encoder", + "text_encoder", + "text_decoder", + ] + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + r""" + embed_tokens_decoder (`nn.Embedding`, *optional*): + input embedding of the decoder. + """ + # update config - used principality for bos_token_id etc. + config = copy.deepcopy(config) + for param, val in config.to_dict().items(): + if param.startswith("t2u_"): + config.__setattr__(param[4:], val) + super().__init__(config) + + self.model = SeamlessM4TTextToUnitModel(config, embed_tokens_decoder) + + self.lm_head = nn.Linear(config.hidden_size, config.t2u_vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id) + + def _tie_weights(self) -> None: + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + +############ VOCODER related code ################ + + +# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class SeamlessM4TVariancePredictor(nn.Module): + def __init__(self, config): + super().__init__() + + embed_dim = config.unit_embed_dim + kernel_size = config.variance_predictor_kernel_size + var_pred_dropout = config.var_pred_dropout + + self.conv1 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + self.activation_function = nn.ReLU() + self.ln1 = nn.LayerNorm(embed_dim) + self.dropout_module = nn.Dropout(p=var_pred_dropout) + self.conv2 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=1, + ) + self.ln2 = nn.LayerNorm(embed_dim) + self.proj = nn.Linear(embed_dim, 1) + + def forward(self, hidden_states: Tensor) -> Tensor: + # Input: B x T x C; Output: B x T + hidden_states = self.conv1(hidden_states.transpose(1, 2)) + hidden_states = self.activation_function(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln1(hidden_states)) + hidden_states = self.conv2(hidden_states.transpose(1, 2)) + hidden_states = self.activation_function(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln2(hidden_states)) + return self.proj(hidden_states).squeeze(dim=2) + + +class SeamlessM4THifiGan(nn.Module): + def __init__(self, config: SeamlessM4TConfig): + super().__init__() + model_in_dim = config.unit_embed_dim + config.lang_embed_dim + config.spkr_embed_dim + self.leaky_relu_slope = config.leaky_relu_slope + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + def forward(self, input_embeds: torch.FloatTensor) -> torch.FloatTensor: + r""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + + Args: + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + model_in_dim)`, or un-batched and of shape `(sequence_length, model_in_dim)`. Note that `model_in_dim` + is the sum of `config.unit_embed_dim`, `config.lang_embed_dim` and `config.spkr_embed_dim`. + + Returns: + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + + hidden_states = self.conv_pre(input_embeds) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +@auto_docstring( + custom_intro=""" + Code HiFi-GAN vocoder as described in this [repository](https://github.com/facebookresearch/speech-resynthesis). + """ +) +class SeamlessM4TCodeHifiGan(PreTrainedModel): + config: SeamlessM4TConfig + main_input_name = "input_embeds" + _no_split_modules = [] + + def __init__(self, config): + super().__init__(config) + + self.pad_token_id = config.t2u_pad_token_id + self.dur_predictor = SeamlessM4TVariancePredictor(config) + + self.unit_embedding = nn.Embedding(config.unit_hifi_gan_vocab_size, config.unit_embed_dim) + self.speaker_embedding = nn.Embedding(config.vocoder_num_spkrs, config.spkr_embed_dim) + self.language_embedding = nn.Embedding(config.vocoder_num_langs, config.lang_embed_dim) + + self.hifi_gan = SeamlessM4THifiGan(config) + + # Initialize weights and apply final processing + self.post_init() + + def _get_dur_output_lengths(self, input_ids, dur_out): + """ + Computes the output length after the duration layer. + """ + unit_lengths = (input_ids != self.pad_token_id).sum(1) + + # take care of edge cases where no padding or too many padding + unit_lengths = torch.clamp(unit_lengths, 0, dur_out.shape[1] - 1) + + cumulative_dur_out = torch.cumsum(dur_out, dim=1) + unit_lengths = cumulative_dur_out.gather(dim=1, index=unit_lengths.unsqueeze(1)).squeeze() + + return unit_lengths + + def _get_output_hifigan_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the hifigan convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return ( + torch.div(input_length + 2 * pad - dilation * (kernel_size - 1) - 1, stride, rounding_mode="floor") + 1 + ) + + def _transpose_conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + return (input_length - 1) * stride - 2 * pad + dilation * (kernel_size - 1) + 1 + + # conv_pre + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + # upsampler + for i, (upsample_rate, kernel_size) in enumerate( + zip(self.config.upsample_rates, self.config.upsample_kernel_sizes) + ): + input_lengths = _transpose_conv_out_length( + input_lengths, kernel_size, upsample_rate, (kernel_size - upsample_rate) // 2 + ) + + # resblock + for i in range(len(self.config.upsample_rates)): + for kernel_size, dilation in zip(self.config.resblock_kernel_sizes, self.config.resblock_dilation_sizes): + for dil in dilation: + input_lengths = _conv_out_length( + input_lengths, kernel_size, 1, (kernel_size - 1) * dil // 2, dilation=dil + ) + + for dil in dilation: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1, (kernel_size - 1) // 2, dilation=1) + + # conv_post + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + return input_lengths + + def forward( + self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor + ) -> tuple[torch.Tensor]: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTextToUnitForConditionalGeneration`]. [What are input + IDs?](../glossary#input-ids) + spkr_id (`int`, *optional*): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + tgt_lang (`str`, *optional*): + The language id to use as target language for translation. + """ + hidden_states = self.unit_embedding(input_ids).transpose(1, 2) + spkr = self.speaker_embedding(spkr_id).transpose(1, 2) + lang = self.language_embedding(lang_id).transpose(1, 2) + + log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2)) + dur_out = torch.clamp(torch.round(torch.expm1(log_dur_pred)).long(), min=1) + # B x C x T + if hidden_states.size(0) == 1: + hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2) + else: + # if batched sample, need to interleave per sample, and pad -> loss of parallelism + if hidden_states.shape[0] > 1 and self.training: + logger.warning( + """`self.training=True` and you use batching. You lose parallelism during the hifigan + forward pass because the samples are interleaved.""" + ) + hidden_states = [ + torch.repeat_interleave(hidden_state, duration, dim=-1).transpose(0, 1) + for (hidden_state, duration) in zip(hidden_states, dur_out) + ] + + hidden_states = nn.utils.rnn.pad_sequence(hidden_states, batch_first=True).transpose(1, 2) + + spkr = spkr.repeat(1, 1, hidden_states.shape[-1]) + lang = lang.repeat(1, 1, hidden_states.shape[-1]) + hidden_states = torch.cat([lang, hidden_states, spkr], dim=1) + + hidden_states = self.hifi_gan(hidden_states) + + unit_lengths = self._get_dur_output_lengths(input_ids, dur_out) + lengths = self._get_output_hifigan_lengths(unit_lengths) + + return hidden_states, lengths + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.apply_weight_norm() + weight_norm(self.hifi_gan.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.hifi_gan.conv_post) + + +############ WHOLE MODEL related code ################ + + +@auto_docstring( + custom_intro=""" + The text-to-text SeamlessM4T Model transformer which can be used for T2TT. + """ +) +class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_ids=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.Tensor` of varying shape depending on the modality, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://huggingface.co/papers/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # prepare text_decoder_input_ids + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {", ".join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + + return super().generate( + input_ids, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + +@auto_docstring( + custom_intro=""" + The speech-to-text SeamlessM4T Model transformer which can be used for S2TT. + """ +) +class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_encoder", "t2u_model", "vocoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_features=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://huggingface.co/papers/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + input_features = input_features if input_features is not None else kwargs.pop("inputs") + if tgt_lang is not None: + inputs = kwargs.get("input_embeds") if input_features is None else input_features + inputs = ( + inputs + if inputs is not None + else kwargs.get("encoder_outputs", {"last_hidden_state": None})["last_hidden_state"] + ) + batch_size = len(inputs) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {", ".join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + return super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + +@auto_docstring( + custom_intro=""" + The text-to-speech SeamlessM4T Model transformer which can be used for T2ST. + """ +) +class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForTextToText`." + "It doesn't use the text-to-unit model `SeamlessM4TTextToUnitForConditionalGeneration`." + "If you want to generate speech, use the `.generate` method." + ) + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)` and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_ids, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + +@auto_docstring( + custom_intro=""" + The speech-to-speech SeamlessM4T Model transformer which can be used for S2ST. + """ +) +class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_encoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForSpeechToText`. It doesn't use `self.t2u_model`." + "If you want to generate speech, use the `generate` method." + ) + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_features, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)` and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_features) if input_features is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_features, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get last_hidden_state from encoder + encoder_hidden_states = self.speech_encoder(input_features=input_features, attention_mask=attention_mask)[0] + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + +@auto_docstring( + custom_intro=""" + The original SeamlessM4T Model transformer which can be used for every tasks available (S2ST, S2TT, T2TT, T2ST). + """ +) +class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config, current_modality="text"): + r""" + current_modality (`str`, *optional*, defaults to `"text"`): + Default modality. Used to initialize the model. + """ + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.current_modality = current_modality + if current_modality == "speech": + self.main_input_name = "input_features" + + # these models already call post_init in their initialization + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def set_modality(self, modality="text"): + if modality == "text": + self.main_input_name = "input_ids" + self.current_modality = "text" + elif modality == "speech": + self.main_input_name = "input_features" + self.current_modality = "speech" + else: + raise ValueError(f"`modality={modality}` is not a valid modality. It must be `text` or `speech`.") + + def get_encoder(self): + if self.current_modality == "text": + return self.text_encoder + else: + return self.speech_encoder + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if input_ids is None and input_features is None and inputs_embeds is None and encoder_outputs is None: + raise ValueError( + "`input_ids`,`input_features`, `inputs_embeds` and `encoder_outputs` are all empty. Make sure at least one of them is not." + ) + elif input_features is not None: + if input_ids is not None: + logger.warning( + "`input_ids` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through the `speech_encoder`. " + "Make sure that `input_features` and `input_ids` are mutually exclusive." + ) + + if inputs_embeds is not None: + logger.warning( + "`inputs_embeds` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through `speech_encoder`. " + "`inputs_embeds` will be ignored." + ) + + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + + self.set_modality("speech") + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + elif input_ids is not None or inputs_embeds is not None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + self.set_modality("text") + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + # input modality = speech so new attention mask + if self.current_modality == "speech" and attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + generate_speech: Optional[bool] = True, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated token ids and/or translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids=input_ids, num_beams=4, speech_do_sample=True)` will successively + perform beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`, *optional*): + Input audio features. This should be returned by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + Note that if `generate_speech=False`, this parameter will be ignored and + the text tokens are returned. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + generate_speech (`bool`, *optional*, defaults to `True`): + If `False`, will only returns the text tokens and won't generate speech. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + Returns: + `Union[SeamlessM4TGenerationOutput, tuple[Tensor], ModelOutput]`: + - If `generate_speech` and `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If `generate_speech` and not `return_intermediate_token_ids`, returns a tuple composed of waveforms of + shape `(batch_size, sequence_length)` and `waveform_lengths` which gives the length of each sample. + - If `generate_speech=False`, it will returns `ModelOutput`. + """ + if input_ids is None and input_features is None and kwargs.get("inputs_embeds") is None: + raise ValueError( + "`input_ids`,`input_features` and `inputs_embeds` are all empty. Make sure at least one of them is not." + ) + + if generate_speech and tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + + if tgt_lang is not None: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + batch_size = ( + len(input_features) + if input_features is not None + else (len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds"))) + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + if input_features is not None: + self.set_modality("speech") + if input_ids is not None: + logger.warning( + "`input_features` and `input_ids` are both non empty. `input_features` will be used in priority " + "through the speech encoder. Make sure `input_features=None` if you want to use the text encoder." + ) + text_generation_output = super().generate(input_features=input_features, **kwargs_text) + else: + self.set_modality("text") + text_generation_output = super().generate(input_ids=input_ids, input_features=None, **kwargs_text) + sequences = text_generation_output.sequences + + if not generate_speech: + return text_generation_output + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get encoder last hidden states + if self.current_modality == "speech": + # get last_hidden_state from encoder - must do a pass through the speech encoder + encoder_hidden_states = self.speech_encoder( + input_features=input_features, attention_mask=attention_mask + ).last_hidden_state + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + else: + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + +__all__ = [ + "SeamlessM4TForTextToSpeech", + "SeamlessM4TForSpeechToSpeech", + "SeamlessM4TForTextToText", + "SeamlessM4TForSpeechToText", + "SeamlessM4TModel", + "SeamlessM4TPreTrainedModel", + "SeamlessM4TCodeHifiGan", + "SeamlessM4THifiGan", + "SeamlessM4TTextToUnitForConditionalGeneration", + "SeamlessM4TTextToUnitModel", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/processing_seamless_m4t.py b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/processing_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..85e1968d7d43c006b5495b5729031abfd5026010 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/processing_seamless_m4t.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio/Text processor class for SeamlessM4T +""" + +from typing import Optional, Union + +from ...audio_utils import AudioInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +from ...utils.deprecation import deprecate_kwarg + + +logger = logging.get_logger(__name__) + + +class SeamlessM4TTextKwargs(TextKwargs): + src_lang: Optional[str] + tgt_lang: Optional[str] + + +class SeamlessM4TProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: SeamlessM4TTextKwargs + _defaults = {} + + +class SeamlessM4TProcessor(ProcessorMixin): + r""" + Constructs a SeamlessM4T processor which wraps a SeamlessM4T feature extractor and a SeamlessM4T tokenizer into a + single processor. + + [`SeamlessM4TProcessor`] offers all the functionalities of [`SeamlessM4TFeatureExtractor`] and + [`SeamlessM4TTokenizerFast`]. See the [`~SeamlessM4TProcessor.__call__`] and [`~SeamlessM4TProcessor.decode`] for + more information. + + Args: + feature_extractor ([`SeamlessM4TFeatureExtractor`]): + The audio processor is a required input. + tokenizer ([`SeamlessM4TTokenizerFast`]): + The tokenizer is a required input. + """ + + feature_extractor_class = "SeamlessM4TFeatureExtractor" + tokenizer_class = ("SeamlessM4TTokenizer", "SeamlessM4TTokenizerFast") + valid_processor_kwargs = SeamlessM4TProcessorKwargs + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + @deprecate_kwarg("audios", version="v4.59.0", new_name="audio") + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + audios: Optional[AudioInput] = None, + audio: Optional[AudioInput] = None, + **kwargs: Unpack[ProcessingKwargs], + ): + """ + Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` + and `kwargs` arguments to SeamlessM4TTokenizerFast's [`~SeamlessM4TTokenizerFast.__call__`] if `text` is not + `None` to encode the text. To prepare the audio(s), this method forwards the `audios` and `kwargs` arguments to + SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audios` is not `None`. Please refer + to the docstring of the above two methods for more information. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + audios (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case + of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, + and T the sample length of the audio. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **input_features** -- Audio input features to be fed to a model. Returned when `audios` is not `None`. + """ + if text is not None and audios is not None: + raise ValueError( + "Text and audios are mututally exclusive when passed to `SeamlessM4T`. Specify one or another." + ) + if audio is None and audios is not None: + logger.warning( + "Passing `audios` as keyword argument is deprecated and will be removed in v4.63, please pass `audio` instead." + ) + audio = audios + return super().__call__(text=text, audio=audio, **kwargs) + + +__all__ = ["SeamlessM4TProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/tokenization_seamless_m4t.py b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/tokenization_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..fd773316580c1dc3961a954dc2fd3a8e0d4face9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/tokenization_seamless_m4t.py @@ -0,0 +1,567 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for SeamlessM4T.""" + +import os +from shutil import copyfile +from typing import Any, Optional, Union + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import ( + BatchEncoding, + PreTokenizedInput, + PreTrainedTokenizer, + TextInput, +) +from ...tokenization_utils_base import AddedToken +from ...utils import PaddingStrategy, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +@requires(backends=("sentencepiece",)) +class SeamlessM4TTokenizer(PreTrainedTokenizer): + """ + Construct a SeamlessM4T tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and ` ` for target language documents. + + Examples: + + ```python + >>> from transformers import SeamlessM4TTokenizer + + >>> tokenizer = SeamlessM4TTokenizer.from_pretrained( + ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + src_lang (`str`, *optional*, defaults to `"eng"`): + The language to use as source language for translation. + tgt_lang (`str`, *optional*, defaults to `"fra"`): + The language to use as target language for translation. + sp_model_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the model initialization. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. Can be used to specify the list of languages that will be + supported by the tokenizer. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: list[int] = [] + suffix_tokens: list[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + tokenizer_file=None, + src_lang="eng", + tgt_lang="fra", + sp_model_kwargs: Optional[dict[str, Any]] = None, + additional_special_tokens=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + # Add this unused argument to keep some important Copied from statements + self.legacy = False + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ---- + # spm | '' | '' | '' | 'an' | 'en' | '_d' | 'er' | 'in' | '_s' | '_a' + # fairseq | '' | '' | '' | '' | 'an' | 'en' | '▁d' | 'er' | 'in' | '▁s' + + # Mimic fairseq token-to-id alignment for the first 4 token + self._added_tokens_decoder = { + 0: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token, + 1: AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token, + 2: AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token, + 3: AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token, + } + + # The first "real" token "an" has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + + self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang + self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + self.set_src_lang_special_tokens(self._src_lang) + self.set_tgt_lang_special_tokens(self._tgt_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + return len(self.sp_model) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] + ] = None, + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + **kwargs, + ): + """ + Args: + text (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + src_lang (`str`, *optional*): + A string representing the source language. If not specified, the last `src_lang` specified (either + during initialization or when calling this tokenizer) will be used. + tgt_lang (`str`, *optional*): + A string representing the target language. If not specified, the last `tgt_lang` specified (either + during initialization or when calling this tokenizer) will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizer.__call__`]. + """ + if src_lang is not None: + self.src_lang = src_lang + if tgt_lang is not None: + self.tgt_lang = tgt_lang + + output = super().__call__( + text=text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + padding=padding, + pad_to_multiple_of=pad_to_multiple_of, + **kwargs, + ) + + return BatchEncoding(output, tensor_type=kwargs.get("return_tensors")) + + @property + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + if "__" not in new_src_lang: + self._src_lang = f"__{new_src_lang}__" + else: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang: str) -> None: + if "__" not in new_tgt_lang: + self._tgt_lang = f"__{new_tgt_lang}__" + else: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(self._tgt_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model.") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + if "__" not in tgt_lang: + tgt_lang = f"__{tgt_lang}__" + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = { + self.convert_ids_to_tokens(i): i for i in range(self.fairseq_offset, self.vocab_size + self.fairseq_offset) + } + vocab.update(self.added_tokens_encoder) + return vocab + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> list[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.prepare_seq2seq_batch with eng_Latn->eng, fra_Latn->fra + def prepare_seq2seq_batch( + self, + src_texts: list[str], + src_lang: str = "eng", + tgt_texts: Optional[list[str]] = None, + tgt_lang: str = "fra", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_input_mode + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_target_mode + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + self.init_kwargs["src_lang"] = src_lang + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`src_lang={src_lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + # https://github.com/facebookresearch/fairseq2/blob/c53f18e6be6b8b46b722f2249b8397b7eccd7ad3/src/fairseq2/models/nllb/tokenizer.py#L112-L116 + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + Prefix=[eos, tgt_lang_code] and suffix=[eos]. + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + self.init_kwargs["tgt_lang"] = lang + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.prefix_tokens = [self.eos_token_id, self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + +__all__ = ["SeamlessM4TTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..0318336332c3c9e4ff9dfad1a050cc7b0a3fe47d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py @@ -0,0 +1,446 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for SeamlessM4T.""" + +import os +from shutil import copyfile +from typing import Optional, Union + +from tokenizers import processors + +from ...tokenization_utils import ( + BatchEncoding, + PreTokenizedInput, + TextInput, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import PaddingStrategy, is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_seamless_m4t import SeamlessM4TTokenizer +else: + SeamlessM4TTokenizer = None + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +class SeamlessM4TTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" SeamlessM4T tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + The tokenization method is ` ` for source language documents, and ` ` for target language documents. + + Examples: + + ```python + >>> from transformers import SeamlessM4TTokenizerFast + + >>> tokenizer = SeamlessM4TTokenizerFast.from_pretrained( + ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + src_lang (`str`, *optional*, defaults to `"eng"`): + The language to use as source language for translation. + tgt_lang (`str`, *optional*, defaults to `"fra"`): + The language to use as target language for translation. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = SeamlessM4TTokenizer + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: list[int] = [] + suffix_tokens: list[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + src_lang="eng", + tgt_lang="fra", + additional_special_tokens=None, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang + self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + self.set_tgt_lang_special_tokens(self._tgt_lang) + + @property + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + if "__" not in new_src_lang: + self._src_lang = f"__{new_src_lang}__" + else: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang: str) -> None: + if "__" not in new_tgt_lang: + self._tgt_lang = f"__{new_tgt_lang}__" + else: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(self._tgt_lang) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. The special tokens depend on calling set_lang. + + An SeamlessM4T sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `[src_lang_code] X [eos]` + - `decoder_input_ids`: (for decoder) `[eos, tgt_lang_code] X [eos]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + if "__" not in tgt_lang: + tgt_lang = f"__{tgt_lang}__" + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.prepare_seq2seq_batch with "fra_Latn"->"fra", "eng_Latn"->"eng" + def prepare_seq2seq_batch( + self, + src_texts: list[str], + src_lang: str = "eng", + tgt_texts: Optional[list[str]] = None, + tgt_lang: str = "fra", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_input_mode + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_target_mode + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={src_lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.init_kwargs["src_lang"] = src_lang + + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + Prefix=[eos, tgt_lang_code] and suffix=[eos]. + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.init_kwargs["tgt_lang"] = lang + + self.prefix_tokens = [self.eos_token_id, self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory.") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + @classmethod + def _from_pretrained( + cls, + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=None, + cache_dir=None, + local_files_only=False, + _commit_hash=None, + _is_local=False, + **kwargs, + ): + tokenizer = super()._from_pretrained( + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=_commit_hash, + _is_local=_is_local, + **kwargs, + ) + + # ensure also set after from pretrained + tokenizer.set_src_lang_special_tokens(tokenizer._src_lang) + tokenizer.set_tgt_lang_special_tokens(tokenizer._tgt_lang) + + return tokenizer + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] + ] = None, + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + **kwargs, + ): + """ + Args: + text (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + src_lang (`str`, *optional*): + A string representing the source language. If not specified, the last `src_lang` specified (either + during initialization or when calling this tokenizer) will be used. + tgt_lang (`str`, *optional*): + A string representing the target language. If not specified, the last `tgt_lang` specified (either + during initialization or when calling this tokenizer) will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizerFast.__call__`]. + """ + if src_lang is not None: + self.src_lang = src_lang + if tgt_lang is not None: + self.tgt_lang = tgt_lang + + output = super().__call__( + text=text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + padding=padding, + pad_to_multiple_of=pad_to_multiple_of, + **kwargs, + ) + + return output + + +__all__ = ["SeamlessM4TTokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seggpt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/seggpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec646f2f81a21d50d50b67177638d275427bc42b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seggpt/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_seggpt import * + from .image_processing_seggpt import * + from .modeling_seggpt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/seggpt/configuration_seggpt.py b/venv/lib/python3.13/site-packages/transformers/models/seggpt/configuration_seggpt.py new file mode 100644 index 0000000000000000000000000000000000000000..dc149adfa0337d89ae5c537a50ea0157bde7dd3e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seggpt/configuration_seggpt.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SegGpt model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SegGptConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SegGptModel`]. It is used to instantiate a SegGPT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the SegGPT + [BAAI/seggpt-vit-large](https://huggingface.co/BAAI/seggpt-vit-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`list[int]`, *optional*, defaults to `[896, 448]`): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If unset, defaults to + `hidden_size` * 4. + drop_path_rate (`float`, *optional*, defaults to 0.1): + The drop path rate for the dropout layers. + pretrain_image_size (`int`, *optional*, defaults to 224): + The pretrained size of the absolute position embeddings. + decoder_hidden_size (`int`, *optional*, defaults to 64): + Hidden size for decoder. + use_relative_position_embeddings (`bool`, *optional*, defaults to `True`): + Whether to use relative position embeddings in the attention layers. + merge_index (`int`, *optional*, defaults to 2): + The index of the encoder layer to merge the embeddings. + intermediate_hidden_state_indices (`list[int]`, *optional*, defaults to `[5, 11, 17, 23]`): + The indices of the encoder layers which we store as features for the decoder. + beta (`float`, *optional*, defaults to 0.01): + Regularization factor for SegGptLoss (smooth-l1 loss). + + Example: + + ```python + >>> from transformers import SegGptConfig, SegGptModel + + >>> # Initializing a SegGPT seggpt-vit-large style configuration + >>> configuration = SegGptConfig() + + >>> # Initializing a model (with random weights) from the seggpt-vit-large style configuration + >>> model = SegGptModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "seggpt" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + hidden_act="gelu", + hidden_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=[896, 448], + patch_size=16, + num_channels=3, + qkv_bias=True, + mlp_dim=None, + drop_path_rate=0.1, + pretrain_image_size=224, + decoder_hidden_size=64, + use_relative_position_embeddings=True, + merge_index=2, + intermediate_hidden_state_indices=[5, 11, 17, 23], + beta=0.01, + **kwargs, + ): + super().__init__(**kwargs) + + if merge_index > min(intermediate_hidden_state_indices): + raise ValueError( + f"Merge index must be less than the minimum encoder output index, but got {merge_index=} and {intermediate_hidden_state_indices=}" + ) + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.drop_path_rate = drop_path_rate + self.pretrain_image_size = pretrain_image_size + self.decoder_hidden_size = decoder_hidden_size + self.use_relative_position_embeddings = use_relative_position_embeddings + self.merge_index = merge_index + self.intermediate_hidden_state_indices = intermediate_hidden_state_indices + self.beta = beta + self.mlp_dim = int(hidden_size * 4) if mlp_dim is None else mlp_dim + + +__all__ = ["SegGptConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seggpt/image_processing_seggpt.py b/venv/lib/python3.13/site-packages/transformers/models/seggpt/image_processing_seggpt.py new file mode 100644 index 0000000000000000000000000000000000000000..ffadfaf85edbe7ca4098d0ba73702e8117b7797c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seggpt/image_processing_seggpt.py @@ -0,0 +1,615 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SegGPT.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, is_torch_available, logging, requires_backends + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +# See https://huggingface.co/papers/2212.02499 at 3.1 Redefining Output Spaces as "Images" - Semantic Segmentation from PAINTER paper +# Taken from https://github.com/Abdullah-Meda/Painter/blob/main/Painter/data/coco_semseg/gen_color_coco_panoptic_segm.py#L31 +def build_palette(num_labels: int) -> list[tuple[int, int]]: + base = int(num_labels ** (1 / 3)) + 1 + margin = 256 // base + + # we assume that class_idx 0 is the background which is mapped to black + color_list = [(0, 0, 0)] + for location in range(num_labels): + num_seq_r = location // base**2 + num_seq_g = (location % base**2) // base + num_seq_b = location % base + + R = 255 - num_seq_r * margin + G = 255 - num_seq_g * margin + B = 255 - num_seq_b * margin + + color_list.append((R, G, B)) + + return color_list + + +def mask_to_rgb( + mask: np.ndarray, palette: Optional[list[tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None +) -> np.ndarray: + data_format = data_format if data_format is not None else ChannelDimension.FIRST + + if palette is not None: + height, width = mask.shape + + rgb_mask = np.zeros((3, height, width), dtype=np.uint8) + + classes_in_mask = np.unique(mask) + + for class_idx in classes_in_mask: + rgb_value = palette[class_idx] + class_mask = (mask == class_idx).astype(np.uint8) + class_mask = np.expand_dims(class_mask, axis=-1) + class_rgb_mask = class_mask * np.array(rgb_value) + class_rgb_mask = np.moveaxis(class_rgb_mask, -1, 0) + rgb_mask += class_rgb_mask.astype(np.uint8) + + rgb_mask = np.clip(rgb_mask, 0, 255).astype(np.uint8) + + else: + rgb_mask = np.repeat(mask[None, ...], 3, axis=0) + + return to_channel_dimension_format(rgb_mask, data_format) + + +class SegGptImageProcessor(BaseImageProcessor): + r""" + Constructs a SegGpt image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 448, "width": 448} + size = get_size_dict(size) + self.do_resize = do_resize + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_convert_rgb = do_convert_rgb + + def get_palette(self, num_labels: int) -> list[tuple[int, int]]: + """Build a palette to map the prompt mask from a single channel to a 3 channel RGB. + + Args: + num_labels (`int`): + Number of classes in the segmentation task (excluding the background). + + Returns: + `list[tuple[int, int]]`: Palette to map the prompt mask from a single channel to a 3 channel RGB. + """ + return build_palette(num_labels) + + def mask_to_rgb( + self, + image: np.ndarray, + palette: Optional[list[tuple[int, int]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Converts a segmentation map to RGB format. + + Args: + image (`np.ndarray`): + Segmentation map with dimensions (height, width) where pixel values represent the class index. + palette (`list[tuple[int, int]]`, *optional*, defaults to `None`): + Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel + dimension. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The mask in RGB format. + """ + return mask_to_rgb(image, palette=palette, data_format=data_format) + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess_step( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + num_labels: Optional[int] = None, + **kwargs, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has + an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built + to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated + across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format. + num_labels: (`int`, *optional*): + Number of classes in the segmentation task (excluding the background). If specified, a palette will be + built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx + channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed + through as is if it is already in RGB format or being duplicated across the channel dimension. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + resample = resample if resample is not None else self.resample + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size_dict = get_size_dict(size) + + # If segmentation map is passed we expect 2D images + images = make_flat_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None: + raise ValueError("Size must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None and not do_convert_rgb: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_convert_rgb: + palette = self.get_palette(num_labels) if num_labels is not None else None + # Since this is the input for the next transformations its format should be the same as the input_data_format + images = [ + self.mask_to_rgb(image=image, palette=palette, data_format=ChannelDimension.FIRST) for image in images + ] + input_data_format = ChannelDimension.FIRST + + if do_resize: + images = [ + self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + return images + + def preprocess( + self, + images: Optional[ImageInput] = None, + prompt_images: Optional[ImageInput] = None, + prompt_masks: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + num_labels: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + prompt_images (`ImageInput`): + Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + prompt_masks (`ImageInput`): + Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output. + Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of + RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels` + specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to + a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel + dimension. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has + an effect if `do_resize` is set to `True`. Doesn't apply to prompt mask as it is resized using nearest. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built + to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated + across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format. + num_labels: (`int`, *optional*): + Number of classes in the segmentation task (excluding the background). If specified, a palette will be + built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map + with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed + through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated + across the channel dimension. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + if all(v is None for v in [images, prompt_images, prompt_masks]): + raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.") + + data = {} + + if images is not None: + images = self._preprocess_step( + images, + is_mask=False, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=False, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + data["pixel_values"] = images + + if prompt_images is not None: + prompt_images = self._preprocess_step( + prompt_images, + is_mask=False, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=False, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + data["prompt_pixel_values"] = prompt_images + + if prompt_masks is not None: + prompt_masks = self._preprocess_step( + prompt_masks, + do_resize=do_resize, + size=size, + resample=PILImageResampling.NEAREST, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=do_convert_rgb, + num_labels=num_labels, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + data["prompt_masks"] = prompt_masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None, num_labels: Optional[int] = None + ): + """ + Converts the output of [`SegGptImageSegmentationOutput`] into segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`SegGptImageSegmentationOutput`]): + Raw outputs of the model. + target_sizes (`list[tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + num_labels (`int`, *optional*): + Number of classes in the segmentation task (excluding the background). If specified, a palette will be + built, assuming that class_idx 0 is the background, to map prediction masks from RGB values to class + indices. This value should be the same used when preprocessing inputs. + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + requires_backends(self, ["torch"]) + # batch_size x num_channels x 2*height x width + masks = outputs.pred_masks + + # Predicted mask and prompt are concatenated in the height dimension + # batch_size x num_channels x height x width + masks = masks[:, :, masks.shape[2] // 2 :, :] + + # To unnormalize we need to permute to channel last + # batch_size x height x width x num_channels + std = torch.tensor(self.image_std).to(masks.device) + mean = torch.tensor(self.image_mean).to(masks.device) + + masks = masks.permute(0, 2, 3, 1) * std + mean + + # batch_size x num_channels x height x width + masks = masks.permute(0, 3, 1, 2) + + # Clip to match with palette if specified + masks = torch.clip(masks * 255, 0, 255) + + semantic_segmentation = [] + palette_tensor = None + palette = self.get_palette(num_labels) if num_labels is not None else None + if palette is not None: + palette_tensor = torch.tensor(palette).to(device=masks.device, dtype=torch.float) + _, num_channels, _, _ = masks.shape + palette_tensor = palette_tensor.view(1, 1, num_labels + 1, num_channels) + + for idx, mask in enumerate(masks): + if target_sizes is not None: + mask = torch.nn.functional.interpolate( + mask.unsqueeze(0), + size=target_sizes[idx], + mode="nearest", + )[0] + + if num_labels is not None: + channels, height, width = mask.shape + dist = mask.permute(1, 2, 0).view(height, width, 1, channels) + dist = dist - palette_tensor + dist = torch.pow(dist, 2) + dist = torch.sum(dist, dim=-1) + pred = dist.argmin(dim=-1) + + else: + # If no palette is specified SegGpt will try to paint using the mask class idx as RGB + pred = mask.mean(dim=0).int() + + semantic_segmentation.append(pred) + + return semantic_segmentation + + +__all__ = ["SegGptImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/seggpt/modeling_seggpt.py b/venv/lib/python3.13/site-packages/transformers/models/seggpt/modeling_seggpt.py new file mode 100644 index 0000000000000000000000000000000000000000..7e82d26c9e74279bfc5de27e404fc425c0908380 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/seggpt/modeling_seggpt.py @@ -0,0 +1,979 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SegGpt model.""" + +import collections.abc +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging, torch_int +from .configuration_seggpt import SegGptConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`SegGptEncoderOutput`]. + """ +) +class SegGptEncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape `(batch_size, patch_height, patch_width, hidden_size)`. + attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`): + Tuple of *torch.FloatTensor* (one for each layer) of shape + `(batch_size, num_heads, seq_len, seq_len)`. + intermediate_hidden_states (`tuple[torch.FloatTensor]`, *optional*, returned when `config.intermediate_hidden_state_indices` is set): + Tuple of `torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`. + Each element in the Tuple corresponds to the output of the layer specified in `config.intermediate_hidden_state_indices`. + Additionally, each feature passes through a LayerNorm. + """ + + last_hidden_state: torch.FloatTensor + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + intermediate_hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`SegGptImageSegmentationOutput`]. + """ +) +class SegGptImageSegmentationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + The loss value. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The predicted masks. + hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape `(batch_size, patch_height, patch_width, hidden_size)`. + attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, seq_len, seq_len)`. + """ + + loss: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.sam.modeling_sam.SamPatchEmbeddings with Sam->SegGpt +class SegGptPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class SegGptEmbeddings(nn.Module): + """ + Construct the embeddings from patch, position embeddings for input and prompt. + """ + + def __init__(self, config: SegGptConfig) -> None: + super().__init__() + + self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + self.segment_token_input = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + self.segment_token_prompt = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + # token for seg types + self.type_token_semantic = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + self.type_token_instance = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + + self.patch_embeddings = SegGptPatchEmbeddings(config) + + num_positions = (config.pretrain_image_size // config.patch_size) ** 2 + 1 + self.position_embeddings = nn.Parameter(torch.randn(1, num_positions, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, height: int, width: int) -> torch.Tensor: + patch_pos_embed = self.position_embeddings[:, 1:] + num_patches = patch_pos_embed.shape[1] + pretrain_patch_size = torch_int(num_patches**0.5) + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if torch.jit.is_tracing() or pretrain_patch_size != height or pretrain_patch_size != width: + patch_pos_embed = F.interpolate( + patch_pos_embed.reshape(1, pretrain_patch_size, pretrain_patch_size, -1).permute(0, 3, 1, 2), + size=(height, width), + mode="bicubic", + align_corners=False, + ) + + return patch_pos_embed.permute(0, 2, 3, 1) + else: + return patch_pos_embed.reshape(1, height, width, -1) + + def forward( + self, + pixel_values: torch.Tensor, + prompt_pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + embedding_type: Optional[str] = None, + ) -> torch.Tensor: + input_embeddings = self.patch_embeddings(pixel_values) + prompt_embeddings = self.patch_embeddings(prompt_pixel_values) + + batch_size, patch_height, patch_width, _ = input_embeddings.shape + + mask_token = self.mask_token.expand(batch_size, patch_height, patch_width, -1) + # replace the masked visual tokens by mask_token + w = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(-1, patch_height, patch_width, 1) + prompt_embeddings = prompt_embeddings * (1 - w) + mask_token * w + + embedding_type = embedding_type if embedding_type is not None else "instance" + + # add positional encoding to each token + pos_embed = self.interpolate_pos_encoding(patch_height, patch_width) + + # add segment token + input_embeddings = input_embeddings + self.segment_token_input + prompt_embeddings = prompt_embeddings + self.segment_token_prompt + + # add position embedding skipping CLS + input_embeddings = input_embeddings + pos_embed + prompt_embeddings = prompt_embeddings + pos_embed + + # add type embedding to each token + if embedding_type == "semantic": + type_embedding = self.type_token_semantic + elif embedding_type == "instance": + type_embedding = self.type_token_instance + else: + raise ValueError(f"Embedding type should be either 'semantic' or 'instance', but got {embedding_type}") + + input_embeddings = input_embeddings + type_embedding + prompt_embeddings = prompt_embeddings + type_embedding + + embeddings = torch.cat((input_embeddings, prompt_embeddings), dim=0) + + return embeddings + + +class SegGptAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + + input_size = (image_size[0] // config.patch_size, image_size[1] // config.patch_size) + head_dim = config.hidden_size // config.num_attention_heads + + self.num_attention_heads = config.num_attention_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_relative_position_embeddings = config.use_relative_position_embeddings + if self.use_relative_position_embeddings: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def add_decomposed_rel_pos( + self, + attn: torch.Tensor, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: tuple[int, int], + k_size: tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) + attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) + return attn + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_relative_position_embeddings: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_attention_heads, height * width, -1) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_attention_heads, height * width, -1) + else: + attn_weights_reshaped = None + + attn_output = (attn_weights @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + return (attn_output, attn_weights_reshaped) + + +# Copied from transformers.models.sam.modeling_sam.SamMLPBlock with SamMLPBlock->SegGptMlp +class SegGptMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->SegGpt +class SegGptDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class SegGptLayer(GradientCheckpointingLayer): + def __init__(self, config: SegGptConfig, drop_path_rate: float) -> None: + super().__init__() + self.attention = SegGptAttention(config) + self.mlp = SegGptMlp(config) + self.drop_path = SegGptDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ensemble_cond: int, + feature_ensemble: bool = False, + output_attentions: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in SegGpt, layernorm is applied before self-attention + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if feature_ensemble and attention_output.shape[0] // 2 >= ensemble_cond: + prompt, inputs = attention_output.split(attention_output.shape[1] // 2, dim=1) + if ensemble_cond == 2: + num_prompts = attention_output.shape[0] // 2 + inputs = inputs.reshape(2, num_prompts, -1) + inputs = inputs.mean(dim=1, keepdim=True).expand_as(inputs) + inputs = inputs.reshape(*prompt.shape) + else: + inputs = inputs.mean(dim=0, keepdim=True).expand_as(inputs) + attention_output = torch.cat([prompt, inputs], dim=1) + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + residual = hidden_states + + hidden_states = self.layernorm_after(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.drop_path(hidden_states) + + outputs = (hidden_states,) + outputs + + return outputs + + +class SegGptEncoder(nn.Module): + def __init__(self, config: SegGptConfig) -> None: + super().__init__() + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")] + self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)]) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + feature_ensemble: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, SegGptEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + intermediate_hidden_states = [] + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Condition to check if we have the appropriate number of prompts to ensemble + ensemble_cond = 2 if self.config.merge_index > i else 1 + + layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions) + + hidden_states = layer_outputs[0] + + if i == self.config.merge_index: + hidden_states = ( + hidden_states[: hidden_states.shape[0] // 2] + hidden_states[hidden_states.shape[0] // 2 :] + ) * 0.5 + + if i in self.config.intermediate_hidden_state_indices: + intermediate_hidden_states.append(self.layernorm(hidden_states)) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, intermediate_hidden_states] + if v is not None + ) + return SegGptEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + intermediate_hidden_states=intermediate_hidden_states, + ) + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->SegGpt +class SegGptLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +class SegGptDecoderHead(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv2d( + config.decoder_hidden_size, + config.decoder_hidden_size, + kernel_size=3, + padding=1, + ) + self.layernorm = SegGptLayerNorm( + normalized_shape=config.decoder_hidden_size, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.act_fct = ACT2FN[config.hidden_act] + self.head = nn.Conv2d(config.decoder_hidden_size, 3, kernel_size=1, bias=True) # decoder to patch + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.conv(hidden_states) + hidden_states = self.layernorm(hidden_states) + hidden_states = self.act_fct(hidden_states) + hidden_states = self.head(hidden_states) + + return hidden_states + + +class SegGptDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.decoder_embed = nn.Linear( + config.hidden_size * len(config.intermediate_hidden_state_indices), + config.patch_size**2 * config.decoder_hidden_size, + bias=True, + ) + self.decoder_pred = SegGptDecoderHead(config) + self.patch_size = config.patch_size + self.decoder_hidden_size = config.decoder_hidden_size + self.config = config + + def _reshape_hidden_states(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + batch_size, patch_height, patch_width, _ = hidden_states.shape + hidden_states = hidden_states.reshape( + batch_size, patch_height, patch_width, self.patch_size, self.patch_size, self.decoder_hidden_size + ) + hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) + hidden_states = hidden_states.reshape( + shape=(batch_size, -1, patch_height * self.patch_size, patch_width * self.patch_size) + ) + + return hidden_states + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.decoder_embed(hidden_states) + hidden_states = self._reshape_hidden_states(hidden_states) + hidden_states = self.decoder_pred(hidden_states) + + return hidden_states + + +@auto_docstring +class SegGptPreTrainedModel(PreTrainedModel): + config: SegGptConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["SegGptEmbeddings", "SegGptLayer"] + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=std).to( + module.weight.dtype + ) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, SegGptLayerNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, SegGptAttention): + module.rel_pos_h.data = nn.init.trunc_normal_( + module.rel_pos_h.data.to(torch.float32), + mean=0.0, + std=std, + ).to(module.rel_pos_h.dtype) + + module.rel_pos_w.data = nn.init.trunc_normal_( + module.rel_pos_w.data.to(torch.float32), + mean=0.0, + std=std, + ).to(module.rel_pos_w.dtype) + + elif isinstance(module, SegGptEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=std, + ).to(module.position_embeddings.dtype) + + torch.nn.init.normal_(module.mask_token, std=std) + torch.nn.init.normal_(module.segment_token_input, std=std) + torch.nn.init.normal_(module.segment_token_prompt, std=std) + torch.nn.init.normal_(module.type_token_semantic, std=std) + torch.nn.init.normal_(module.type_token_instance, std=std) + + +@auto_docstring +class SegGptModel(SegGptPreTrainedModel): + def __init__(self, config: SegGptConfig): + super().__init__(config) + self.config = config + + self.embeddings = SegGptEmbeddings(config) + self.encoder = SegGptEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> SegGptPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + prompt_pixel_values: torch.Tensor, + prompt_masks: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + feature_ensemble: Optional[bool] = None, + embedding_type: Optional[str] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SegGptEncoderOutput]: + r""" + prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See + [`SegGptImageProcessor.__call__`] for details. + prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for + details. + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + feature_ensemble (`bool`, *optional*): + Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble + if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should + be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image. + embedding_type (`str`, *optional*): + Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either + instance or semantic. + labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`): + Ground truth mask for input images. + + Examples: + + ```python + >>> from transformers import SegGptImageProcessor, SegGptModel + >>> from PIL import Image + >>> import requests + + >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg" + >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg" + >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png" + + >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw) + >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw) + >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L") + + >>> checkpoint = "BAAI/seggpt-vit-large" + >>> model = SegGptModel.from_pretrained(checkpoint) + >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint) + + >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> list(outputs.last_hidden_state.shape) + [1, 56, 28, 1024] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + feature_ensemble = feature_ensemble if feature_ensemble is not None else False + + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + pixel_values = pixel_values.to(expected_dtype) + prompt_pixel_values = prompt_pixel_values.to(expected_dtype) + + # Prepare inputs + pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2) + prompt_pixel_values = ( + torch.cat((prompt_masks, prompt_masks), dim=2) + if labels is None + else torch.cat((prompt_masks, labels), dim=2) + ) + + if bool_masked_pos is None and labels is not None: + logger.warning_once( + "Labels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos." + ) + + # We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion + # of the mask prompt pixels that will be destinated to the prediction as they don't add any information. + # This is only the case for inference. In training, the model concat of prompt mask and label is masked + # and reconstructed together (In-Context Painting). + if bool_masked_pos is None: + num_patches = self.embeddings.patch_embeddings.num_patches + bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device) + bool_masked_pos_ones = torch.ones( + num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device + ) + bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones]) + bool_masked_pos = bool_masked_pos.unsqueeze(0) + + embedding_output = self.embeddings( + pixel_values, prompt_pixel_values, embedding_type=embedding_type, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + feature_ensemble=feature_ensemble, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +def patchify(tensor: torch.Tensor, patch_size: int) -> torch.Tensor: + batch_size, num_channels, height, width = tensor.shape + patch_height = height // patch_size + patch_width = width // patch_size + + tensor = tensor.reshape(shape=(batch_size, num_channels, patch_height, patch_size, patch_width, patch_size)) + tensor = tensor.permute(0, 2, 4, 3, 5, 1) + tensor = tensor.reshape(shape=(batch_size, patch_height * patch_width, patch_size**2 * 3)) + + return tensor + + +def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> torch.Tensor: + batch_size = tensor.shape[0] + patch_size = int((tensor.shape[-1] / 3) ** 0.5) + if patch_height * patch_width != tensor.shape[1]: + raise ValueError( + f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})." + ) + + tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3)) + tensor = tensor.permute(0, 5, 1, 3, 2, 4) + tensor = tensor.reshape(shape=(batch_size, 3, patch_height * patch_size, patch_width * patch_size)) + + return tensor + + +class SegGptLoss(nn.Module): + def __init__(self, config): + super().__init__() + self.beta = config.beta + self.patch_size = config.patch_size + + def forward( + self, + prompt_masks: torch.FloatTensor, + pred_masks: torch.FloatTensor, + labels: torch.FloatTensor, + bool_masked_pos: torch.BoolTensor, + ): + """Computes the L1 loss between the predicted masks and the ground truth masks. + + Args: + prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values from mask prompt. + + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`): + Predicted masks. + + labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Ground truth mask for input images. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + `torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks. + """ + ground_truth = torch.cat((prompt_masks, labels), dim=2) + + mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3) + mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size) + + loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta) + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + + return loss + + +@auto_docstring( + custom_intro=""" + SegGpt model with a decoder on top for one-shot image segmentation. + """ +) +class SegGptForImageSegmentation(SegGptPreTrainedModel): + def __init__(self, config: SegGptConfig): + super().__init__(config) + self.config = config + + self.model = SegGptModel(config) + self.decoder = SegGptDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + prompt_pixel_values: torch.Tensor, + prompt_masks: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + feature_ensemble: Optional[bool] = None, + embedding_type: Optional[str] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SegGptImageSegmentationOutput]: + r""" + prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See + [`SegGptImageProcessor.__call__`] for details. + prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for + details. + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + feature_ensemble (`bool`, *optional*): + Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble + if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should + be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image. + embedding_type (`str`, *optional*): + Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either + instance or semantic. + labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`): + Ground truth mask for input images. + + Examples: + + ```python + >>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation + >>> from PIL import Image + >>> import requests + + >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg" + >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg" + >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png" + + >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw) + >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw) + >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L") + + >>> checkpoint = "BAAI/seggpt-vit-large" + >>> model = SegGptForImageSegmentation.from_pretrained(checkpoint) + >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint) + + >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt") + >>> outputs = model(**inputs) + >>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0] + >>> print(list(result.shape)) + [170, 297] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is None: + num_patches = self.model.embeddings.patch_embeddings.num_patches + bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device) + bool_masked_pos_ones = torch.ones( + num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device + ) + bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones]) + bool_masked_pos = bool_masked_pos.unsqueeze(0) + + outputs = self.model( + pixel_values=pixel_values, + prompt_pixel_values=prompt_pixel_values, + prompt_masks=prompt_masks, + bool_masked_pos=bool_masked_pos, + feature_ensemble=feature_ensemble, + embedding_type=embedding_type, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + intermediate_hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[-1] + intermediate_hidden_states = torch.cat(intermediate_hidden_states, dim=-1) + pred_masks = self.decoder(intermediate_hidden_states) + + loss = None + if labels is not None: + loss_fn = SegGptLoss(self.config) + loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos) + + if not return_dict: + output = (pred_masks,) + if output_hidden_states: + output = output + (outputs[1],) + + if output_attentions: + idx = 2 if output_hidden_states else 1 + output = output + (outputs[idx],) + + if loss is not None: + output = (loss,) + output + return output + + return SegGptImageSegmentationOutput( + loss=loss, + pred_masks=pred_masks, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["SegGptModel", "SegGptPreTrainedModel", "SegGptForImageSegmentation"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sew_d/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/sew_d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7902aa464a9cfd5b465c2448014e218f4540c9d3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sew_d/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sew_d import * + from .modeling_sew_d import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/sew_d/configuration_sew_d.py b/venv/lib/python3.13/site-packages/transformers/models/sew_d/configuration_sew_d.py new file mode 100644 index 0000000000000000000000000000000000000000..2809fa0fb6ba1bc2accc9afb3f0de702760ef602 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sew_d/configuration_sew_d.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SEW-D model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SEWDConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SEWDModel`]. It is used to instantiate a SEW-D + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the SEW-D + [asapp/sew-d-tiny-100k](https://huggingface.co/asapp/sew-d-tiny-100k) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32): + Vocabulary size of the SEW-D model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`SEWD`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + squeeze_factor (`int`, *optional*, defaults to 2): + Sequence length downsampling factor after the encoder and upsampling factor after the transformer. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + position_buckets (`int`, *optional*, defaults to 256): + The maximum size of relative position embeddings. + share_att_key (`bool`, *optional*, defaults to `True`): + Whether to share attention key with c2p and p2c. + relative_attention (`bool`, *optional*, defaults to `True`): + Whether to use relative position encoding. + pos_att_type (`tuple[str]`, *optional*, defaults to `("p2c", "c2p")`): + The type of relative position attention, it can be a combination of `("p2c", "c2p")`, e.g. `("p2c")`, + `("p2c", "c2p")`, `("p2c", "c2p")`. + norm_rel_ebd (`str`, *optional*, defaults to `"layer_norm"`): + Whether to use layer norm in relative embedding (`"layer_norm"` if yes) + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"`, `"gelu_python"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + Deprecated. Not used by the model and will be removed in a future version. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`SEWDForCTC`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-7): + The epsilon used by the layer normalization layers in the transformer encoder. + feature_layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization after the feature encoder. + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the feature encoder. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://huggingface.co/papers/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the probability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + diversity_loss_weight (`int`, *optional*, defaults to 0.1): + The weight of the codebook diversity loss component. + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`SEWDForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`SEWDForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`Wav2Vec2ForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + + Example: + + ```python + >>> from transformers import SEWDConfig, SEWDModel + + >>> # Initializing a SEW-D asapp/sew-d-tiny-100k style configuration + >>> configuration = SEWDConfig() + + >>> # Initializing a model (with random weights) from the asapp/sew-d-tiny-100k style configuration + >>> model = SEWDModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "sew-d" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + squeeze_factor=2, + max_position_embeddings=512, + position_buckets=256, + share_att_key=True, + relative_attention=True, + pos_att_type=("p2c", "c2p"), + norm_rel_ebd="layer_norm", + hidden_act="gelu_python", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + final_dropout=0.1, + initializer_range=0.02, + layer_norm_eps=1e-7, + feature_layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512), + conv_stride=(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1), + conv_kernel=(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + ctc_loss_reduction="mean", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.squeeze_factor = squeeze_factor + self.max_position_embeddings = max_position_embeddings + self.position_buckets = position_buckets + self.share_att_key = share_att_key + self.relative_attention = relative_attention + self.norm_rel_ebd = norm_rel_ebd + self.pos_att_type = list(pos_att_type) + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self._hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layer_norm_eps = layer_norm_eps + self.feature_layer_norm_eps = feature_layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. " + "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, " + f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) " + f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://huggingface.co/papers/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # sequence classification + self.use_weighted_layer_sum = use_weighted_layer_sum + self.classifier_proj_size = classifier_proj_size + + @property + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. + """ + output = super().to_dict() + output["hidden_dropout"] = output.pop("_hidden_dropout") + return output + + +__all__ = ["SEWDConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/sew_d/modeling_sew_d.py b/venv/lib/python3.13/site-packages/transformers/models/sew_d/modeling_sew_d.py new file mode 100644 index 0000000000000000000000000000000000000000..0458e8764cb4560d38e207017c6dc9c26849e635 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/sew_d/modeling_sew_d.py @@ -0,0 +1,1655 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SEW model.""" + +import math +import warnings +from collections.abc import Sequence +from typing import Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import softmax_backward_data +from ...utils import auto_docstring, logging +from .configuration_sew_d import SEWDConfig + + +logger = logging.get_logger(__name__) + +_HIDDEN_STATES_START_POSITION = 1 + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + ) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + device (`torch.device`): the device on which tensors will be created. + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + """ + + q_ids = torch.arange(0, query_size, device=device) + k_ids = torch.arange(0, key_size, device=device) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = rel_pos_ids.to(torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEWD +class SEWDNoLayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEWD +class SEWDLayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEWD +class SEWDGroupNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.sew.modeling_sew.SEWPositionalConvEmbedding with SEW->SEWD +class SEWDPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + stride=config.squeeze_factor, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW +class SEWDSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.sew.modeling_sew.SEWUpsampling with SEW->SEWD +class SEWDUpsampling(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) + self.activation = ACT2FN[config.feat_extract_activation] + self.squeeze_factor = config.squeeze_factor + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.activation(hidden_states) + + if self.squeeze_factor > 1: + # transform embedding channels to sequence length + bsz, src_len, src_embed_dim = hidden_states.size() + tgt_len = src_len * self.squeeze_factor + tgt_embed_dim = src_embed_dim // self.squeeze_factor + hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) + hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEWD +class SEWDFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [SEWDGroupNormConvLayer(config, layer_id=0)] + [ + SEWDNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [SEWDLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class SEWDFeatureExtractor(SEWDFeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) + + +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(ctx, input, mask, dim): + ctx.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, ctx.dim) + output.masked_fill_(rmask, 0) + ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx, grad_output): + (output,) = ctx.saved_tensors + inputGrad = softmax_backward_data(ctx, grad_output, output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) + r_mask = g.op( + "Cast", + g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx["Bool"], + ) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) + output = softmax(g, output, dim) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool))) + + +class DropoutContext: + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + @staticmethod + def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value: + from torch.onnx import symbolic_opset12 + + dropout_p = local_ctx + if isinstance(local_ctx, DropoutContext): + dropout_p = local_ctx.dropout + # StableDropout only calls this function when training. + train = True + # TODO: We should check if the opset_version being used to export + # is > 12 here, but there's no good way to do that. As-is, if the + # opset_version < 12, export will fail with a CheckerError. + # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like: + # if opset_version < 12: + # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train) + return symbolic_opset12.dropout(g, input, dropout_p, train) + + +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +class SEWDSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.activation_dropout) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.activation_dropout) + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_dropout) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.BoolTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, *optional*): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, *optional*): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype)) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + device=query_layer.device, + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long) + + rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale.to(dtype=c2p_att.dtype) + + # position->content + if "p2c" in self.pos_att_type: + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + device=query_layer.device, + ) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + score += p2c_att / scale.to(dtype=p2c_att.dtype) + + return score + + +class SEWDAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = SEWDSelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->SEWD +class SEWDIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SEWDOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.activation_dropout) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class SEWDLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.attention = SEWDAttention(config) + self.intermediate = SEWDIntermediate(config) + self.output = SEWDOutput(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class SEWDTransformerEncoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([SEWDLayer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, + hidden_states.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + device=hidden_states.device, + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = attention_mask.sum(-2) > 0 + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class SEWDEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = SEWDPositionalConvEmbedding(config) + self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) + self.encoder = SEWDTransformerEncoder(config) + self.upsample = SEWDUpsampling(config) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor + if attention_mask is None: + attention_mask = torch.ones( + (hidden_states.shape[0], max_encoder_length), dtype=torch.long, device=hidden_states.device + ) + else: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask.bool()] = 0.0 + + input_lengths = (attention_mask.long()).sum(-1) + # apply pooling formula to get real output_lengths + output_lengths = input_lengths // self.config.squeeze_factor + attention_ids = ( + torch.arange(0, max_encoder_length, device=output_lengths.device) + .view(1, -1) + .expand(output_lengths.shape[0], -1) + ) + attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() + + n_input_timesteps = hidden_states.shape[1] + + hidden_states = hidden_states.transpose(1, 2) + position_embeddings = self.pos_conv_embed(hidden_states) + pooled_hidden_states = self.pool(hidden_states) + min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) + hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] + hidden_states = hidden_states.transpose(1, 2) + + encoder_outputs = self.encoder(hidden_states, attention_mask, output_hidden_states, output_attentions) + + hidden_states = self.upsample(encoder_outputs.last_hidden_state) + if hidden_states.shape[1] < n_input_timesteps: + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_outputs.hidden_states, encoder_outputs.attentions] if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class SEWDPreTrainedModel(PreTrainedModel): + config: SEWDConfig + base_model_prefix = "sew-d" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SEWDPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +@auto_docstring +# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD, layer_norm_eps->feature_layer_norm_eps +class SEWDModel(SEWDPreTrainedModel): + def __init__(self, config: SEWDConfig): + super().__init__(config) + self.config = config + self.feature_extractor = SEWDFeatureEncoder(config) + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.feature_layer_norm_eps) + + self.project_features = config.conv_dim[-1] != config.hidden_size + if self.project_features: + self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.feature_dropout = nn.Dropout(config.feat_proj_dropout) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = SEWDEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://huggingface.co/papers/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = self.layer_norm(extract_features) + + if self.project_features: + extract_features = self.feature_projection(extract_features) + hidden_states = self.feature_dropout(extract_features) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + + hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). + """ +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD +class SEWDForCTC(SEWDPreTrainedModel): + def __init__(self, config, target_lang: Optional[str] = None): + r""" + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`SEWDForCTC`] with adapters. Uses 'eng' by + default. + """ + super().__init__(config) + + self.sew_d = SEWDModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + self.target_lang = target_lang + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `SEWDForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for SEWD so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang, force_load=True) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.sew_d.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.sew_d.parameters(): + param.requires_grad = False + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.sew_d( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@auto_docstring( + custom_intro=""" + SEWD Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB + Keyword Spotting. + """ +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD +class SEWDForSequenceClassification(SEWDPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of SEWD adapters (config.add_adapter=True)" + ) + self.sew_d = SEWDModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.sew_d.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.sew_d.parameters(): + param.requires_grad = False + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion + into a tensor of type `torch.FloatTensor`. See [`SEWDProcessor.__call__`] for details. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.sew_d( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["SEWDForCTC", "SEWDForSequenceClassification", "SEWDModel", "SEWDPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/siglip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5861f45f45b06bc01bbaa48f96fcc4f638fc590 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_siglip import * + from .image_processing_siglip import * + from .image_processing_siglip_fast import * + from .modeling_siglip import * + from .processing_siglip import * + from .tokenization_siglip import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip/configuration_siglip.py b/venv/lib/python3.13/site-packages/transformers/models/siglip/configuration_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..0c182014fa272bb53b9e597894daa104a864a2c9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip/configuration_siglip.py @@ -0,0 +1,257 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Siglip model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + projection_size (`int`, *optional*, defaults to `hidden_size`): + The size of the projection head. + + Example: + + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + projection_size=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.projection_size = projection_size if projection_size is not None else hidden_size + + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import SiglipConfig, SiglipModel + + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip" + sub_configs = {"text_config": SiglipTextConfig, "vision_config": SiglipVisionConfig} + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + +__all__ = ["SiglipConfig", "SiglipTextConfig", "SiglipVisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip/image_processing_siglip.py b/venv/lib/python3.13/site-packages/transformers/models/siglip/image_processing_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffed5258de5a09d84d88b6800fba3e934495c7e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip/image_processing_siglip.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SigLIP.""" + +from typing import Optional, Union + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class SiglipImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + height, width = size["height"], size["width"] + images = [ + resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["SiglipImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip/image_processing_siglip_fast.py b/venv/lib/python3.13/site-packages/transformers/models/siglip/image_processing_siglip_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..de2426e6eb023fdf4ea9f28e1b14e0773e365917 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip/image_processing_siglip_fast.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + PILImageResampling, +) +from ...utils import auto_docstring + + +@auto_docstring +class SiglipImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + + +__all__ = ["SiglipImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip/modeling_siglip.py b/venv/lib/python3.13/site-packages/transformers/models/siglip/modeling_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..072ec9721f041ef05d06e8ac5c5ba037c86f1248 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip/modeling_siglip.py @@ -0,0 +1,1095 @@ +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Siglip model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + can_return_tuple, + filter_out_non_signature_kwargs, + torch_int, +) +from ...utils.generic import check_model_inputs +from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + r""" + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = SiglipAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring +class SiglipPreTrainedModel(PreTrainedModel): + config: SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": SiglipEncoderLayer, + "attentions": SiglipAttention, + } + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, SiglipForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + uses_flash_attention = "flash" in self.config._attn_implementation + if uses_flash_attention: + attention_mask = None + elif attention_mask is not None and not uses_flash_attention: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # The model uses the last token's hidden state, which may be padding. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + +@auto_docstring( + custom_intro=""" + The text model from SigLIP without any head or projection on top. + """ +) +class SiglipTextModel(SiglipPreTrainedModel): + config: SiglipTextConfig + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @auto_docstring + def forward( + self, + pixel_values, + interpolate_pos_encoding: Optional[bool] = False, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@auto_docstring( + custom_intro=""" + The vision model from SigLIP without any head or projection on top. + """ +) +class SiglipVisionModel(SiglipPreTrainedModel): + config: SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + +@auto_docstring +class SiglipModel(SiglipPreTrainedModel): + config: SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise TypeError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + pooled_output = text_outputs.pooler_output + + return pooled_output + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AutoModel + >>> from transformers.image_utils import load_image + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + pooled_output = vision_outputs.pooler_output + + return pooled_output + + # NOTE: SiglipModel uses Pretrained backbones, so we don't need to add `check_model_inputs` here + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> SiglipOutput: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@auto_docstring( + custom_intro=""" + SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """ +) +class SiglipForImageClassification(SiglipPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: SiglipConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = SiglipVisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SiglipForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `SiglipModel` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + sequence_output = torch.mean(sequence_output, dim=1) + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + ) + + +__all__ = [ + "SiglipModel", + "SiglipPreTrainedModel", + "SiglipTextModel", + "SiglipVisionModel", + "SiglipForImageClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip/processing_siglip.py b/venv/lib/python3.13/site-packages/transformers/models/siglip/processing_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..f0284ae66670768ac399082c30eae997fbddf0cf --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip/processing_siglip.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for SigLIP. +""" + +from ...processing_utils import ProcessorMixin + + +class SiglipProcessor(ProcessorMixin): + r""" + Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. + + [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the + [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`]): + The image processor is a required input. + tokenizer ([`SiglipTokenizer`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + +__all__ = ["SiglipProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip/tokenization_siglip.py b/venv/lib/python3.13/site-packages/transformers/models/siglip/tokenization_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..7a82a7a411ab3e548d055c22e8bbe7d0bc2b3dd3 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip/tokenization_siglip.py @@ -0,0 +1,380 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for SigLIP model.""" + +import os +import re +import string +import warnings +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Optional + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import AddedToken + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput +from ...utils import logging, requires_backends +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +SPIECE_UNDERLINE = "▁" + + +@requires(backends=("sentencepiece",)) +class SiglipTokenizer(PreTrainedTokenizer): + """ + Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (`list[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + model_max_length (`int`, *optional*, defaults to 64): + The maximum length (in number of tokens) for model inputs. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + additional_special_tokens=None, + sp_model_kwargs: Optional[dict[str, Any]] = None, + model_max_length=64, + do_lower_case=True, + **kwargs, + ) -> None: + requires_backends(self, "protobuf") + + pad_token = ( + AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(pad_token, str) + else pad_token + ) + unk_token = ( + AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(unk_token, str) + else unk_token + ) + eos_token = ( + AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(eos_token, str) + else eos_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_lower_case = do_lower_case + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor() + self.vocab_file = vocab_file + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + model_max_length=model_max_length, + do_lower_case=do_lower_case, + **kwargs, + ) + + def get_spm_processor(self): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf() + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size + def vocab_size(self): + return self.sp_model.get_piece_size() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present + def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def remove_punctuation(self, text: str) -> str: + return text.translate(str.maketrans("", "", string.punctuation)) + + # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + def canonicalize_text(self, text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (puncuation removed). + + Args: + text (`str`): + String to be canonicalized. + keep_punctuation_exact_string (`str`, *optional*): + If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}' + (but will still remove '{' and '}' that appear separately). + """ + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string) + ) + else: + text = self.remove_punctuation(text) + text = re.sub(r"\s+", " ", text) + text = text.strip() + + return text + + def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> list[str]: + """ + Converts a string to a list of tokens. + """ + tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. + + For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`. + + Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + text = self.canonicalize_text(text, keep_punctuation_exact_string=None) + tokens = self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["SiglipTokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/siglip2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5b732b951367cd130d7edd3299eda74f9dc267 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip2/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_siglip2 import * + from .image_processing_siglip2 import * + from .image_processing_siglip2_fast import * + from .modeling_siglip2 import * + from .processing_siglip2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip2/configuration_siglip2.py b/venv/lib/python3.13/site-packages/transformers/models/siglip2/configuration_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..67ef9df8f4f8f03608d7ef5c09b269c4cc74433b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip2/configuration_siglip2.py @@ -0,0 +1,265 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_siglip2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Siglip2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2TextModel`]. It is used to instantiate a + Siglip2 text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip2 + [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip2 text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Siglip2Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + projection_size (`int`, *optional*, defaults to `hidden_size`): + The size of the projection head. + + Example: + + ```python + >>> from transformers import Siglip2TextConfig, Siglip2TextModel + + >>> # Initializing a Siglip2TextConfig with google/siglip2-base-patch16-224 style configuration + >>> configuration = Siglip2TextConfig() + + >>> # Initializing a Siglip2TextModel (with random weights) from the google/siglip2-base-patch16-224 style configuration + >>> model = Siglip2TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip2_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip2 + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + projection_size=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.projection_size = projection_size if projection_size is not None else hidden_size + + +class Siglip2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a + Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2 + [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). + The image is resized to fill maximum of this number of patches, and to preserve + the aspect ratio. In case the resulted number of patches is lower, the image is + padded in "patch" dimension. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel + + >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration + >>> configuration = Siglip2VisionConfig() + + >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration + >>> model = Siglip2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip2_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.num_patches = num_patches + + +class Siglip2Config(PretrainedConfig): + r""" + [`Siglip2Config`] is the configuration class to store the configuration of a [`Siglip2Model`]. It is used to + instantiate a Siglip2 model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip2 + [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Siglip2TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Siglip2VisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Siglip2Config, Siglip2Model + + >>> # Initializing a Siglip2Config with google/siglip2-base-patch16-224 style configuration + >>> configuration = Siglip2Config() + + >>> # Initializing a Siglip2Model (with random weights) from the google/siglip2-base-patch16-224 style configuration + >>> model = Siglip2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Siglip2Config from a Siglip2TextConfig and a Siglip2VisionConfig + >>> from transformers import Siglip2TextConfig, Siglip2VisionConfig + + >>> # Initializing a Siglip2Text and Siglip2Vision configuration + >>> config_text = Siglip2TextConfig() + >>> config_vision = Siglip2VisionConfig() + + >>> config = Siglip2Config.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip2" + sub_configs = {"text_config": Siglip2TextConfig, "vision_config": Siglip2VisionConfig} + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `Siglip2TextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `Siglip2VisionConfig` with default values.") + + self.text_config = Siglip2TextConfig(**text_config) + self.vision_config = Siglip2VisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + +__all__ = ["Siglip2Config", "Siglip2TextConfig", "Siglip2VisionConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip2/image_processing_siglip2.py b/venv/lib/python3.13/site-packages/transformers/models/siglip2/image_processing_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..30b5f1b958af2dadd2c7dcc79619f418857086fc --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip2/image_processing_siglip2.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SigLIP2.""" + +import math +from functools import lru_cache +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + from PIL import Image + + +@lru_cache(maxsize=256) +def get_image_size_for_max_num_patches( + image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5 +) -> tuple[int, int]: + """ + Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch. + + Args: + image_height (`int`): + Original image height. + image_width (`int`): + Original image width. + patch_size (`int`): + Patch size for processing. + max_num_patches (`int`): + Maximum number of patches. + eps (`float`): + Small threshold for binary search. + + Returns: + Tuple: (target_height, target_width) + """ + + def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: + scaled_size = size * scale + scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size + scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch + return int(scaled_size) + + # Binary search for optimal scale + scale_min, scale_max = eps / 10, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + num_patches = (target_height / patch_size) * (target_width / patch_size) + + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + return target_height, target_width + + +def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray: + """ + Convert 3D array image of shape (image_height, image_width, num_channels) into 2D array of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + image_height, image_width, num_channels = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_patches_height, patch_size, num_patches_width, patch_size, num_channels) + patched_image = patched_image.transpose(0, 2, 1, 3, 4) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +def pad_along_first_dim(array: np.ndarray, target_length: int, pad_value: int = 0) -> tuple[np.ndarray, np.ndarray]: + """ + Pad the array along the first dimension. + """ + current_length = array.shape[0] + padding_length = target_length - current_length + mask = np.ones((target_length,), dtype=np.int32) + if padding_length > 0: + paddings = [(0, padding_length)] + [(0, 0)] * (array.ndim - 1) + array = np.pad(array, paddings, mode="constant", constant_values=pad_value) + mask[-padding_length:] = 0 + return array, mask + + +class Siglip2ImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP2 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`. + Can be overridden by `do_resize` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to 256): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """ + + model_input_names = ["pixel_values", "pixel_attention_mask", "spatial_shapes"] + + def __init__( + self, + do_resize: bool = True, + resample: "PILImageResampling" = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: int = 16, + max_num_patches: int = 256, + **kwargs, + ): + super().__init__(**kwargs) + + image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] + image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] + + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.patch_size = patch_size + self.max_num_patches = max_num_patches + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + resample: Optional["PILImageResampling"] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: Optional[int] = None, + max_num_patches: Optional[int] = None, + ) -> "Image.Image": + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + Patch size for processing, same as the patch size used in the model. + max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): + Maximum number of patches per image, the image will be resized to have at most this number of patches. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + patch_size = patch_size if patch_size is not None else self.patch_size + max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches + + # Explicitly specify data format to be channels last for image preprocessing. + # Image processor does not support different output formats, because it returns patches. + data_format = ChannelDimension.LAST + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_masks = [] + pixel_values = [] + spatial_shapes = [] + + for image in images: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + if do_resize: + height, width = get_image_size_for_max_num_patches( + image_height=image.shape[0], + image_width=image.shape[1], + patch_size=patch_size, + max_num_patches=max_num_patches, + ) + image = resize(image=image, size=(height, width), resample=resample, input_data_format=data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=data_format) + + patches = convert_image_to_patches(image, patch_size) + patches, mask = pad_along_first_dim(patches, max_num_patches) + num_patches_height = image.shape[0] // patch_size + num_patches_width = image.shape[1] // patch_size + + spatial_shapes.append((num_patches_height, num_patches_width)) + pixel_values.append(patches) + pixel_masks.append(mask) + + batch_feature = BatchFeature( + data={ + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_masks, + "spatial_shapes": spatial_shapes, + }, + tensor_type=return_tensors, + ) + + return batch_feature + + +__all__ = ["Siglip2ImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip2/image_processing_siglip2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/siglip2/image_processing_siglip2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..45261fab2cd0e7195d05bc0d2a8dafc71b6d2e19 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip2/image_processing_siglip2_fast.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP2.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, +) +from ...image_utils import ( + ImageInput, + PILImageResampling, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + logging, +) +from .image_processing_siglip2 import get_image_size_for_max_num_patches + + +logger = logging.get_logger(__name__) + + +def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor": + """ + Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + num_channels, image_height, image_width = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size) + patched_image = patched_image.permute(1, 3, 2, 4, 0) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +def pad_along_first_dim( + tensor: "torch.Tensor", target_length: int, pad_value: int = 0 +) -> tuple["torch.Tensor", "torch.Tensor"]: + """ + Pad the tensor along the first dimension. + """ + current_length = tensor.shape[0] + padding_length = target_length - current_length + mask = torch.ones((target_length,), dtype=torch.int32) + if padding_length > 0: + padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length] + tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value) + mask[-padding_length:] = 0 + return tensor, mask + + +class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to 256): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """ + + patch_size: Optional[int] + max_num_patches: Optional[int] + + +@auto_docstring +class Siglip2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + do_resize = True + do_rescale = True + do_normalize = True + patch_size = 16 + max_num_patches = 256 + valid_kwargs = Siglip2FastImageProcessorKwargs + unused_kwargs = ["size", "do_center_crop", "crop_size"] + + def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _validate_preprocess_kwargs(self, **kwargs) -> tuple: + # Remove do_resize from kwargs to not raise an error as size is None + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + patch_size: int, + max_num_patches: int, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + pixel_masks = [] + pixel_values = [] + spatial_shapes = [] + + for image in images: + if do_resize: + height, width = get_image_size_for_max_num_patches( + image_height=image.shape[1], + image_width=image.shape[2], + patch_size=patch_size, + max_num_patches=max_num_patches, + ) + side_dict = SizeDict(height=height, width=width) + image = self.resize(image=image, size=side_dict, interpolation=interpolation) + + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + + # (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels) + patches = convert_image_to_patches(image, patch_size) + patches, mask = pad_along_first_dim(patches, max_num_patches) + + num_patches_height = image.shape[1] // patch_size + num_patches_width = image.shape[2] // patch_size + + spatial_shapes.append((num_patches_height, num_patches_width)) + pixel_values.append(patches) + pixel_masks.append(mask) + + pixel_values = torch.stack(pixel_values) + pixel_masks = torch.stack(pixel_masks) + spatial_shapes = torch.tensor(spatial_shapes) + + batch_feature = BatchFeature( + data={ + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_masks, + "spatial_shapes": spatial_shapes, + }, + tensor_type=return_tensors, + ) + return batch_feature + + +__all__ = ["Siglip2ImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip2/modeling_siglip2.py b/venv/lib/python3.13/site-packages/transformers/models/siglip2/modeling_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..4fbaf32dff901861d944419039db0f3566dfa65a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip2/modeling_siglip2.py @@ -0,0 +1,1195 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_siglip2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs +from ...utils.generic import check_model_inputs +from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + """ +) +class Siglip2VisionOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +class Siglip2TextOutput(ModelOutput): + r""" + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring +class Siglip2Output(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`Siglip2TextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Siglip2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`list[tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Siglip2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Siglip2EncoderLayer`]. + + Args: + config: Siglip2Config + """ + + def __init__(self, config: Siglip2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class Siglip2VisionTransformer(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = Siglip2MultiheadAttentionPoolingHead(config) + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, spatial_shapes) + + if attention_mask is not None and self.config._attn_implementation != "flash_attention_2": + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + else: + encoder_attention_mask = attention_mask + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@auto_docstring +class Siglip2PreTrainedModel(PreTrainedModel): + config: Siglip2Config + base_model_prefix = "siglip2" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "Siglip2TextEmbeddings", + "Siglip2VisionEmbeddings", + "Siglip2EncoderLayer", + "Siglip2MultiheadAttentionPoolingHead", + ] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": Siglip2EncoderLayer, + "attentions": Siglip2Attention, + } + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Siglip2VisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, Siglip2Config) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, Siglip2Attention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, Siglip2MLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, Siglip2MultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, Siglip2Model): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, Siglip2ForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class Siglip2TextEmbeddings(nn.Module): + def __init__(self, config: Siglip2TextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class Siglip2TextTransformer(nn.Module): + def __init__(self, config: Siglip2TextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = Siglip2TextEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: Siglip2's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + uses_flash_attention = "flash" in self.config._attn_implementation + if uses_flash_attention: + attention_mask = None + elif attention_mask is not None and not uses_flash_attention: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # The model uses the last token's hidden state, which may be padding. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + +@auto_docstring( + custom_intro=""" + The text model from Siglip2 without any head or projection on top. + """ +) +class Siglip2TextModel(Siglip2PreTrainedModel): + config: Siglip2TextConfig + + def __init__(self, config: Siglip2TextConfig): + super().__init__(config) + self.text_model = Siglip2TextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from transformers import AutoTokenizer, Siglip2TextModel + + >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + + +class Siglip2MultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + self.num_heads = config.num_attention_heads + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + if attention_mask is not None: + target_len, source_len = probe.shape[1], hidden_state.shape[1] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) + attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) + attention_mask = attention_mask.reshape(-1, target_len, source_len) + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@auto_docstring( + custom_intro=""" + The vision model from Siglip2 without any head or projection on top. + """ +) +class Siglip2VisionModel(Siglip2PreTrainedModel): + config: Siglip2VisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Siglip2VisionConfig): + super().__init__(config) + + self.vision_model = Siglip2VisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Siglip2VisionModel + + >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +@auto_docstring +class Siglip2Model(Siglip2PreTrainedModel): + config: Siglip2Config + + def __init__(self, config: Siglip2Config): + super().__init__(config) + + if not isinstance(config.text_config, Siglip2TextConfig): + raise TypeError( + "config.text_config is expected to be of type Siglip2TextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, Siglip2VisionConfig): + raise TypeError( + "config.vision_config is expected to be of type Siglip2VisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = Siglip2TextModel._from_config(text_config) + vision_model = Siglip2VisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`Siglip2TextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + pooled_output = text_outputs.pooler_output + + return pooled_output + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`Siglip2VisionModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AutoModel + >>> from transformers.image_utils import load_image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ``` + """ + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + ) + pooled_output = vision_outputs.pooler_output + + return pooled_output + + # NOTE: Siglip2Model uses Pretrained backbones, so we don't need to add `check_model_inputs` here + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Siglip2Output: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ``` + """ + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return Siglip2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@auto_docstring( + custom_intro=""" + Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """ +) +class Siglip2ForImageClassification(Siglip2PreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: Siglip2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = Siglip2VisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> ImageClassifierOutput: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `Siglip2Model` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224") + >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + if pixel_attention_mask is not None: + pool_mask = pixel_attention_mask[..., None].to(sequence_output.device) + sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1) + else: + sequence_output = torch.mean(sequence_output, dim=1) + + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + "Siglip2ForImageClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip2/modular_siglip2.py b/venv/lib/python3.13/site-packages/transformers/models/siglip2/modular_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..260a82e5143e157301f4d590270504e6205ae9ea --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip2/modular_siglip2.py @@ -0,0 +1,605 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + SiglipForImageClassification, + SiglipModel, + SiglipMultiheadAttentionPoolingHead, + SiglipOutput, + SiglipPreTrainedModel, + SiglipTextModel, + SiglipTextModelOutput, + SiglipVisionModel, + SiglipVisionModelOutput, + SiglipVisionTransformer, +) + +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...utils import auto_docstring, filter_out_non_signature_kwargs + + +class Siglip2TextConfig(SiglipTextConfig): + pass + + +class Siglip2VisionConfig(SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a + Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2 + [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). + The image is resized to fill maximum of this number of patches, and to preserve + the aspect ratio. In case the resulted number of patches is lower, the image is + padded in "patch" dimension. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel + + >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration + >>> configuration = Siglip2VisionConfig() + + >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration + >>> model = Siglip2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_patches = num_patches + del self.image_size + + +class Siglip2Config(SiglipConfig): + pass + + +class Siglip2VisionOutput(SiglipVisionModelOutput): + pass + + +class Siglip2TextOutput(SiglipTextModelOutput): + pass + + +class Siglip2Output(SiglipOutput): + pass + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`list[tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +class Siglip2VisionTransformer(SiglipVisionTransformer): + def __init__(self, config: Siglip2VisionConfig): + super().__init__(config) + + # Update: add `spatial_shapes` and `attention_mask` + def forward( + self, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, spatial_shapes) + + if attention_mask is not None and self.config._attn_implementation != "flash_attention_2": + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + else: + encoder_attention_mask = attention_mask + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Siglip2PreTrainedModel(SiglipPreTrainedModel): + pass + + +class Siglip2TextModel(SiglipTextModel): + pass + + +class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): + def __init__(self, config: Siglip2VisionConfig): + super().__init__(config) + self.num_heads = config.num_attention_heads + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + if attention_mask is not None: + target_len, source_len = probe.shape[1], hidden_state.shape[1] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) + attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) + attention_mask = attention_mask.reshape(-1, target_len, source_len) + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Siglip2VisionModel(SiglipVisionModel): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Siglip2VisionModel + + >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class Siglip2Model(SiglipModel): + # Update: add `spatial_shapes` and `pixel_attention_mask` + @filter_out_non_signature_kwargs() + @auto_docstring + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`Siglip2VisionModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AutoModel + >>> from transformers.image_utils import load_image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ``` + """ + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + ) + pooled_output = vision_outputs.pooler_output + + return pooled_output + + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Siglip2Output: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ``` + """ + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return Siglip2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class Siglip2ForImageClassification(SiglipForImageClassification): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> ImageClassifierOutput: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `Siglip2Model` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224") + >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + if pixel_attention_mask is not None: + pool_mask = pixel_attention_mask[..., None].to(sequence_output.device) + sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1) + else: + sequence_output = torch.mean(sequence_output, dim=1) + + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Siglip2Config", + "Siglip2TextConfig", + "Siglip2VisionConfig", + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + "Siglip2ForImageClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/siglip2/processing_siglip2.py b/venv/lib/python3.13/site-packages/transformers/models/siglip2/processing_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..8e177b237b10633cc6181f07d54d24fbd1600d5c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/siglip2/processing_siglip2.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for SigLIP2. +""" + +from typing import Optional + +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin + + +class Siglip2ImagesKwargs(ImagesKwargs, total=False): + max_num_patches: Optional[int] + patch_size: Optional[int] + + +class Siglip2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Siglip2ImagesKwargs + + _defaults = { + "text_kwargs": { + "padding": "max_length", + "truncation": True, + "max_length": 64, + }, + "images_kwargs": { + "max_num_patches": 256, + "patch_size": 16, + }, + } + + +class Siglip2Processor(ProcessorMixin): + r""" + Constructs a Siglip2 processor which wraps a Siglip2 image processor and a Gemma tokenizer into a single processor. + + [`Siglip2Processor`] offers all the functionalities of [`Siglip2ImageProcessor`] and [`GemmaTokenizerFast`]. See the + [`~Siglip2Processor.__call__`] and [`~Siglip2Processor.decode`] for more information. + + Args: + image_processor ([`Siglip2ImageProcessor`]): + The image processor is a required input. + tokenizer ([`GemmaTokenizerFast`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + valid_processor_kwargs = Siglip2ProcessorKwargs + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + +__all__ = ["Siglip2Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e07844d45c2d0f3c5bfba69b330c3beede0eab1 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_speech_encoder_decoder import * + from .modeling_flax_speech_encoder_decoder import * + from .modeling_speech_encoder_decoder import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..63e161fcbe74775fcfd8ab6dee18f664ca694194 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class SpeechEncoderDecoderConfig(PretrainedConfig): + r""" + [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a + [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified + arguments, defining the encoder and decoder configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + + - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the encoder config. + - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the decoder config. + + Examples: + + ```python + >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel + + >>> # Initializing a Wav2Vec2 & BERT style configuration + >>> config_encoder = Wav2Vec2Config() + >>> config_decoder = BertConfig() + + >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) + + >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & google-bert/bert-base-uncased style configurations + >>> model = SpeechEncoderDecoderModel(config=config) + + >>> # Accessing the model configuration + >>> config_encoder = model.config.encoder + >>> config_decoder = model.config.decoder + >>> # set decoder config to causal lm + >>> config_decoder.is_decoder = True + >>> config_decoder.add_cross_attention = True + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("my-model") + + >>> # loading model and config from pretrained folder + >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model") + >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config) + ```""" + + model_type = "speech-encoder-decoder" + sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} + has_no_defaults_at_init = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "encoder" not in kwargs or "decoder" not in kwargs: + raise ValueError( + f"A configuration of type {self.model_type} cannot be instantiated because not both `encoder` and" + f" `decoder` sub-configurations are passed, but only {kwargs}" + ) + + encoder_config = kwargs.pop("encoder") + encoder_model_type = encoder_config.pop("model_type") + decoder_config = kwargs.pop("decoder") + decoder_model_type = decoder_config.pop("model_type") + + self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config) + self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config) + self.is_encoder_decoder = True + + @classmethod + def from_encoder_decoder_configs( + cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs + ) -> PretrainedConfig: + r""" + Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model + configuration and decoder model configuration. + + Returns: + [`SpeechEncoderDecoderConfig`]: An instance of a configuration object + """ + logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) + + +__all__ = ["SpeechEncoderDecoderConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3614c5d4981b4e010f9ba7732ad537865357046d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -0,0 +1,930 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Flax Speech-Encoder-Decoder architectures""" + +import os +from typing import Optional, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput +from ...modeling_flax_utils import FlaxPreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM +from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig" + +SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech + autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is + loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via + [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder + and should be fine-tuned on a downstream generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech + Translation](https://huggingface.co/papers/2104.06678) it is shown how leveraging large pretrained speech models for speech + translation yields a significant performance improvement. + + After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other + models (see the examples for more information). + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Parameters: + config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): + Float values of input raw speech waveform or speech features. Values can be obtained by loading a `.flac` + or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* + via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or + [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type + `torch.FloatTensor`. + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. +""" + +SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" + Args: + inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): + Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* + or *.wav* audio file into an array of type *list[float]* or a *numpy.ndarray*, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or [`Speech2TextProcessor`] should be used + for padding and conversion into a tensor of type *torch.FloatTensor*. + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. +""" + +SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a + plain tuple. +""" + + +class FlaxSpeechEncoderDecoderModule(nn.Module): + config: SpeechEncoderDecoderConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + encoder_config = self.config.encoder + decoder_config = self.config.decoder + + # Copied from `modeling_hybrid_clip.py` with modifications. + from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING + + encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class + decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class + + self.encoder = encoder_module(encoder_config, dtype=self.dtype) + self.decoder = decoder_module(decoder_config, dtype=self.dtype) + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Dense( + self.decoder.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), + dtype=self.dtype, + ) + else: + self.enc_to_dec_proj = None + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.encoder.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride) + + return input_lengths + + def _get_encoder_module(self): + return self.encoder + + def _get_projection_module(self): + return self.enc_to_dec_proj + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + inputs, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + encoder_outputs=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + freeze_feature_encoder: bool = False, + ): + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + freeze_feature_encoder=freeze_feature_encoder, + ) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if self.enc_to_dec_proj is not None: + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # compute correct encoder attention mask + if attention_mask is not None: + encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( + encoder_hidden_states.shape[1], attention_mask + ) + else: + encoder_attention_mask = None + + # flax script modeling_flax_wav2vec2.py + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=decoder_outputs.logits, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) +class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): + r""" + [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture + with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one + as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the + encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config_class = SpeechEncoderDecoderConfig + base_model_prefix: str = "speech_encoder_decoder" + module_class = FlaxSpeechEncoderDecoderModule + + def __init__( + self, + config: SpeechEncoderDecoderConfig, + input_shape: Optional[tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if not _do_init: + raise ValueError( + "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + + if config.decoder.cross_attention_hidden_size is not None: + # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer) + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # make sure input & output embeddings are not tied + config.tie_word_embeddings = False + module = self.module_class(config=config, dtype=dtype, **kwargs) + + if input_shape is None: + # speech encoders almost always downsample the sequence length dimension + encoder_input_length = 1024 + decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) + input_shape = ((1, encoder_input_length), (1, decoder_input_length)) + + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + encoder_input_shape, decoder_input_shape = input_shape + + # init input DeviceArrays + inputs = jnp.zeros(encoder_input_shape, dtype="f4") + attention_mask = jnp.ones_like(inputs, dtype="i4") + decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + batch_size, sequence_length = inputs.shape + + decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape + if not decoder_batch_size == batch_size: + raise ValueError( + f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder" + f" and {decoder_batch_size} for decoder." + ) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + inputs, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) + + @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) + def encode( + self, + inputs: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + freeze_feature_encoder: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel + + >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" + ... ) + + >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) + >>> encoder_outputs = model.encode(inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(inputs, dtype="i4") + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, inputs, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(inputs, attention_mask, **kwargs) + + outputs = self.module.apply( + {"params": params or self.params}, + inputs=jnp.array(inputs, dtype="f4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + freeze_feature_encoder=freeze_feature_encoder, + rngs=rngs, + method=_encoder_forward, + ) + + if return_dict: + outputs = FlaxBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return outputs + + @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel + >>> import jax.numpy as jnp + + >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" + ... ) + + >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) + >>> encoder_outputs = model.encode(inputs) + + >>> decoder_start_token_id = model.config.decoder.bos_token_id + >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + params = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBartAttention module + if past_key_values: + params["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward( + module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs + ): + projection_module = module._get_projection_module() + decoder_module = module._get_decoder_module() + + # optionally project encoder_hidden_states + if projection_module is not None: + encoder_hidden_states = projection_module(encoder_hidden_states) + + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + **kwargs, + ) + + outputs = self.module.apply( + params, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def __call__( + self, + inputs: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + freeze_feature_encoder: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel, AutoTokenizer + + >>> # load a fine-tuned wav2vec2-2-bart model + >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") + >>> # load output tokenizer + >>> tokenizer_output = AutoTokenizer.from_pretrained("facebook/bart-large") + + >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) + + >>> # use bart's special bos, pad and eos tokens + >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id + >>> model.config.pad_token_id = model.decoder.config.pad_token_id + >>> model.config.eos_token_id = model.decoder.config.eos_token_id + + >>> outputs = model.generate(inputs) + # Assert something? More interesting input? dtype correct? + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(inputs, dtype="i4") + + # prepare decoder inputs + if decoder_input_ids is None: + raise ValueError( + "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must" + " be specified as an input argument." + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + inputs=jnp.array(inputs, dtype="f4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + freeze_feature_encoder=freeze_feature_encoder, + rngs=rngs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": decoder_position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + *model_args, + **kwargs, + ) -> FlaxPreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + Params: + encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel + + >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" + ... ) + >>> # saving model after fine-tuning + >>> model.save_pretrained("./wav2vec2-2-bart-large") + >>> # load fine-tuned model + >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder: + del kwargs["encoder_" + key] + for key in kwargs_decoder: + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and causal mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = FlaxAutoModel.from_pretrained( + encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder + ) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + dtype = kwargs.pop("dtype", jnp.float32) + config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + + # make sure input & output word embeddings are not tied + config.tie_word_embeddings = False + + # init model + model = cls(config, dtype=dtype) + model.params["encoder"] = encoder.params + model.params["decoder"] = decoder.params + + return model + + +__all__ = ["FlaxSpeechEncoderDecoderModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..272ebdc741bc6bfee509266b8203ccd57a81d960 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -0,0 +1,505 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Speech-Encoder-Text-Decoder architectures""" + +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM +from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +@auto_docstring +class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin): + r""" + [`SpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with + one of the base model classes of the library as encoder and another one as decoder when created with the + :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and + :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config: SpeechEncoderDecoderConfig + base_model_prefix = "speech_encoder_decoder" + main_input_name = "inputs" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + _supports_flash_attn = True + _supports_sdpa = True + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + ): + r""" + encoder (`PreTrainedModel`, *optional*): + The encoder model to use. + decoder (`PreTrainedModel`, *optional*): + The decoder model to use. + """ + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") + if config is None: + config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # initialize with config + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + super().__init__(config) + + if encoder is None: + encoder = AutoModel.from_config(config.encoder) + + if decoder is None: + decoder = AutoModelForCausalLM.from_config(config.decoder) + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" + f" {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.config.encoder._attn_implementation = self.encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + # get encoder output hidden size + self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size) + if ( + self.encoder_output_dim != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + # encoder outputs might need to be projected to different dimension for decoder + self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) + + def get_encoder(self): + return self.encoder + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder of the speech encoder so + that its parameters will not be updated during training. + """ + self.encoder.freeze_feature_encoder() + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[str] = None, + decoder_pretrained_model_name_or_path: Optional[str] = None, + *model_args, + **kwargs, + ) -> PreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you need to first set it back in training mode with `model.train()`. + + Params: + encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import SpeechEncoderDecoderModel + + >>> # initialize a wav2vec2bert from a pretrained Wav2Vec2 and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized + >>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-base-960h", "google-bert/bert-base-uncased" + ... ) + >>> # saving model after fine-tuning + >>> model.save_pretrained("./wav2vec2bert") + >>> # load fine-tuned model + >>> model = SpeechEncoderDecoderModel.from_pretrained("./wav2vec2bert") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder: + del kwargs["encoder_" + key] + for key in kwargs_decoder: + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and causal mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + return cls(encoder=encoder, decoder=decoder, config=config) + + @auto_docstring + def forward( + self, + inputs: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + input_values: Optional[torch.FloatTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): + Float values of input raw speech waveform or speech features. Values can be obtained by loading a `.flac` + or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* + via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or + [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type + `torch.FloatTensor`. + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the + right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, + ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type *list[float]* or a *numpy.ndarray*, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding and conversion + into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details. + + Examples: + + ```python + >>> from transformers import SpeechEncoderDecoderModel, AutoProcessor + >>> from datasets import load_dataset + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15") + >>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values + >>> # Inference: Translate English speech to German + >>> generated = model.generate(input_values) + >>> decoded = processor.batch_decode(generated, skip_special_tokens=True)[0] + >>> decoded + 'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.' + + >>> # Training: Train model on English transcription + >>> labels = processor(text=ds[0]["text"], return_tensors="pt").input_ids + + >>> loss = model(input_values, labels=labels).loss + >>> loss.backward() + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + if "num_items_in_batch" in kwargs_encoder: + kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None) + + if encoder_outputs is None: + if inputs is None: + if input_values is not None and input_features is not None: + raise ValueError("You cannot specify both input_values and input_features at the same time") + elif input_values is not None: + inputs = input_values + elif input_features is not None: + inputs = input_features + else: + raise ValueError("You have to specify either input_values or input_features") + + encoder_outputs = self.encoder( + inputs, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder_output_dim != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # compute correct encoder attention mask + if attention_mask is not None: + encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( + encoder_hidden_states.shape[1], attention_mask + ) + else: + encoder_attention_mask = None + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + **kwargs_decoder, + ) + + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + logits = decoder_outputs.logits if return_dict else decoder_outputs[0] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the SpeechEncoderDecoderModel directly is not supported. Please use the" + " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" + ) + + +__all__ = ["SpeechEncoderDecoderModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speecht5/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/speecht5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52fee59ae9b96d14361a771cbc58089d5e5811f2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speecht5/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_speecht5 import * + from .feature_extraction_speecht5 import * + from .modeling_speecht5 import * + from .processing_speecht5 import * + from .tokenization_speecht5 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/speecht5/configuration_speecht5.py b/venv/lib/python3.13/site-packages/transformers/models/speecht5/configuration_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..b73471a29dc4cda3538b08c3da40b9b34d337b14 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speecht5/configuration_speecht5.py @@ -0,0 +1,422 @@ +# coding=utf-8 +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SpeechT5 model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SpeechT5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SpeechT5Model`]. It is used to instantiate a + SpeechT5 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the SpeechT5 + [microsoft/speecht5_asr](https://huggingface.co/microsoft/speecht5_asr) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 81): + Vocabulary size of the SpeechT5 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed to the forward method of [`SpeechT5Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + encoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer decoder. + decoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder. + decoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + positional_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the text position encoding layers. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in the speech encoder pre-net. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the speech encoder pre-net. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + speech encoder pre-net. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the speech encoder pre-net. The + length of *conv_stride* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the speech encoder pre-net. + The length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the speech encoder pre-net. For + reference see [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://huggingface.co/papers/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the probability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + num_mel_bins (`int`, *optional*, defaults to 80): + Number of mel features used per input features. Used by the speech decoder pre-net. Should correspond to + the value used in the [`SpeechT5Processor`] class. + speech_decoder_prenet_layers (`int`, *optional*, defaults to 2): + Number of layers in the speech decoder pre-net. + speech_decoder_prenet_units (`int`, *optional*, defaults to 256): + Dimensionality of the layers in the speech decoder pre-net. + speech_decoder_prenet_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability for the speech decoder pre-net layers. + speaker_embedding_dim (`int`, *optional*, defaults to 512): + Dimensionality of the *XVector* embedding vectors. + speech_decoder_postnet_layers (`int`, *optional*, defaults to 5): + Number of layers in the speech decoder post-net. + speech_decoder_postnet_units (`int`, *optional*, defaults to 256): + Dimensionality of the layers in the speech decoder post-net. + speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5): + Number of convolutional filter channels in the speech decoder post-net. + speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability for the speech decoder post-net layers. + reduction_factor (`int`, *optional*, defaults to 2): + Spectrogram length reduction factor for the speech decoder inputs. + max_speech_positions (`int`, *optional*, defaults to 4000): + The maximum sequence length of speech features that this model might ever be used with. + max_text_positions (`int`, *optional*, defaults to 450): + The maximum sequence length of text features that this model might ever be used with. + encoder_max_relative_position (`int`, *optional*, defaults to 160): + Maximum distance for relative position embedding in the encoder. + use_guided_attention_loss (`bool`, *optional*, defaults to `True`): + Whether to apply guided attention loss while training the TTS model. + guided_attention_loss_num_heads (`int`, *optional*, defaults to 2): + Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all + attention heads. + guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4): + Standard deviation for guided attention loss. + guided_attention_loss_scale (`float`, *optional*, defaults to 10.0): + Scaling coefficient for guided attention loss (also known as lambda). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example: + + ```python + >>> from transformers import SpeechT5Model, SpeechT5Config + + >>> # Initializing a "microsoft/speecht5_asr" style configuration + >>> configuration = SpeechT5Config() + + >>> # Initializing a model (with random weights) from the "microsoft/speecht5_asr" style configuration + >>> model = SpeechT5Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "speecht5" + attribute_map = {"num_attention_heads": "encoder_attention_heads", "num_hidden_layers": "encoder_layers"} + + def __init__( + self, + vocab_size=81, + hidden_size=768, + encoder_layers=12, + encoder_attention_heads=12, + encoder_ffn_dim=3072, + encoder_layerdrop=0.1, + decoder_layers=6, + decoder_ffn_dim=3072, + decoder_attention_heads=12, + decoder_layerdrop=0.1, + hidden_act="gelu", + positional_dropout=0.1, + hidden_dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + scale_embedding=False, + feat_extract_norm="group", + feat_proj_dropout=0.0, + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + decoder_start_token_id=2, + num_mel_bins=80, + speech_decoder_prenet_layers=2, + speech_decoder_prenet_units=256, + speech_decoder_prenet_dropout=0.5, + speaker_embedding_dim=512, + speech_decoder_postnet_layers=5, + speech_decoder_postnet_units=256, + speech_decoder_postnet_kernel=5, + speech_decoder_postnet_dropout=0.5, + reduction_factor=2, + max_speech_positions=4000, + max_text_positions=450, + encoder_max_relative_position=160, + use_guided_attention_loss=True, + guided_attention_loss_num_heads=2, + guided_attention_loss_sigma=0.4, + guided_attention_loss_scale=10.0, + use_cache=True, + is_encoder_decoder=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_attention_heads = decoder_attention_heads + self.decoder_layerdrop = decoder_layerdrop + self.hidden_act = hidden_act + self.positional_dropout = positional_dropout + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.scale_embedding = scale_embedding + + self.feat_extract_norm = feat_extract_norm + self.feat_proj_dropout = feat_proj_dropout + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://huggingface.co/papers/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + self.num_mel_bins = num_mel_bins + self.speech_decoder_prenet_layers = speech_decoder_prenet_layers + self.speech_decoder_prenet_units = speech_decoder_prenet_units + self.speech_decoder_prenet_dropout = speech_decoder_prenet_dropout + self.speaker_embedding_dim = speaker_embedding_dim + + self.speech_decoder_postnet_layers = speech_decoder_postnet_layers + self.speech_decoder_postnet_units = speech_decoder_postnet_units + self.speech_decoder_postnet_kernel = speech_decoder_postnet_kernel + self.speech_decoder_postnet_dropout = speech_decoder_postnet_dropout + self.reduction_factor = reduction_factor + + self.max_speech_positions = max_speech_positions + self.max_text_positions = max_text_positions + self.encoder_max_relative_position = encoder_max_relative_position + + self.use_guided_attention_loss = use_guided_attention_loss + self.guided_attention_loss_num_heads = guided_attention_loss_num_heads + self.guided_attention_loss_sigma = guided_attention_loss_sigma + self.guided_attention_loss_scale = guided_attention_loss_scale + + self.use_cache = use_cache + self.is_encoder_decoder = is_encoder_decoder + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +class SpeechT5HifiGanConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SpeechT5HifiGanModel`]. It is used to instantiate + a SpeechT5 HiFi-GAN vocoder model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the SpeechT5 + [microsoft/speecht5_hifigan](https://huggingface.co/microsoft/speecht5_hifigan) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + model_in_dim (`int`, *optional*, defaults to 80): + The number of frequency bins in the input log-mel spectrogram. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the output audio will be generated, expressed in hertz (Hz). + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the upsampling network. + upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[4, 4, 4, 4]`): + A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The + length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. + upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 8, 8]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The + length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of + *upsample_rates*. + resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field + fusion (MRF) module. + resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the + multi-receptive field fusion (MRF) module. + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation. + normalize_before (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance. + + Example: + + ```python + >>> from transformers import SpeechT5HifiGan, SpeechT5HifiGanConfig + + >>> # Initializing a "microsoft/speecht5_hifigan" style configuration + >>> configuration = SpeechT5HifiGanConfig() + + >>> # Initializing a model (with random weights) from the "microsoft/speecht5_hifigan" style configuration + >>> model = SpeechT5HifiGan(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "hifigan" + + def __init__( + self, + model_in_dim=80, + sampling_rate=16000, + upsample_initial_channel=512, + upsample_rates=[4, 4, 4, 4], + upsample_kernel_sizes=[8, 8, 8, 8], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + initializer_range=0.01, + leaky_relu_slope=0.1, + normalize_before=True, + **kwargs, + ): + self.model_in_dim = model_in_dim + self.sampling_rate = sampling_rate + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.initializer_range = initializer_range + self.leaky_relu_slope = leaky_relu_slope + self.normalize_before = normalize_before + super().__init__(**kwargs) + + +__all__ = ["SpeechT5Config", "SpeechT5HifiGanConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speecht5/feature_extraction_speecht5.py b/venv/lib/python3.13/site-packages/transformers/models/speecht5/feature_extraction_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..822ae01a88d7d9e8714d2ef42bf66aab73418a49 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speecht5/feature_extraction_speecht5.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for SpeechT5.""" + +import warnings +from typing import Any, Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class SpeechT5FeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a SpeechT5 feature extractor. + + This class can pre-process a raw speech signal by (optionally) normalizing to zero-mean unit-variance, for use by + the SpeechT5 speech encoder prenet. + + This class can also extract log-mel filter bank features from raw speech, for use by the SpeechT5 speech decoder + prenet. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance for some models. + num_mel_bins (`int`, *optional*, defaults to 80): + The number of mel-frequency bins in the extracted spectrogram features. + hop_length (`int`, *optional*, defaults to 16): + Number of ms between windows. Otherwise referred to as "shift" in many papers. + win_length (`int`, *optional*, defaults to 64): + Number of ms per window. + win_function (`str`, *optional*, defaults to `"hann_window"`): + Name for the window function used for windowing, must be accessible via `torch.{win_function}` + frame_signal_scale (`float`, *optional*, defaults to 1.0): + Constant multiplied in creating the frames before applying DFT. This argument is deprecated. + fmin (`float`, *optional*, defaults to 80): + Minimum mel frequency in Hz. + fmax (`float`, *optional*, defaults to 7600): + Maximum mel frequency in Hz. + mel_floor (`float`, *optional*, defaults to 1e-10): + Minimum value of mel frequency banks. + reduction_factor (`int`, *optional*, defaults to 2): + Spectrogram length reduction factor. This argument is deprecated. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`. + """ + + model_input_names = ["input_values", "attention_mask"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 16000, + padding_value: float = 0.0, + do_normalize: bool = False, + num_mel_bins: int = 80, + hop_length: int = 16, + win_length: int = 64, + win_function: str = "hann_window", + frame_signal_scale: float = 1.0, + fmin: float = 80, + fmax: float = 7600, + mel_floor: float = 1e-10, + reduction_factor: int = 2, + return_attention_mask: bool = True, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.do_normalize = do_normalize + self.return_attention_mask = return_attention_mask + + self.num_mel_bins = num_mel_bins + self.hop_length = hop_length + self.win_length = win_length + self.win_function = win_function + self.frame_signal_scale = frame_signal_scale + self.fmin = fmin + self.fmax = fmax + self.mel_floor = mel_floor + self.reduction_factor = reduction_factor + + self.sample_size = win_length * sampling_rate // 1000 + self.sample_stride = hop_length * sampling_rate // 1000 + self.n_fft = optimal_fft_length(self.sample_size) + self.n_freqs = (self.n_fft // 2) + 1 + + self.window = window_function(window_length=self.sample_size, name=self.win_function, periodic=True) + + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.n_freqs, + num_mel_filters=self.num_mel_bins, + min_frequency=self.fmin, + max_frequency=self.fmax, + sampling_rate=self.sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + if frame_signal_scale != 1.0: + warnings.warn( + "The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + if reduction_factor != 2.0: + warnings.warn( + "The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + + @staticmethod + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 + ) -> list[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def _extract_mel_features( + self, + one_waveform: np.ndarray, + ) -> np.ndarray: + """ + Extracts log-mel filterbank features for one waveform array (unbatched). + """ + log_mel_spec = spectrogram( + one_waveform, + window=self.window, + frame_length=self.sample_size, + hop_length=self.sample_stride, + fft_length=self.n_fft, + mel_filters=self.mel_filters, + mel_floor=self.mel_floor, + log_mel="log10", + ) + return log_mel_spec.T + + def __call__( + self, + audio: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] = None, + audio_target: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] = None, + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Pass in a value for `audio` to extract waveform features. Pass in a value for `audio_target` to extract log-mel + spectrogram features. + + Args: + audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`, *optional*): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. This outputs waveform features. Must + be mono channel audio, not stereo, i.e. single float per timestep. + audio_target (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`, *optional*): + The sequence or batch of sequences to be processed as targets. Each sequence can be a numpy array, a + list of float values, a list of numpy arrays or a list of list of float values. This outputs log-mel + spectrogram features. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` or `audio_target` input was sampled. It is strongly recommended + to pass `sampling_rate` at the forward call to prevent silent errors. + """ + if audio is None and audio_target is None: + raise ValueError("You must provide either `audio` or `audio_target` values.") + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if audio is not None: + inputs = self._process_audio( + audio, + False, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_attention_mask, + return_tensors, + **kwargs, + ) + else: + inputs = None + + if audio_target is not None: + inputs_target = self._process_audio( + audio_target, + True, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_attention_mask, + return_tensors, + **kwargs, + ) + + if inputs is None: + return inputs_target + else: + inputs["labels"] = inputs_target["input_values"] + decoder_attention_mask = inputs_target.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask + + return inputs + + def _process_audio( + self, + speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + is_target: bool = False, + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + is_batched_numpy = isinstance(speech, np.ndarray) and len(speech.shape) > 1 + if is_batched_numpy and len(speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(speech, (list, tuple)) and (isinstance(speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + speech = [np.asarray(speech, dtype=np.float32) for speech in speech] + elif not is_batched and not isinstance(speech, np.ndarray): + speech = np.asarray(speech, dtype=np.float32) + elif isinstance(speech, np.ndarray) and speech.dtype is np.dtype(np.float64): + speech = speech.astype(np.float32) + + # always return batch + if not is_batched: + speech = [speech] + + # needed to make pad() work on spectrogram inputs + feature_size_hack = self.feature_size + + # convert into correct format for padding + if is_target: + features = [self._extract_mel_features(waveform) for waveform in speech] + encoded_inputs = BatchFeature({"input_values": features}) + self.feature_size = self.num_mel_bins + else: + encoded_inputs = BatchFeature({"input_values": speech}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + **kwargs, + ) + + self.feature_size = feature_size_hack + + # convert input values to correct format + input_values = padded_inputs["input_values"] + if not isinstance(input_values[0], np.ndarray): + padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] + elif ( + not isinstance(input_values, np.ndarray) + and isinstance(input_values[0], np.ndarray) + and input_values[0].dtype is np.dtype(np.float64) + ): + padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] + elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64): + padded_inputs["input_values"] = input_values.astype(np.float32) + + # convert attention_mask to correct format + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + # zero-mean and unit-variance normalization + if not is_target and self.do_normalize: + attention_mask = ( + attention_mask + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + else None + ) + padded_inputs["input_values"] = self.zero_mean_unit_var_norm( + padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value + ) + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + + # Don't serialize these as they are derived from the other properties. + names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"] + for name in names: + if name in output: + del output[name] + + return output + + +__all__ = ["SpeechT5FeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speecht5/modeling_speecht5.py b/venv/lib/python3.13/site-packages/transformers/models/speecht5/modeling_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e79a46680cc2d2fc47e8e259038a55214e7a44 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speecht5/modeling_speecht5.py @@ -0,0 +1,3242 @@ +# coding=utf-8 +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SpeechT5 model.""" + +import math +from typing import Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqSpectrogramOutput, +) +from ...modeling_utils import EmbeddingAccessMixin, PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 1 + +# General docstring + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def shift_spectrograms_right( + input_values: torch.Tensor, reduction_factor: int = 1, attention_mask: Optional[torch.Tensor] = None +): + """ + Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length. + """ + # thin out frames for reduction factor + if reduction_factor > 1: + input_values = input_values[:, reduction_factor - 1 :: reduction_factor] + if attention_mask is not None: + attention_mask = attention_mask[:, reduction_factor - 1 :: reduction_factor] + + shifted_input_values = input_values.new_zeros(input_values.shape) + shifted_input_values[:, 1:] = input_values[:, :-1].clone() + + # replace possible -100 values in labels by zeros + shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0) + + return shifted_input_values, attention_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->SpeechT5 +class SpeechT5SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + def create_position_ids_from_input_ids( + self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0 + ): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SpeechT5 +class SpeechT5PositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SpeechT5ScaledPositionalEncoding(nn.Module): + """ + Scaled positional encoding, see §3.2 in https://huggingface.co/papers/1809.08895 + """ + + def __init__(self, dropout, dim, max_len=5000): + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim)) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) + super().__init__() + self.register_buffer("pe", pe, persistent=False) + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + self.alpha = nn.Parameter(torch.tensor(1.0)) + + def forward(self, emb): + emb = emb + self.alpha * self.pe[:, : emb.size(1)] + emb = self.dropout(emb) + return emb + + +class SpeechT5RelativePositionalEncoding(torch.nn.Module): + def __init__(self, dim, max_length=1000): + super().__init__() + self.dim = dim + self.max_length = max_length + self.pe_k = torch.nn.Embedding(2 * max_length, dim) + + def forward(self, hidden_states): + seq_len = hidden_states.shape[1] + pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long) + pos_seq = pos_seq[:, None] - pos_seq[None, :] + + pos_seq[pos_seq < -self.max_length] = -self.max_length + pos_seq[pos_seq >= self.max_length] = self.max_length - 1 + pos_seq = pos_seq + self.max_length + + return self.pe_k(pos_seq) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SpeechT5 +class SpeechT5SamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SpeechT5 +class SpeechT5FeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [SpeechT5GroupNormConvLayer(config, layer_id=0)] + [ + SpeechT5NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + SpeechT5LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->SpeechT5 +class SpeechT5FeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +class SpeechT5SpeechEncoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.feature_encoder = SpeechT5FeatureEncoder(config) + self.feature_projection = SpeechT5FeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.pos_conv_embed = SpeechT5PositionalConvEmbedding(config) + self.pos_sinusoidal_embed = SpeechT5SinusoidalPositionalEmbedding( + config.max_speech_positions + config.pad_token_id + 1, + config.hidden_size, + config.pad_token_id, + ) + + def freeze_feature_encoder(self): + self.feature_encoder._freeze_parameters() + + def forward( + self, + input_values: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + ): + extract_features = self.feature_encoder(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], + attention_mask, + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + positional_conv_embedding = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + positional_conv_embedding + + if attention_mask is not None: + padding_mask = attention_mask.ne(1).long() + else: + padding_mask = torch.zeros(hidden_states.shape[:2], dtype=torch.long, device=hidden_states.device) + + positional_sinusoidal_embeddings = self.pos_sinusoidal_embed(padding_mask) + hidden_states = hidden_states + positional_sinusoidal_embeddings + + return hidden_states, attention_mask + + # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feature_vector_attention_mask + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feat_extract_output_lengths + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://huggingface.co/papers/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + +class SpeechT5SpeechDecoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.layers = nn.ModuleList( + [ + nn.Linear( + config.num_mel_bins if i == 0 else config.speech_decoder_prenet_units, + config.speech_decoder_prenet_units, + ) + for i in range(config.speech_decoder_prenet_layers) + ] + ) + + self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size) + self.encode_positions = SpeechT5ScaledPositionalEncoding( + config.positional_dropout, + config.hidden_size, + config.max_speech_positions, + ) + self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size) + + def _consistent_dropout(self, inputs_embeds, p): + mask = torch.bernoulli(inputs_embeds[0], p=p) + all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1) + return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p) + + def forward( + self, + input_values: torch.Tensor, + speaker_embeddings: Optional[torch.Tensor] = None, + ): + # Dropout is always applied, even when evaluating. See §2.2 in https://huggingface.co/papers/1712.05884. + + inputs_embeds = input_values + for layer in self.layers: + inputs_embeds = nn.functional.relu(layer(inputs_embeds)) + inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout) + + inputs_embeds = self.final_layer(inputs_embeds) + inputs_embeds = self.encode_positions(inputs_embeds) + + if speaker_embeddings is not None: + speaker_embeddings = nn.functional.normalize(speaker_embeddings) + speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) + inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1) + inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) + + return inputs_embeds + + +class SpeechT5BatchNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + + if layer_id == 0: + in_conv_dim = config.num_mel_bins + else: + in_conv_dim = config.speech_decoder_postnet_units + + if layer_id == config.speech_decoder_postnet_layers - 1: + out_conv_dim = config.num_mel_bins + else: + out_conv_dim = config.speech_decoder_postnet_units + + self.conv = nn.Conv1d( + in_conv_dim, + out_conv_dim, + kernel_size=config.speech_decoder_postnet_kernel, + stride=1, + padding=(config.speech_decoder_postnet_kernel - 1) // 2, + bias=False, + ) + self.batch_norm = nn.BatchNorm1d(out_conv_dim) + + if layer_id < config.speech_decoder_postnet_layers - 1: + self.activation = nn.Tanh() + else: + self.activation = None + + self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + if self.activation is not None: + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SpeechT5SpeechDecoderPostnet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor) + self.prob_out = nn.Linear(config.hidden_size, config.reduction_factor) + + self.layers = nn.ModuleList( + [SpeechT5BatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)] + ) + + def forward(self, hidden_states: torch.Tensor): + outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins) + outputs_after_postnet = self.postnet(outputs_before_postnet) + logits = self.prob_out(hidden_states).view(hidden_states.size(0), -1) + return outputs_before_postnet, outputs_after_postnet, logits + + def postnet(self, hidden_states: torch.Tensor): + layer_output = hidden_states.transpose(1, 2) + for layer in self.layers: + layer_output = layer(layer_output) + return hidden_states + layer_output.transpose(1, 2) + + +class SpeechT5TextEncoderPrenet(nn.Module, EmbeddingAccessMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.encode_positions = SpeechT5ScaledPositionalEncoding( + config.positional_dropout, + config.hidden_size, + config.max_text_positions, + ) + + def forward(self, input_ids: torch.Tensor): + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.encode_positions(inputs_embeds) + return inputs_embeds + + +class SpeechT5TextDecoderPrenet(nn.Module, EmbeddingAccessMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.dropout = nn.Dropout(config.positional_dropout) + self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.embed_positions = SpeechT5SinusoidalPositionalEmbedding( + config.max_text_positions + config.pad_token_id + 1, + config.hidden_size, + config.pad_token_id, + ) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + ): + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + else: + raise ValueError("You have to specify `decoder_input_ids`") + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + positions = self.embed_positions(input_ids, past_key_values_length) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds += positions + inputs_embeds = self.dropout(inputs_embeds) + + return inputs_embeds, attention_mask + + +class SpeechT5TextDecoderPostnet(nn.Module, EmbeddingAccessMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, hidden_states: torch.Tensor): + return self.lm_head(hidden_states) + + def get_output_embeddings(self): + # Post-net has no token embeddings, but its lm_head must still be + # tied to the decoder weights when `tie_word_embeddings=True`. + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + +class SpeechT5Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper with relative position bias (see + https://aclanthology.org/N18-2074.pdf) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # relative attention bias + if position_bias is not None: + reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1) + rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) + rel_pos_bias = rel_pos_bias.transpose(0, 1).view( + bsz * self.num_heads, position_bias.size(0), position_bias.size(1) + ) + attn_weights += rel_pos_bias + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class SpeechT5FeedForward(nn.Module): + def __init__(self, config, intermediate_size): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SpeechT5EncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SpeechT5Config): + super().__init__() + self.attention = SpeechT5Attention( + embed_dim=config.hidden_size, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = SpeechT5FeedForward(config, config.encoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + position_bias (`torch.FloatTensor`): + relative position embeddings of size `(seq_len, seq_len, hidden_size // encoder_attention_heads)` + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SpeechT5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SpeechT5Config, layer_idx=None): + super().__init__() + self.self_attn = SpeechT5Attention( + embed_dim=config.hidden_size, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.encoder_attn = SpeechT5Attention( + config.hidden_size, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.feed_forward = SpeechT5FeedForward(config, config.decoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@auto_docstring +class SpeechT5PreTrainedModel(PreTrainedModel): + config: SpeechT5Config + base_model_prefix = "speecht5" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, SpeechT5PositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, SpeechT5ScaledPositionalEncoding): + module.alpha.data.fill_(1.0) + elif isinstance(module, SpeechT5FeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if hasattr(module, "masked_spec_embed"): + nn.init.uniform_(module.masked_spec_embed) + + +class SpeechT5Encoder(SpeechT5PreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* layers. Each layer is a [`SpeechT5EncoderLayer`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layerdrop = config.encoder_layerdrop + + self.layers = nn.ModuleList([SpeechT5EncoderLayer(config) for _ in range(config.encoder_layers)]) + + self.embed_positions = SpeechT5RelativePositionalEncoding( + config.hidden_size // config.encoder_attention_heads, config.encoder_max_relative_position + ) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): + Features extracted from the speech or text input by the encoder prenet. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + position_bias = self.embed_positions(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + skip_the_layer = False + if self.training: + dropout_probability = torch.rand([]) + skip_the_layer = dropout_probability < self.layerdrop + + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SpeechT5EncoderWithSpeechPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Encoder that applies SpeechT5SpeechEncoderPrenet to convert the audio waveform data to + hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5SpeechEncoderPrenet(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + hidden_states, attention_mask = self.prenet(input_values, attention_mask) + + outputs = self.wrapped_encoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5EncoderWithTextPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Encoder that applies SpeechT5TextEncoderPrenet to convert the input_ids to hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5TextEncoderPrenet(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prenet.get_input_embeddings() + + def set_input_embeddings(self, value): + self.prenet.set_input_embeddings(value) + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + hidden_states = self.prenet(input_values) + + outputs = self.wrapped_encoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with + [`SpeechT5Model`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + return self.wrapped_encoder( + hidden_states=input_values, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SpeechT5Decoder(SpeechT5PreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SpeechT5DecoderLayer`] + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.layerdrop = config.decoder_layerdrop + + self.layers = nn.ModuleList([SpeechT5DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): + Features extracted from the speech or text input by the decoder prenet. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = hidden_states.size()[:-1] + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] + ) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + skip_the_layer = False + if self.training: + dropout_probability = torch.rand([]) + skip_the_layer = dropout_probability < self.layerdrop + if skip_the_layer and not synced_gpus: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Decoder that applies SpeechT5SpeechDecoderPrenet to convert log-mel filterbanks to hidden + features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5SpeechDecoderPrenet(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeddings: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + decoder_hidden_states = self.prenet(input_values, speaker_embeddings) + + outputs = self.wrapped_decoder( + hidden_states=decoder_hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Decoder that applies SpeechT5TextDecoderPrenet to convert input tokens to hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5TextDecoderPrenet(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prenet.get_input_embeddings() + + def set_input_embeddings(self, value): + self.prenet.set_input_embeddings(value) + + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + decoder_hidden_states, attention_mask = self.prenet(input_values, attention_mask, past_key_values) + + outputs = self.wrapped_decoder( + hidden_states=decoder_hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with + [`SpeechT5Model`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + outputs = self.wrapped_decoder( + hidden_states=input_values, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + return outputs + + +class SpeechT5GuidedMultiheadAttentionLoss(nn.Module): + """ + Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional + Networks with Guided Attention](https://huggingface.co/papers/1710.08969), adapted for multi-head attention. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.sigma = config.guided_attention_loss_sigma + self.scale = config.guided_attention_loss_scale + + def forward( + self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor + ) -> torch.Tensor: + """ + Compute the attention loss. + + Args: + attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`): + Batch of multi-head attention weights + input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`): + Input attention mask as booleans. + output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`): + Target attention mask as booleans. + + Returns: + `torch.Tensor` with the loss value + """ + guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device) + masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2) + masks = masks.to(attentions.device).unsqueeze(1) + + losses = guided_attn_masks * attentions + loss = torch.mean(losses.masked_select(masks)) + return self.scale * loss + + def _make_guided_attention_masks(self, input_masks, output_masks, device): + input_lengths = input_masks.sum(-1) + output_lengths = output_masks.sum(-1) + + guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device) + + for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device) + + return guided_attn_masks.unsqueeze(1) + + @staticmethod + def _make_guided_attention_mask(input_length, output_length, sigma, device): + grid_y, grid_x = torch.meshgrid( + torch.arange(input_length, device=device), + torch.arange(output_length, device=device), + indexing="xy", + ) + grid_x = grid_x.float() / output_length + grid_y = grid_y.float() / input_length + return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2))) + + +class SpeechT5SpectrogramLoss(nn.Module): + """ + Loss computation used by SpeechT5ForTextToSpeech. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.use_guided_attention_loss = config.use_guided_attention_loss + self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads + self.reduction_factor = config.reduction_factor + + self.l1_criterion = L1Loss() + self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0)) + + if self.use_guided_attention_loss: + self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config) + + def forward( + self, + attention_mask: torch.LongTensor, + outputs_before_postnet: torch.FloatTensor, + outputs_after_postnet: torch.FloatTensor, + logits: torch.FloatTensor, + labels: torch.FloatTensor, + cross_attentions: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + padding_mask = labels != -100.0 + + # mask out the padded portions + labels = labels.masked_select(padding_mask) + outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask) + outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask) + + # spectrogram loss + l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels) + + # construct stop labels from the padding mask + masks = padding_mask[:, :, 0] + stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1) + stop_labels = stop_labels[:, 1:].masked_select(masks) + logits = logits.masked_select(masks) + + # stop token loss + bce_loss = self.bce_criterion(logits, stop_labels) + + # combined loss + loss = l1_loss + bce_loss + + # guided attention loss + if self.use_guided_attention_loss: + attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1) + input_masks = attention_mask == 1 + output_masks = padding_mask[:, :, 0] + if self.reduction_factor > 1: + output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor] + attn_loss = self.attn_criterion(attn, input_masks, output_masks) + loss += attn_loss + + return loss + + +@auto_docstring( + custom_intro=""" + The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets. + """ +) +class SpeechT5Model(SpeechT5PreTrainedModel): + def __init__( + self, + config: SpeechT5Config, + encoder: Optional[nn.Module] = None, + decoder: Optional[nn.Module] = None, + ): + r""" + encoder (`PreTrainedModel`, *optional*): + The encoder model to use. + decoder (`PreTrainedModel`, *optional*): + The decoder model to use. + """ + super().__init__(config) + self.config = config + self.encoder = SpeechT5EncoderWithoutPrenet(config) if encoder is None else encoder + self.decoder = SpeechT5DecoderWithoutPrenet(config) if decoder is None else decoder + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet): + return self.encoder.get_input_embeddings() + if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): + return self.decoder.get_input_embeddings() + raise NotImplementedError + + def set_input_embeddings(self, value): + if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet): + self.encoder.set_input_embeddings(value) + if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): + self.decoder.set_input_embeddings(value) + + def get_encoder(self): + return self.encoder + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + if isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet): + self.encoder.prenet.freeze_feature_encoder() + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Depending on which encoder is being used, the `input_values` are either: float values of the input raw + speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states. + decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel + filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in + the vocabulary, or hidden states. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_values=input_values, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # downsample encoder attention mask (only for encoders with speech input) + if attention_mask is not None and isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = self.encoder.prenet._get_feature_vector_attention_mask( + encoder_outputs[0].shape[1], attention_mask + ) + else: + encoder_attention_mask = attention_mask + + if isinstance(self.decoder, SpeechT5DecoderWithSpeechPrenet): + decoder_args = {"speaker_embeddings": speaker_embeddings} + else: + decoder_args = {} + + decoder_outputs = self.decoder( + input_values=decoder_input_values, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **decoder_args, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + SpeechT5 Model with a speech encoder and a text decoder. + """ +) +class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["text_decoder_postnet.lm_head.weight"] + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that does not define the" + " vocabulary size of the language model head. Please instantiate the model as follows:" + " `SpeechT5ForSpeechToText.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of" + " your model's configuration." + ) + + speech_encoder = SpeechT5EncoderWithSpeechPrenet(config) + text_decoder = SpeechT5DecoderWithTextPrenet(config) + self.speecht5 = SpeechT5Model(config, speech_encoder, text_decoder) + + self.text_decoder_postnet = SpeechT5TextDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speecht5.get_encoder() + + def get_decoder(self): + return self.speecht5.get_decoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.get_encoder().prenet.freeze_feature_encoder() + + def get_output_embeddings(self): + return self.text_decoder_postnet.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.text_decoder_postnet.set_output_embeddings(new_embeddings) + + @auto_docstring + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SpeechT5 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Label indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToText + >>> from datasets import load_dataset + + >>> dataset = load_dataset( + ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation" + ... ) # doctest: +IGNORE_RESULT + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr") + >>> model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr") + + >>> # audio file is decoded on the fly + >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + >>> predicted_ids = model.generate(**inputs, max_length=100) + + >>> # transcribe speech + >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + >>> transcription[0] + 'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel' + ``` + + ```python + >>> inputs["labels"] = processor(text_target=dataset[0]["text"], return_tensors="pt").input_ids + + >>> # compute loss + >>> loss = model(**inputs).loss + >>> round(loss.item(), 2) + 19.68 + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.speecht5( + input_values=input_values, + attention_mask=attention_mask, + decoder_input_values=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + logits = self.text_decoder_postnet(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +def _generate_speech( + model: SpeechT5PreTrainedModel, + input_values: torch.FloatTensor, + speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, +) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]: + if speaker_embeddings is None: + raise ValueError( + """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following + the code snippet provided in this link: + https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors + """ + ) + + if attention_mask is None: + encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int() + else: + encoder_attention_mask = attention_mask + + bsz = input_values.size(0) + + encoder_out = model.speecht5.encoder( + input_values=input_values, + attention_mask=encoder_attention_mask, + return_dict=True, + ) + + encoder_last_hidden_state = encoder_out.last_hidden_state + + # downsample encoder attention mask + if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( + encoder_out[0].shape[1], encoder_attention_mask + ) + + maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor) + minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor) + + # Start the output sequence with a mel spectrum that is all zeros. + output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins) + + spectrogram = [] + cross_attentions = [] + past_key_values = None + idx = 0 + result_spectrogram = {} + + while True: + idx += 1 + + # Run the decoder prenet on the entire output sequence. + decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) + # Run the decoder layers on the last element of the prenet output. + decoder_out = model.speecht5.decoder.wrapped_decoder( + hidden_states=decoder_hidden_states[:, -1:], + attention_mask=None, + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=True, + output_attentions=output_cross_attentions, + return_dict=True, + ) + + if output_cross_attentions: + cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) + + last_decoder_output = decoder_out.last_hidden_state.squeeze(1) + past_key_values = decoder_out.past_key_values + + # Predict the new mel spectrum for this step in the sequence. + spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) + spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) + spectrogram.append(spectrum) + + # Extend the output sequence with the new mel spectrum. + new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins) + output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1) + # Predict the probability that this is the stop token. + prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) + + if idx < minlen: + continue + else: + # If the generation loop is less than maximum length time, check the ones in the batch that have met + # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch. + if idx < maxlen: + meet_thresholds = torch.sum(prob, dim=-1) >= threshold + meet_indexes = torch.where(meet_thresholds)[0].tolist() + else: + meet_indexes = range(len(prob)) + meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] + if len(meet_indexes) > 0: + spectrograms = torch.stack(spectrogram) + spectrograms = spectrograms.transpose(0, 1).flatten(1, 2) + spectrograms = model.speech_decoder_postnet.postnet(spectrograms) + for meet_index in meet_indexes: + result_spectrogram[meet_index] = spectrograms[meet_index] + if len(result_spectrogram) >= bsz: + break + spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] + if not return_output_lengths: + spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + if vocoder is not None: + outputs = vocoder(spectrogram) + else: + outputs = spectrogram + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + if bsz > 1: + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (outputs, cross_attentions) + else: + # batched return values should also include the spectrogram/waveform lengths + spectrogram_lengths = [] + for i in range(bsz): + spectrogram_lengths.append(spectrograms[i].size(0)) + if vocoder is None: + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + outputs = (spectrograms, spectrogram_lengths) + else: + waveforms = [] + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + waveforms = vocoder(spectrograms) + waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + outputs = (waveforms, waveform_lengths) + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (*outputs, cross_attentions) + return outputs + + +@auto_docstring( + custom_intro=""" + SpeechT5 Model with a text encoder and a speech decoder. + """ +) +class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): + main_input_name = "input_ids" + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that does not define the" + " vocabulary size of the language model head. Please instantiate the model as follows:" + " `SpeechT5ForTextToSpeech.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of" + " your model's configuration." + ) + + text_encoder = SpeechT5EncoderWithTextPrenet(config) + speech_decoder = SpeechT5DecoderWithSpeechPrenet(config) + self.speecht5 = SpeechT5Model(config, text_encoder, speech_decoder) + + self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def can_generate(cls) -> bool: + # Speecht5 has a unique model structure, where the external class (`SpeechT5ForTextToSpeech`) doesn't need to inherit from + # `GenerationMixin` (it has a non-standard generation method). This means that the base `can_generate()` will return `False`, + # but we need to override it so as to do `GenerationConfig` handling in multiple parts of the codebase. + return True + + def get_encoder(self): + return self.speecht5.get_encoder() + + def get_decoder(self): + return self.speecht5.get_decoder() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + stop_labels: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, Seq2SeqSpectrogramOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`): + Float values of input mel spectrogram. + + SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If + `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): + Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss + computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`] + for details. + stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Binary tensor indicating the position of the stop token in the sequence. + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, set_seed + >>> import torch + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") + >>> model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") + >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + + >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt") + >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file + + >>> set_seed(555) # make deterministic + + >>> # generate speech + >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder) + >>> speech.shape + torch.Size([15872]) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_values is None: + decoder_input_values, decoder_attention_mask = shift_spectrograms_right( + labels, self.config.reduction_factor, decoder_attention_mask + ) + if self.config.use_guided_attention_loss: + output_attentions = True + + outputs = self.speecht5( + input_values=input_ids, + attention_mask=attention_mask, + decoder_input_values=decoder_input_values, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + speaker_embeddings=speaker_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0]) + + loss = None + if labels is not None: + criterion = SpeechT5SpectrogramLoss(self.config) + loss = criterion( + attention_mask, + outputs_before_postnet, + outputs_after_postnet, + logits, + labels, + outputs.cross_attentions, + ) + + if not return_dict: + output = (outputs_after_postnet,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSpectrogramOutput( + loss=loss, + spectrogram=outputs_after_postnet, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + **kwargs, + ) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]: + r""" + Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a + speech waveform using a vocoder. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Attention mask from the tokenizer, required for batched inference to signal to the model where to + ignore padded tokens from the input_ids. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch_size." + ) + + return _generate_speech( + self, + input_ids, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + @torch.no_grad() + def generate_speech( + self, + input_ids: torch.LongTensor, + speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + ) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]: + r""" + Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a + speech waveform using a vocoder. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*, defaults to `None`): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch size." + ) + + return _generate_speech( + self, + input_ids, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + +@auto_docstring( + custom_intro=""" + SpeechT5 Model with a speech encoder and a speech decoder. + """ +) +class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + speech_encoder = SpeechT5EncoderWithSpeechPrenet(config) + speech_decoder = SpeechT5DecoderWithSpeechPrenet(config) + self.speecht5 = SpeechT5Model(config, speech_encoder, speech_decoder) + + self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speecht5.get_encoder() + + def get_decoder(self): + return self.speecht5.get_decoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.get_encoder().prenet.freeze_feature_encoder() + + @auto_docstring + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + stop_labels: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, Seq2SeqSpectrogramOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into + a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`): + Float values of input mel spectrogram. + + SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If + `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): + Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See + [`SpeechT5Processor.__call__`] for details. + stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Binary tensor indicating the position of the stop token in the sequence. + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToSpeech, SpeechT5HifiGan, set_seed + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset( + ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation" + ... ) # doctest: +IGNORE_RESULT + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_vc") + >>> model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc") + >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + + >>> # audio file is decoded on the fly + >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + + >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file + + >>> set_seed(555) # make deterministic + + >>> # generate speech + >>> speech = model.generate_speech(inputs["input_values"], speaker_embeddings, vocoder=vocoder) + >>> speech.shape + torch.Size([77824]) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_values is None: + decoder_input_values, decoder_attention_mask = shift_spectrograms_right( + labels, self.config.reduction_factor, decoder_attention_mask + ) + + outputs = self.speecht5( + input_values=input_values, + attention_mask=attention_mask, + decoder_input_values=decoder_input_values, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + speaker_embeddings=speaker_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + ) + + _, spectrogram, logits = self.speech_decoder_postnet(outputs[0]) + + loss = None + + if not return_dict: + output = (spectrogram,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSpectrogramOutput( + loss=loss, + spectrogram=spectrogram, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @torch.no_grad() + def generate_speech( + self, + input_values: torch.FloatTensor, + speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + ) -> torch.FloatTensor: + r""" + Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a + speech waveform using a vocoder. + + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. + + Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `list[float]`, + a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) + or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*, defaults to `None`): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is None: + speaker_embeddings = torch.zeros((1, 512), device=input_values.device) + + return _generate_speech( + self, + input_values, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +@auto_docstring( + custom_intro=""" + HiFi-GAN vocoder. + """ +) +class SpeechT5HifiGan(PreTrainedModel): + config: SpeechT5HifiGanConfig + main_input_name = "spectrogram" + + def __init__(self, config: SpeechT5HifiGanConfig): + super().__init__(config) + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + config.model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + self.register_buffer("mean", torch.zeros(config.model_in_dim)) + self.register_buffer("scale", torch.ones(config.model_in_dim)) + + # Initialize weights and apply final processing + self.post_init() + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) + for layer in self.upsampler: + weight_norm(layer) + for layer in self.resblocks: + layer.apply_weight_norm() + weight_norm(self.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + for layer in self.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.conv_post) + + @auto_docstring( + custom_intro=""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + """ + ) + def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor: + r""" + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`. + + Returns: + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + if self.config.normalize_before: + spectrogram = (spectrogram - self.mean) / self.scale + + is_batched = spectrogram.dim() == 3 + if not is_batched: + spectrogram = spectrogram.unsqueeze(0) + + hidden_states = spectrogram.transpose(2, 1) + + hidden_states = self.conv_pre(hidden_states) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + if not is_batched: + # remove batch dim and collapse tensor to 1-d audio waveform + waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1) + else: + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +__all__ = [ + "SpeechT5ForSpeechToText", + "SpeechT5ForSpeechToSpeech", + "SpeechT5ForTextToSpeech", + "SpeechT5Model", + "SpeechT5PreTrainedModel", + "SpeechT5HifiGan", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speecht5/number_normalizer.py b/venv/lib/python3.13/site-packages/transformers/models/speecht5/number_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3314c24f24c1f8b9bc760c4ece69e0a2819888 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speecht5/number_normalizer.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Number Normalizer class for SpeechT5.""" + +import re + + +class EnglishNumberNormalizer: + def __init__(self): + self.ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] + self.teens = [ + "", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ] + self.tens = ["", "ten", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] + self.thousands = [ + "", + "thousand", + "million", + "billion", + "trillion", + "quadrillion", + "quintillion", + "sextillion", + "septillion", + "octillion", + "nonillion", + "decillion", + ] + + # Define a dictionary to map currency symbols to their names + # Top most traded currencies according to + # https://en.wikipedia.org/wiki/Template:Most_traded_currencies + self.currency_symbols = { + "$": " dollars", + "€": " euros", + "£": " pounds", + "¢": " cents", + "¥": " japanese yen", + "﷼": " saudi riyal", + "₹": " indian rupees", + "₽": " russian rubles", + "฿": " thai baht", + "₺": " turkish liras", + "₴": " ukrainian hryvnia", + "₣": " swiss francs", + "₡": " costa rican colon", + "₱": " philippine peso", + "₪": " israeli shekels", + "₮": " mongolian tögrög", + "₩": " south korean won", + "₦": " nigerian naira", + "₫": " vietnamese Đồng", + } + + def spell_number(self, num): + if num == 0: + return "zero" + + parts = [] + for i in range(0, len(self.thousands)): + if num % 1000 != 0: + part = "" + hundreds = num % 1000 // 100 + tens_units = num % 100 + + if hundreds > 0: + part += self.ones[hundreds] + " hundred" + if tens_units > 0: + part += " and " + + if tens_units > 10 and tens_units < 20: + part += self.teens[tens_units - 10] + else: + tens_digit = self.tens[tens_units // 10] + ones_digit = self.ones[tens_units % 10] + if tens_digit: + part += tens_digit + if ones_digit: + if tens_digit: + part += " " + part += ones_digit + + parts.append(part) + + num //= 1000 + + return " ".join(reversed(parts)) + + def convert(self, number): + """ + Converts an individual number passed in string form to spelt-out form + """ + if "." in number: + integer_part, decimal_part = number.split(".") + else: + integer_part, decimal_part = number, "00" + + # Extract currency symbol if present + currency_symbol = "" + for symbol, name in self.currency_symbols.items(): + if integer_part.startswith(symbol): + currency_symbol = name + integer_part = integer_part[len(symbol) :] + break + + if integer_part.startswith("-"): + if integer_part[1:].startswith(symbol): + currency_symbol = name + integer_part = "-" + integer_part[len(symbol) + 1 :] + break + + # Extract 'minus' prefix for negative numbers + minus_prefix = "" + if integer_part.startswith("-"): + minus_prefix = "minus " + integer_part = integer_part[1:] + elif integer_part.startswith("minus"): + minus_prefix = "minus " + integer_part = integer_part[len("minus") :] + + percent_suffix = "" + if "%" in integer_part or "%" in decimal_part: + percent_suffix = " percent" + integer_part = integer_part.replace("%", "") + decimal_part = decimal_part.replace("%", "") + + integer_part = integer_part.zfill(3 * ((len(integer_part) - 1) // 3 + 1)) + + parts = [] + for i in range(0, len(integer_part), 3): + chunk = int(integer_part[i : i + 3]) + if chunk > 0: + part = self.spell_number(chunk) + unit = self.thousands[len(integer_part[i:]) // 3 - 1] + if unit: + part += " " + unit + parts.append(part) + + spelled_integer = " ".join(parts) + + # Format the spelt-out number based on conditions, such as: + # If it has decimal parts, currency symbol, minus prefix, etc + if decimal_part == "00": + return ( + f"{minus_prefix}{spelled_integer}{percent_suffix}{currency_symbol}" + if minus_prefix or currency_symbol + else f"{spelled_integer}{percent_suffix}" + ) + else: + spelled_decimal = " ".join([self.spell_number(int(digit)) for digit in decimal_part]) + return ( + f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}{currency_symbol}" + if minus_prefix or currency_symbol + else f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}" + ) + + def __call__(self, text): + """ + Convert numbers / number-like quantities in a string to their spelt-out counterparts + """ + # Form part of the pattern for all currency symbols + pattern = r"(? 15000, etc) + text = re.sub(r"(\d+,\d+)", lambda match: match.group(1).replace(",", ""), text) + + # Use regex to find and replace numbers in the text + converted_text = re.sub(pattern, lambda match: self.convert(match.group(1)), text) + converted_text = re.sub(" +", " ", converted_text) + + return converted_text diff --git a/venv/lib/python3.13/site-packages/transformers/models/speecht5/processing_speecht5.py b/venv/lib/python3.13/site-packages/transformers/models/speecht5/processing_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..6be19f8ae9ed1c1cde5c4579b39344221b970758 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speecht5/processing_speecht5.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Speech processor class for SpeechT5.""" + +from ...processing_utils import ProcessorMixin + + +class SpeechT5Processor(ProcessorMixin): + r""" + Constructs a SpeechT5 processor which wraps a feature extractor and a tokenizer into a single processor. + + [`SpeechT5Processor`] offers all the functionalities of [`SpeechT5FeatureExtractor`] and [`SpeechT5Tokenizer`]. See + the docstring of [`~SpeechT5Processor.__call__`] and [`~SpeechT5Processor.decode`] for more information. + + Args: + feature_extractor (`SpeechT5FeatureExtractor`): + An instance of [`SpeechT5FeatureExtractor`]. The feature extractor is a required input. + tokenizer (`SpeechT5Tokenizer`): + An instance of [`SpeechT5Tokenizer`]. The tokenizer is a required input. + """ + + feature_extractor_class = "SpeechT5FeatureExtractor" + tokenizer_class = "SpeechT5Tokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__(self, *args, **kwargs): + """ + Processes audio and text input, as well as audio and text targets. + + You can process audio by using the argument `audio`, or process audio targets by using the argument + `audio_target`. This forwards the arguments to SpeechT5FeatureExtractor's + [`~SpeechT5FeatureExtractor.__call__`]. + + You can process text by using the argument `text`, or process text labels by using the argument `text_target`. + This forwards the arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.__call__`]. + + Valid input combinations are: + + - `text` only + - `audio` only + - `text_target` only + - `audio_target` only + - `text` and `audio_target` + - `audio` and `audio_target` + - `text` and `text_target` + - `audio` and `text_target` + + Please refer to the docstring of the above two methods for more information. + """ + audio = kwargs.pop("audio", None) + text = kwargs.pop("text", None) + text_target = kwargs.pop("text_target", None) + audio_target = kwargs.pop("audio_target", None) + sampling_rate = kwargs.pop("sampling_rate", None) + + if audio is not None and text is not None: + raise ValueError( + "Cannot process both `audio` and `text` inputs. Did you mean `audio_target` or `text_target`?" + ) + if audio_target is not None and text_target is not None: + raise ValueError( + "Cannot process both `audio_target` and `text_target` inputs. Did you mean `audio` or `text`?" + ) + if audio is None and audio_target is None and text is None and text_target is None: + raise ValueError( + "You need to specify either an `audio`, `audio_target`, `text`, or `text_target` input to process." + ) + + if audio is not None: + inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + elif text is not None: + inputs = self.tokenizer(text, **kwargs) + else: + inputs = None + + if audio_target is not None: + targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs) + labels = targets["input_values"] + elif text_target is not None: + targets = self.tokenizer(text_target, **kwargs) + labels = targets["input_ids"] + else: + targets = None + + if inputs is None: + return targets + + if targets is not None: + inputs["labels"] = labels + + decoder_attention_mask = targets.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask + + return inputs + + def pad(self, *args, **kwargs): + """ + Collates the audio and text inputs, as well as their targets, into a padded batch. + + Audio inputs are padded by SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`]. Text inputs are padded + by SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`]. + + Valid input combinations are: + + - `input_ids` only + - `input_values` only + - `labels` only, either log-mel spectrograms or text tokens + - `input_ids` and log-mel spectrogram `labels` + - `input_values` and text `labels` + + Please refer to the docstring of the above two methods for more information. + """ + input_values = kwargs.pop("input_values", None) + input_ids = kwargs.pop("input_ids", None) + labels = kwargs.pop("labels", None) + + if input_values is not None and input_ids is not None: + raise ValueError("Cannot process both `input_values` and `input_ids` inputs.") + if input_values is None and input_ids is None and labels is None: + raise ValueError( + "You need to specify either an `input_values`, `input_ids`, or `labels` input to be padded." + ) + + if input_values is not None: + inputs = self.feature_extractor.pad(input_values, *args, **kwargs) + elif input_ids is not None: + inputs = self.tokenizer.pad(input_ids, **kwargs) + else: + inputs = None + + if labels is not None: + if "input_ids" in labels or (isinstance(labels, list) and "input_ids" in labels[0]): + targets = self.tokenizer.pad(labels, **kwargs) + labels = targets["input_ids"] + else: + feature_size_hack = self.feature_extractor.feature_size + self.feature_extractor.feature_size = self.feature_extractor.num_mel_bins + targets = self.feature_extractor.pad(labels, *args, **kwargs) + self.feature_extractor.feature_size = feature_size_hack + labels = targets["input_values"] + else: + targets = None + + if inputs is None: + return targets + + if targets is not None: + inputs["labels"] = labels + + decoder_attention_mask = targets.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask + + return inputs + + +__all__ = ["SpeechT5Processor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/speecht5/tokenization_speecht5.py b/venv/lib/python3.13/site-packages/transformers/models/speecht5/tokenization_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..de19db8568e542b6b2506e0652d48c89850c0e40 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/speecht5/tokenization_speecht5.py @@ -0,0 +1,223 @@ +# coding=utf-8 +# Copyright 2023 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for SpeechT5.""" + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires +from .number_normalizer import EnglishNumberNormalizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spm_char.model"} + + +@requires(backends=("sentencepiece",)) +class SpeechT5Tokenizer(PreTrainedTokenizer): + """ + Construct a SpeechT5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The begin of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + normalize (`bool`, *optional*, defaults to `False`): + Whether to convert numeric quantities in the text to their spelt-out english counterparts. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + normalize=False, + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.vocab_file = vocab_file + self.normalize = normalize + self._normalizer = None + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + normalize=normalize, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + normalize = kwargs.pop("normalize", self.normalize) + if is_split_into_words: + text = " " + text + if normalize: + text = self.normalizer(text) + return (text, kwargs) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + @property + def normalizer(self): + if self._normalizer is None: + self._normalizer = EnglishNumberNormalizer() + return self._normalizer + + @normalizer.setter + def normalizer(self, value): + self._normalizer = value + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]: + """Build model inputs from a sequence by appending eos_token_id.""" + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + suffix_ones = [1] + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + suffix_ones + return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["SpeechT5Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/stablelm/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/stablelm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..004a82f49cd66857a6d10e614da2451046a1dcad --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/stablelm/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_stablelm import * + from .modeling_stablelm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/stablelm/configuration_stablelm.py b/venv/lib/python3.13/site-packages/transformers/models/stablelm/configuration_stablelm.py new file mode 100644 index 0000000000000000000000000000000000000000..b919470c7f3d55c5f61b8f3ebea987e7a678fd53 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/stablelm/configuration_stablelm.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright 2024 Stability AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""StableLM model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class StableLmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~StableLmModel`]. + It is used to instantiate an StableLM model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the StableLM [stabilityai/stablelm-3b-4e1t](https://huggingface.co/stabilityai/stablelm-3b-4e1t) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the StableLM model. Defines the number of different tokens that + can be represented by the `inputs_ids` passed when calling [`StableLmModel`]. + intermediate_size (`int`, *optional*, defaults to 6912): + Dimension of the MLP representations. + hidden_size (`int`, *optional*, defaults to 2560): + Number of hidden layers in the Transformer decoder. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string). + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing + all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions + (not used by all models). Only relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to `10000.0`): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_qkv_bias (`bool`, *optional*, defaults to `False`): + Whether or not the model should use bias for qkv layers. + qk_layernorm (`bool`, *optional*, defaults to `False`): + Whether or not to normalize, per head, the Queries and Keys after projecting the hidden states. + use_parallel_residual (`bool`, *optional*, defaults to `False`): + Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training + speedup at large scales. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after applying the MLP to the hidden states. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + bos_token_id (int, *optional*, defaults to 0): + The id of the `BOS` token in the vocabulary. + eos_token_id (int, *optional*, defaults to 0): + The id of the `EOS` token in the vocabulary. + + Example: + + ```python + >>> from transformers import StableLmModel, StableLmConfig + + >>> # Initializing a StableLM stablelm-3b style configuration + >>> configuration = StableLmConfig() + ```""" + + model_type = "stablelm" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + intermediate_size=6912, + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + layer_norm_eps=1.0e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10_000, + rope_scaling=None, + use_qkv_bias=False, + qk_layernorm=False, + use_parallel_residual=False, + hidden_dropout=0.0, + attention_dropout=0.0, + partial_rotary_factor=0.25, + bos_token_id=0, + eos_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.use_qkv_bias = use_qkv_bias + self.qk_layernorm = qk_layernorm + self.use_parallel_residual = use_parallel_residual + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.partial_rotary_factor = partial_rotary_factor + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["StableLmConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/stablelm/modeling_stablelm.py b/venv/lib/python3.13/site-packages/transformers/models/stablelm/modeling_stablelm.py new file mode 100644 index 0000000000000000000000000000000000000000..6b31565a1b1d9ad7e37b970c7a9eeb95dcaac52c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/stablelm/modeling_stablelm.py @@ -0,0 +1,996 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch StableLM model.""" + +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_stablelm import StableLmConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm +class StableLmRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: StableLmConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->StableLm +class StableLmMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class StableLmLayerNormPerHead(nn.Module): + def __init__(self, dim, num_heads, eps=1e-5, bias=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)]) + + def forward(self, hidden_states: torch.Tensor): + # Split along the num_heads axis to get per-head inputs + # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads + states_per_heads = torch.split(hidden_states, 1, dim=1) + # Normalize and merge the heads back together + return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class StableLmAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.rope_theta = config.rope_theta + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps) + self.k_layernorm = StableLmLayerNormPerHead( + self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps + ) + + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.rotary_emb = StableLmRotaryEmbedding(config=self.config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_values is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights += causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) + attn_weights = self.attention_dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class StableLmSdpaAttention(StableLmAttention): + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_values is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = bool(causal_mask is None and q_len > 1) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +class StableLmFlashAttention2(StableLmAttention): + """ + StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + # StableLmFlashAttention2 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_values is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout.p if self.training else 0.0 + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +ATTENTION_CLASSES = { + "eager": StableLmAttention, + "sdpa": StableLmSdpaAttention, + "flash_attention_2": StableLmFlashAttention2, +} + + +class StableLmDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: StableLmConfig, layer_idx: int): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.hidden_size = config.hidden_size + self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.mlp = StableLmMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = None + if not self.use_parallel_residual: + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + self_attn_output, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward + if self.use_parallel_residual: + # x = x + attn(ln1(x)) + mlp(ln1(x)) + # Fully Connected + mlp_output = self.mlp(hidden_states) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + self_attn_output + mlp_output + else: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + residual = residual + self_attn_output + # Fully Connected + mlp_output = self.mlp(self.post_attention_layernorm(residual)) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + mlp_output + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class StableLmPreTrainedModel(PreTrainedModel): + config: StableLmConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["StableLmDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +@auto_docstring +class StableLmModel(StableLmPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableLmDecoderLayer`] + + Args: + config: StableLmConfig + """ + + def __init__(self, config: StableLmConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.rotary_emb = StableLmRotaryEmbedding(config=config) + + self._attn_implementation = config._attn_implementation + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm +class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm + def __init__(self, config): + super().__init__(config) + self.model = StableLmModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, StableLmForCausalLM + + >>> model = StableLmForCausalLM.from_pretrained("adept/persimmon-8b-base") + >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base") + + >>> prompt = "human: Hey, what should I eat for dinner?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + # No upscaling to float was ever done for StableLm + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class StableLmForSequenceClassification(GenericForSequenceClassification, StableLmPreTrainedModel): ... + + +class StableLmForTokenClassification(GenericForTokenClassification, StableLmPreTrainedModel): ... + + +__all__ = [ + "StableLmForCausalLM", + "StableLmModel", + "StableLmPreTrainedModel", + "StableLmForSequenceClassification", + "StableLmForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/superpoint/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/superpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccec260fa5e7d33818153d651d833d43226224e7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/superpoint/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_superpoint import * + from .image_processing_superpoint import * + from .image_processing_superpoint_fast import * + from .modeling_superpoint import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/superpoint/configuration_superpoint.py b/venv/lib/python3.13/site-packages/transformers/models/superpoint/configuration_superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4271960726fd34235e8377fd42d438bda1d35c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/superpoint/configuration_superpoint.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SuperPointConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a + SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SuperPoint + [magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + encoder_hidden_sizes (`List`, *optional*, defaults to `[64, 64, 128, 128]`): + The number of channels in each convolutional layer in the encoder. + decoder_hidden_size (`int`, *optional*, defaults to 256): The hidden size of the decoder. + keypoint_decoder_dim (`int`, *optional*, defaults to 65): The output dimension of the keypoint decoder. + descriptor_decoder_dim (`int`, *optional*, defaults to 256): The output dimension of the descriptor decoder. + keypoint_threshold (`float`, *optional*, defaults to 0.005): + The threshold to use for extracting keypoints. + max_keypoints (`int`, *optional*, defaults to -1): + The maximum number of keypoints to extract. If `-1`, will extract all keypoints. + nms_radius (`int`, *optional*, defaults to 4): + The radius for non-maximum suppression. + border_removal_distance (`int`, *optional*, defaults to 4): + The distance from the border to remove keypoints. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + ```python + >>> from transformers import SuperPointConfig, SuperPointForKeypointDetection + + >>> # Initializing a SuperPoint superpoint style configuration + >>> configuration = SuperPointConfig() + >>> # Initializing a model from the superpoint style configuration + >>> model = SuperPointForKeypointDetection(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "superpoint" + + def __init__( + self, + encoder_hidden_sizes: list[int] = [64, 64, 128, 128], + decoder_hidden_size: int = 256, + keypoint_decoder_dim: int = 65, + descriptor_decoder_dim: int = 256, + keypoint_threshold: float = 0.005, + max_keypoints: int = -1, + nms_radius: int = 4, + border_removal_distance: int = 4, + initializer_range=0.02, + **kwargs, + ): + self.encoder_hidden_sizes = encoder_hidden_sizes + self.decoder_hidden_size = decoder_hidden_size + self.keypoint_decoder_dim = keypoint_decoder_dim + self.descriptor_decoder_dim = descriptor_decoder_dim + self.keypoint_threshold = keypoint_threshold + self.max_keypoints = max_keypoints + self.nms_radius = nms_radius + self.border_removal_distance = border_removal_distance + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +__all__ = ["SuperPointConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/superpoint/image_processing_superpoint.py b/venv/lib/python3.13/site-packages/transformers/models/superpoint/image_processing_superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..4c895b035feb41dd419e3364be6cc9b15c185497 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/superpoint/image_processing_superpoint.py @@ -0,0 +1,352 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SuperPoint.""" + +from typing import TYPE_CHECKING, Optional, Union + +import numpy as np + +from ... import is_torch_available, is_vision_available +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, logging, requires_backends + + +if is_torch_available(): + import torch + +if TYPE_CHECKING: + from .modeling_superpoint import SuperPointKeypointDescriptionOutput + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +def is_grayscale( + image: np.ndarray, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True + return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) + elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True + return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) + + +def convert_to_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch + and tensorflow grayscale conversion + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (Image): + The image to convert. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + requires_backends(convert_to_grayscale, ["vision"]) + + if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image + if input_data_format == ChannelDimension.FIRST: + gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=0) + elif input_data_format == ChannelDimension.LAST: + gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=-1) + return gray_image + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("L") + return image + + +class SuperPointImageProcessor(BaseImageProcessor): + r""" + Constructs a SuperPoint image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden + by `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): + Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to + `True`. Can be overridden by `size` in the `preprocess` method. + resample (`Resampling`, *optional*, defaults to `2`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_grayscale (`bool`, *optional*, defaults to `False`): + Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_grayscale: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 480, "width": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be inferred from the input + image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + size = get_size_dict(size, default_to_square=False) + + return resize( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_grayscale: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None: + raise ValueError("Size must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_grayscale: + images = [convert_to_grayscale(image, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_detection( + self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, list[tuple]] + ) -> list[dict[str, "torch.Tensor"]]: + """ + Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + + Args: + outputs ([`SuperPointKeypointDescriptionOutput`]): + Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors. + target_sizes (`torch.Tensor` or `list[tuple[int, int]]`): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. This must be the original + image size (before any processing). + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according + to target_sizes, scores and descriptors for an image in the batch as predicted by the model. + """ + if len(outputs.mask) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + + if isinstance(target_sizes, list): + image_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_sizes = target_sizes + + # Flip the image sizes to (width, height) and convert keypoints to absolute coordinates + image_sizes = torch.flip(image_sizes, [1]) + masked_keypoints = outputs.keypoints * image_sizes[:, None] + + # Convert masked_keypoints to int + masked_keypoints = masked_keypoints.to(torch.int32) + + results = [] + for image_mask, keypoints, scores, descriptors in zip( + outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors + ): + indices = torch.nonzero(image_mask).squeeze(1) + keypoints = keypoints[indices] + scores = scores[indices] + descriptors = descriptors[indices] + results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors}) + + return results + + +__all__ = ["SuperPointImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/superpoint/image_processing_superpoint_fast.py b/venv/lib/python3.13/site-packages/transformers/models/superpoint/image_processing_superpoint_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..54f95fa75af617456c1bb2375fff65c741cad025 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/superpoint/image_processing_superpoint_fast.py @@ -0,0 +1,175 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Superpoint.""" + +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, +) + + +if TYPE_CHECKING: + from .modeling_superpoint import SuperPointKeypointDescriptionOutput + +import torchvision.transforms.v2.functional as F + + +def is_grayscale( + image: "torch.Tensor", +): + """Checks if an image is grayscale (all RGB channels are identical).""" + if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1: + return True + return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all( + image[..., 1, :, :] == image[..., 2, :, :] + ) + + +class SuperPointFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + do_grayscale (`bool`, *optional*, defaults to `True`): + Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. + """ + + do_grayscale: Optional[bool] = True + + +def convert_to_grayscale( + image: "torch.Tensor", +) -> "torch.Tensor": + """ + Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor. + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (torch.Tensor): + The image to convert. + """ + if is_grayscale(image): + return image + return F.rgb_to_grayscale(image, num_output_channels=3) + + +@auto_docstring +class SuperPointImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + size = {"height": 480, "width": 640} + default_to_square = False + do_resize = True + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = None + valid_kwargs = SuperPointFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[SuperPointFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + size: Union[dict[str, int], SizeDict], + rescale_factor: float, + do_rescale: bool, + do_resize: bool, + interpolation: Optional["F.InterpolationMode"], + do_grayscale: bool, + disable_grouping: bool, + return_tensors: Union[str, TensorType], + **kwargs, + ) -> BatchFeature: + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_grayscale: + stacked_images = convert_to_grayscale(stacked_images) + if do_resize: + stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation) + if do_rescale: + stacked_images = self.rescale(stacked_images, rescale_factor) + processed_images_grouped[shape] = stacked_images + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature(data={"pixel_values": processed_images}) + + def post_process_keypoint_detection( + self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, list[tuple]] + ) -> list[dict[str, "torch.Tensor"]]: + """ + Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + + Args: + outputs ([`SuperPointKeypointDescriptionOutput`]): + Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. This must be the original + image size (before any processing). + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according + to target_sizes, scores and descriptors for an image in the batch as predicted by the model. + """ + if len(outputs.mask) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + + if isinstance(target_sizes, list): + image_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_sizes = target_sizes + + # Flip the image sizes to (width, height) and convert keypoints to absolute coordinates + image_sizes = torch.flip(image_sizes, [1]) + masked_keypoints = outputs.keypoints * image_sizes[:, None] + + # Convert masked_keypoints to int + masked_keypoints = masked_keypoints.to(torch.int32) + + results = [] + for image_mask, keypoints, scores, descriptors in zip( + outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors + ): + indices = torch.nonzero(image_mask).squeeze(1) + keypoints = keypoints[indices] + scores = scores[indices] + descriptors = descriptors[indices] + results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors}) + + return results + + +__all__ = ["SuperPointImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/superpoint/modeling_superpoint.py b/venv/lib/python3.13/site-packages/transformers/models/superpoint/modeling_superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..efd3113eb3b0f8518734a36ff451f347605737b2 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/superpoint/modeling_superpoint.py @@ -0,0 +1,477 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SuperPoint model.""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from transformers import PreTrainedModel +from transformers.modeling_outputs import ( + BaseModelOutputWithNoAttention, +) +from transformers.models.superpoint.configuration_superpoint import SuperPointConfig + +from ...utils import ( + ModelOutput, + auto_docstring, + logging, +) + + +logger = logging.get_logger(__name__) + + +def remove_keypoints_from_borders( + keypoints: torch.Tensor, scores: torch.Tensor, border: int, height: int, width: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Removes keypoints (and their associated scores) that are too close to the border""" + mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) + mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints: torch.Tensor, scores: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]: + """Keeps the k keypoints with highest score""" + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor: + """Applies non-maximum suppression on scores""" + if nms_radius < 0: + raise ValueError("Expected positive values for nms_radius") + + def max_pool(x): + return nn.functional.max_pool2d(x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of + keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, + the maximum number of keypoints is set as the dimension of the keypoints, scores and descriptors tensors. The mask + tensor is used to indicate which values in the keypoints, scores and descriptors tensors are keypoint information + and which are padding. + """ +) +class SuperPointKeypointDescriptionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Relative (x, y) coordinates of predicted keypoints in a given image. + scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`): + Scores of predicted keypoints. + descriptors (`torch.FloatTensor` of shape `(batch_size, num_keypoints, descriptor_size)`): + Descriptors of predicted keypoints. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in keypoints, scores and descriptors are keypoint information. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + """ + + loss: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.IntTensor] = None + scores: Optional[torch.FloatTensor] = None + descriptors: Optional[torch.FloatTensor] = None + mask: Optional[torch.BoolTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +class SuperPointConvBlock(nn.Module): + def __init__( + self, config: SuperPointConfig, in_channels: int, out_channels: int, add_pooling: bool = False + ) -> None: + super().__init__() + self.conv_a = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + self.conv_b = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) if add_pooling else None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.relu(self.conv_a(hidden_states)) + hidden_states = self.relu(self.conv_b(hidden_states)) + if self.pool is not None: + hidden_states = self.pool(hidden_states) + return hidden_states + + +class SuperPointEncoder(nn.Module): + """ + SuperPoint encoder module. It is made of 4 convolutional layers with ReLU activation and max pooling, reducing the + dimensionality of the image. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__() + # SuperPoint uses 1 channel images + self.input_dim = 1 + + conv_blocks = [] + conv_blocks.append( + SuperPointConvBlock(config, self.input_dim, config.encoder_hidden_sizes[0], add_pooling=True) + ) + for i in range(1, len(config.encoder_hidden_sizes) - 1): + conv_blocks.append( + SuperPointConvBlock( + config, config.encoder_hidden_sizes[i - 1], config.encoder_hidden_sizes[i], add_pooling=True + ) + ) + conv_blocks.append( + SuperPointConvBlock( + config, config.encoder_hidden_sizes[-2], config.encoder_hidden_sizes[-1], add_pooling=False + ) + ) + self.conv_blocks = nn.ModuleList(conv_blocks) + + def forward( + self, + input, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple, BaseModelOutputWithNoAttention]: + all_hidden_states = () if output_hidden_states else None + + for conv_block in self.conv_blocks: + input = conv_block(input) + if output_hidden_states: + all_hidden_states = all_hidden_states + (input,) + output = input + if not return_dict: + return tuple(v for v in [output, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=output, + hidden_states=all_hidden_states, + ) + + +class SuperPointInterestPointDecoder(nn.Module): + """ + The SuperPointInterestPointDecoder uses the output of the SuperPointEncoder to compute the keypoint with scores. + The scores are first computed by a convolutional layer, then a softmax is applied to get a probability distribution + over the 65 possible keypoint classes. The keypoints are then extracted from the scores by thresholding and + non-maximum suppression. Post-processing is then applied to remove keypoints too close to the image borders as well + as to keep only the k keypoints with highest score. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__() + self.keypoint_threshold = config.keypoint_threshold + self.max_keypoints = config.max_keypoints + self.nms_radius = config.nms_radius + self.border_removal_distance = config.border_removal_distance + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv_score_a = nn.Conv2d( + config.encoder_hidden_sizes[-1], + config.decoder_hidden_size, + kernel_size=3, + stride=1, + padding=1, + ) + self.conv_score_b = nn.Conv2d( + config.decoder_hidden_size, config.keypoint_decoder_dim, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, encoded: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + scores = self._get_pixel_scores(encoded) + keypoints, scores = self._extract_keypoints(scores) + + return keypoints, scores + + def _get_pixel_scores(self, encoded: torch.Tensor) -> torch.Tensor: + """Based on the encoder output, compute the scores for each pixel of the image""" + scores = self.relu(self.conv_score_a(encoded)) + scores = self.conv_score_b(scores) + scores = nn.functional.softmax(scores, 1)[:, :-1] + batch_size, _, height, width = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(batch_size, height, width, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(batch_size, height * 8, width * 8) + scores = simple_nms(scores, self.nms_radius) + return scores + + def _extract_keypoints(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation. + The keypoints are in the form of relative (x, y) coordinates. + """ + _, height, width = scores.shape + + # Threshold keypoints by score value + keypoints = torch.nonzero(scores[0] > self.keypoint_threshold) + scores = scores[0][tuple(keypoints.t())] + + # Discard keypoints near the image borders + keypoints, scores = remove_keypoints_from_borders( + keypoints, scores, self.border_removal_distance, height * 8, width * 8 + ) + + # Keep the k keypoints with highest score + if self.max_keypoints >= 0: + keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints) + + # Convert (y, x) to (x, y) + keypoints = torch.flip(keypoints, [1]).to(scores.dtype) + + return keypoints, scores + + +class SuperPointDescriptorDecoder(nn.Module): + """ + The SuperPointDescriptorDecoder uses the outputs of both the SuperPointEncoder and the + SuperPointInterestPointDecoder to compute the descriptors at the keypoints locations. + + The descriptors are first computed by a convolutional layer, then normalized to have a norm of 1. The descriptors + are then interpolated at the keypoints locations. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__() + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv_descriptor_a = nn.Conv2d( + config.encoder_hidden_sizes[-1], + config.decoder_hidden_size, + kernel_size=3, + stride=1, + padding=1, + ) + self.conv_descriptor_b = nn.Conv2d( + config.decoder_hidden_size, + config.descriptor_decoder_dim, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, encoded: torch.Tensor, keypoints: torch.Tensor) -> torch.Tensor: + """Based on the encoder output and the keypoints, compute the descriptors for each keypoint""" + descriptors = self.conv_descriptor_b(self.relu(self.conv_descriptor_a(encoded))) + descriptors = nn.functional.normalize(descriptors, p=2, dim=1) + + descriptors = self._sample_descriptors(keypoints[None], descriptors[0][None], 8)[0] + + # [descriptor_dim, num_keypoints] -> [num_keypoints, descriptor_dim] + descriptors = torch.transpose(descriptors, 0, 1) + + return descriptors + + @staticmethod + def _sample_descriptors(keypoints, descriptors, scale: int = 8) -> torch.Tensor: + """Interpolate descriptors at keypoint locations""" + batch_size, num_channels, height, width = descriptors.shape + keypoints = keypoints - scale / 2 + 0.5 + divisor = torch.tensor([[(width * scale - scale / 2 - 0.5), (height * scale - scale / 2 - 0.5)]]) + divisor = divisor.to(keypoints) + keypoints /= divisor + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + kwargs = {"align_corners": True} + # [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2] + keypoints = keypoints.view(batch_size, 1, -1, 2) + descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs) + # [batch_size, descriptor_decoder_dim, num_channels, num_keypoints] -> [batch_size, descriptor_decoder_dim, num_keypoints] + descriptors = descriptors.reshape(batch_size, num_channels, -1) + descriptors = nn.functional.normalize(descriptors, p=2, dim=1) + return descriptors + + +@auto_docstring +class SuperPointPreTrainedModel(PreTrainedModel): + config: SuperPointConfig + base_model_prefix = "superpoint" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: + """ + Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same, + extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for SuperPoint. This is + a workaround for the issue discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width) + + Returns: + pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width) + + """ + return pixel_values[:, 0, :, :][:, None, :, :] + + +@auto_docstring( + custom_intro=""" + SuperPoint model outputting keypoints and descriptors. + """ +) +class SuperPointForKeypointDetection(SuperPointPreTrainedModel): + """ + SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a + SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and + Description `__ by Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. It + is a fully convolutional neural network that extracts keypoints and descriptors from an image. It is trained in a + self-supervised manner, using a combination of a photometric loss and a loss based on the homographic adaptation of + keypoints. It is made of a convolutional encoder and two decoders: one for keypoints and one for descriptors. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__(config) + + self.config = config + + self.encoder = SuperPointEncoder(config) + self.keypoint_decoder = SuperPointInterestPointDecoder(config) + self.descriptor_decoder = SuperPointDescriptorDecoder(config) + + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SuperPointKeypointDescriptionOutput]: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SuperPointForKeypointDetection + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint") + >>> model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + loss = None + if labels is not None: + raise ValueError("SuperPoint does not support training for now.") + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + pixel_values = self.extract_one_channel_pixel_values(pixel_values) + + batch_size, _, height, width = pixel_values.shape + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + + list_keypoints_scores = [ + self.keypoint_decoder(last_hidden_state[None, ...]) for last_hidden_state in last_hidden_state + ] + + list_keypoints = [keypoints_scores[0] for keypoints_scores in list_keypoints_scores] + list_scores = [keypoints_scores[1] for keypoints_scores in list_keypoints_scores] + + list_descriptors = [ + self.descriptor_decoder(last_hidden_state[None, ...], keypoints[None, ...]) + for last_hidden_state, keypoints in zip(last_hidden_state, list_keypoints) + ] + + maximum_num_keypoints = max(keypoints.shape[0] for keypoints in list_keypoints) + + keypoints = torch.zeros((batch_size, maximum_num_keypoints, 2), device=pixel_values.device) + scores = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device) + descriptors = torch.zeros( + (batch_size, maximum_num_keypoints, self.config.descriptor_decoder_dim), + device=pixel_values.device, + ) + mask = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device, dtype=torch.int) + + for i, (_keypoints, _scores, _descriptors) in enumerate(zip(list_keypoints, list_scores, list_descriptors)): + keypoints[i, : _keypoints.shape[0]] = _keypoints + scores[i, : _scores.shape[0]] = _scores + descriptors[i, : _descriptors.shape[0]] = _descriptors + mask[i, : _scores.shape[0]] = 1 + + # Convert to relative coordinates + keypoints = keypoints / torch.tensor([width, height], device=keypoints.device) + + hidden_states = encoder_outputs[1] if output_hidden_states else None + if not return_dict: + return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None) + + return SuperPointKeypointDescriptionOutput( + loss=loss, + keypoints=keypoints, + scores=scores, + descriptors=descriptors, + mask=mask, + hidden_states=hidden_states, + ) + + +__all__ = ["SuperPointForKeypointDetection", "SuperPointPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/swin/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/swin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc5871b0375e8a5553e30dce1e5dde807a1b493 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/swin/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_swin import * + from .modeling_swin import * + from .modeling_tf_swin import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/swin/configuration_swin.py b/venv/lib/python3.13/site-packages/transformers/models/swin/configuration_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..41462d51edbedbd098fcf7a97eec6e412516cd62 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/swin/configuration_swin.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swin Transformer model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class SwinConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SwinModel`]. It is used to instantiate a Swin + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Swin + [microsoft/swin-tiny-patch4-window7-224](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to add absolute position embeddings to the patch embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + encoder_stride (`int`, *optional*, defaults to 32): + Factor to increase the spatial resolution by in the decoder head for masked image modeling. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + + ```python + >>> from transformers import SwinConfig, SwinModel + + >>> # Initializing a Swin microsoft/swin-tiny-patch4-window7-224 style configuration + >>> configuration = SwinConfig() + + >>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration + >>> model = SwinModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_stride=32, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.encoder_stride = encoder_stride + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +class SwinOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["SwinConfig", "SwinOnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/swin/modeling_swin.py b/venv/lib/python3.13/site-packages/transformers/models/swin/modeling_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..18b61abbd3a4212f69ef3f777493c17798a722f9 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/swin/modeling_swin.py @@ -0,0 +1,1278 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Swin Transformer model.""" + +import collections.abc +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging, torch_int +from ...utils.backbone_utils import BackboneMixin +from .configuration_swin import SwinConfig + + +logger = logging.get_logger(__name__) + + +# drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. + + +@dataclass +@auto_docstring( + custom_intro=""" + Swin encoder's outputs, with potential hidden states and attentions. + """ +) +class SwinEncoderOutput(ModelOutput): + r""" + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Swin model's outputs that also contains a pooling of the last hidden states. + """ +) +class SwinModelOutput(ModelOutput): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Swin masked image model outputs. + """ +) +class SwinMaskedImageModelingOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Masked image modeling (MLM) loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + + +@dataclass +@auto_docstring( + custom_intro=""" + Swin outputs for image classification. + """ +) +class SwinImageClassifierOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +class SwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = SwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> tuple[torch.Tensor]: + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +class SwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor, tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +class SwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be divisible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin +class SwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class SwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + hidden_shape = (batch_size, dim, -1, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SwinModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class SwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + self.self = SwinSelfAttention(config, dim, num_heads, window_size) + self.output = SwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class SwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size) + self.drop_path = SwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = SwinIntermediate(config, dim) + self.output = SwinOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class SwinStage(GradientCheckpointingLayer): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + SwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[i], + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.layers = nn.ModuleList( + [ + SwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=SwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + always_partition: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple, SwinEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return SwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +@auto_docstring +class SwinPreTrainedModel(PreTrainedModel): + config: SwinConfig + base_model_prefix = "swin" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["SwinStage"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, SwinEmbeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, SwinSelfAttention): + module.relative_position_bias_table.data.zero_() + + +@auto_docstring +class SwinModel(SwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + r""" + add_pooling_layer (`bool`, *optional*, defaults to `True`): + Whether or not to apply pooling layer. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether or not to create and apply mask tokens in the embedding layer. + """ + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = SwinEncoder(config, self.embeddings.patch_grid) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SwinModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return SwinModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + Swin Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886). + + + + Note that we provide a script to pre-train this model on custom data in our [examples + directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining). + + + """ +) +class SwinForMaskedImageModeling(SwinPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.swin = SwinModel(config, add_pooling_layer=False, use_mask_token=True) + + num_features = int(config.embed_dim * 2 ** (config.num_layers - 1)) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1 + ), + nn.PixelShuffle(config.encoder_stride), + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SwinMaskedImageModelingOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Examples: + ```python + >>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192") + >>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 192, 192] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output.transpose(1, 2) + batch_size, num_channels, sequence_length = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[2:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return SwinMaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """ +) +class SwinForImageClassification(SwinPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.swin = SwinModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.swin.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SwinImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SwinImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + Swin backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class SwinBackbone(SwinPreTrainedModel, BackboneMixin): + def __init__(self, config: SwinConfig): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, self.embeddings.patch_grid) + + # Add layer norms to hidden states of out_features + hidden_states_norms = {} + for stage, num_channels in zip(self._out_features, self.channels): + hidden_states_norms[stage] = nn.LayerNorm(num_channels) + self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output, input_dimensions = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=None, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + always_partition=True, + return_dict=True, + ) + + hidden_states = outputs.reshaped_hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SwinForImageClassification", + "SwinForMaskedImageModeling", + "SwinModel", + "SwinPreTrainedModel", + "SwinBackbone", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/swin/modeling_tf_swin.py b/venv/lib/python3.13/site-packages/transformers/models/swin/modeling_tf_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..7fa54e958046cc2c8b9e40c69959775013755d68 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/swin/modeling_tf_swin.py @@ -0,0 +1,1639 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Swin Transformer model.""" + +from __future__ import annotations + +import collections.abc +import math +import warnings +from collections.abc import Iterable +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_swin import SwinConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SwinConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +# drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow +# implementations of PyTorch functionalities in the timm library. + + +@dataclass +class TFSwinEncoderOutput(ModelOutput): + """ + Swin encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFSwinModelOutput(ModelOutput): + """ + Swin model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: tf.Tensor | None = None + pooler_output: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFSwinMaskedImageModelingOutput(ModelOutput): + """ + Swin masked image model outputs. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Masked image modeling (MLM) loss. + reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: tf.Tensor | None = None + reconstruction: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + + +@dataclass +class TFSwinImageClassifierOutput(ModelOutput): + """ + Swin outputs for image classification. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + attentions: tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None + + +def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = shape_list(input_feature) + input_feature = tf.reshape( + input_feature, + (batch_size, height // window_size, window_size, width // window_size, window_size, num_channels), + ) + windows = tf.transpose(input_feature, (0, 1, 3, 2, 4, 5)) + windows = tf.reshape(windows, (-1, window_size, window_size, num_channels)) + return windows + + +def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor: + """ + Merges windows to produce higher resolution features. + """ + x = tf.shape(windows)[0] + y = tf.cast(height * width / (window_size * window_size), tf.int32) + batch_size = tf.math.floordiv(x, y) + windows = tf.reshape( + windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) + ) + windows = tf.transpose(windows, (0, 1, 3, 2, 4, 5)) + windows = tf.reshape(windows, (batch_size, height, width, -1)) + return windows + + +def drop_path( + input: tf.Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +) -> tf.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + input_shape = shape_list(input) + ndim = len(input_shape) + shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = tf.random.uniform(shape) + random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0) + if keep_prob > 0.0 and scale_by_keep: + random_tensor /= keep_prob + return input * random_tensor + + +class TFSwinEmbeddings(keras.layers.Layer): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None: + super().__init__(**kwargs) + self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings") + self.num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.embed_dim = config.embed_dim + self.use_mask_token = use_mask_token + self.use_absolute_embeddings = config.use_absolute_embeddings + + self.norm = keras.layers.LayerNormalization(name="norm", epsilon=1e-5) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") + self.config = config + + def build(self, input_shape: tf.TensorShape) -> None: + if self.use_mask_token: + self.mask_token = self.add_weight(shape=(1, 1, self.embed_dim), initializer="zeros", name="mask_token") + else: + self.mask_token = None + + if self.use_absolute_embeddings: + self.position_embeddings = self.add_weight( + (1, self.num_patches + 1, self.embed_dim), initializer="zeros", name="positional_embeddings" + ) + else: + self.position_embeddings = None + + if self.built: + return + self.built = True + if getattr(self, "patch_embeddings", None) is not None: + with tf.name_scope(self.patch_embeddings.name): + self.patch_embeddings.build(None) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build([None, None, self.config.embed_dim]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + def call( + self, pixel_values: tf.Tensor, bool_masked_pos: bool | None = None, training: bool = False + ) -> tuple[tf.Tensor, tuple[int, int]]: + embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training) + embeddings = self.norm(embeddings, training=training) + batch_size, seq_len, _ = shape_list(embeddings) + + if bool_masked_pos is not None: + mask_tokens = tf.repeat(self.mask_token, batch_size, 0) + mask_tokens = tf.repeat(mask_tokens, seq_len, 1) + # replace the masked visual tokens by mask_tokens + mask = tf.expand_dims(bool_masked_pos, -1) + mask = tf.cast(mask, mask_tokens.dtype) + + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings, training=training) + + return embeddings, output_dimensions + + +class TFSwinPatchEmbeddings(keras.layers.Layer): + """ + Image to Patch Embedding. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = keras.layers.Conv2D( + filters=hidden_size, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + name="projection", + ) + + def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor: + if width % self.patch_size[1] != 0: + pad_values = ((0, 0), (0, 0), (0, 0), (0, self.patch_size[1] - width % self.patch_size[1])) + pixel_values = tf.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = ((0, 0), (0, 0), (0, self.patch_size[0] - height % self.patch_size[0]), (0, 0)) + pixel_values = tf.pad(pixel_values, pad_values) + return pixel_values + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tuple[tf.Tensor, tuple[int, int]]: + _, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + + # B,C,H,W -> B,H,W,C + pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1)) + + embeddings = self.projection(pixel_values, training=training) + + # B,H,W,C -> B,C,H,W + embeddings = tf.transpose(embeddings, (0, 3, 1, 2)) + + batch_size, channels, height, width = shape_list(embeddings) + output_dimensions = (height, width) + + embeddings = tf.reshape(embeddings, (batch_size, channels, -1)) + embeddings = tf.transpose(embeddings, (0, 2, 1)) + return embeddings, output_dimensions + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + + +class TFSwinPatchMerging(keras.layers.Layer): + """ + Patch Merging Layer. + + Args: + input_resolution (`tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`keras.layer.Layer`, *optional*, defaults to `keras.layers.LayerNormalization`): + Normalization layer class. + """ + + def __init__( + self, input_resolution: tuple[int, int], dim: int, norm_layer: Callable | None = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.input_resolution = input_resolution + self.dim = dim + self.reduction = keras.layers.Dense(2 * dim, use_bias=False, name="reduction") + if norm_layer is None: + # Use same default epsilon as PyTorch + self.norm = keras.layers.LayerNormalization(epsilon=1e-5, name="norm") + else: + self.norm = norm_layer(name="norm") + + def maybe_pad(self, input_feature: tf.Tensor, height: int, width: int) -> tf.Tensor: + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = ((0, 0), (0, height % 2), (0, width % 2), (0, 0)) + input_feature = tf.pad(input_feature, pad_values) + + return input_feature + + def call(self, input_feature: tf.Tensor, input_dimensions: tuple[int, int], training: bool = False) -> tf.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, _, num_channels = shape_list(input_feature) + + input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels)) + # pad input to be divisible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = tf.concat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = tf.reshape( + input_feature, (batch_size, -1, 4 * num_channels) + ) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature, training=training) + input_feature = self.reduction(input_feature, training=training) + + return input_feature + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "reduction", None) is not None: + with tf.name_scope(self.reduction.name): + self.reduction.build([None, None, 4 * self.dim]) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build([None, None, 4 * self.dim]) + + +class TFSwinDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float | None = None, scale_by_keep: bool = True, **kwargs) -> None: + super().__init__(**kwargs) + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor: + return drop_path(input, self.drop_prob, training, self.scale_by_keep) + + +class TFSwinSelfAttention(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None: + super().__init__(**kwargs) + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + window_size = config.window_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.query = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=config.qkv_bias, + name="query", + ) + self.key = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=config.qkv_bias, + name="key", + ) + self.value = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=config.qkv_bias, + name="value", + ) + + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + + def build(self, input_shape: tf.TensorShape) -> None: + self.relative_position_bias_table = self.add_weight( + shape=(((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)), self.num_attention_heads), + initializer="zeros", + name="relative_position_bias_table", + ) + self.relative_position_index = self.add_weight( + shape=(self.window_size[0] ** 2, self.window_size[1] ** 2), + trainable=False, + dtype=tf.int32, + name="relative_position_index", + ) + + # get pair-wise relative position index for each token inside the window + coords_h = tf.range(self.window_size[0]) + coords_w = tf.range(self.window_size[1]) + coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij")) + coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1)) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = tf.transpose(relative_coords, (1, 2, 0)) + + stack_0, stack_1 = tf.unstack(relative_coords, axis=2) + stack_0 += self.window_size[0] - 1 + stack_0 *= 2 * self.window_size[1] - 1 + stack_1 += self.window_size[1] - 1 + relative_coords = tf.stack([stack_0, stack_1], axis=2) + + self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32)) + + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.all_head_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.all_head_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.all_head_size]) + + def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tuple[tf.Tensor, ...]: + batch_size, dim, _ = shape_list(hidden_states) + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, (0, 1, 3, 2))) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + relative_position_bias = tf.gather( + self.relative_position_bias_table, tf.reshape(self.relative_position_index, (-1,)) + ) + relative_position_bias = tf.reshape( + relative_position_bias, + (self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1), + ) + + relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1)) + attention_scores = attention_scores + tf.expand_dims(relative_position_bias, 0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SwinModel call() function) + mask_shape = shape_list(attention_mask)[0] + attention_scores = tf.reshape( + attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim) + ) + attention_mask = tf.expand_dims(attention_mask, 1) + attention_mask = tf.expand_dims(attention_mask, 0) + attention_scores = attention_scores + attention_mask + attention_scores = tf.reshape(attention_scores, (-1, self.num_attention_heads, dim, dim)) + + # Normalize the attention scores to probabilities. + attention_probs = tf.nn.softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = tf.matmul(attention_probs, value_layer) + context_layer = tf.transpose(context_layer, (0, 2, 1, 3)) + new_context_layer_shape = shape_list(context_layer)[:-2] + [ + self.all_head_size, + ] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class TFSwinSelfOutput(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(dim, name="dense") + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob, name="dropout") + self.dim = dim + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.dim]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +class TFSwinAttention(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None: + super().__init__(**kwargs) + self.self = TFSwinSelfAttention(config, dim, num_heads, name="self") + self.self_output = TFSwinSelfOutput(config, dim, name="output") + self.pruned_heads = set() + + def prune_heads(self, heads): + """ + Prunes heads of the model. See base class PreTrainedModel heads: dict of {layer_num: list of heads to prune in + this layer} + """ + raise NotImplementedError + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tf.Tensor: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions, training=training) + attention_output = self.self_output(self_outputs[0], hidden_states, training=training) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "self_output", None) is not None: + with tf.name_scope(self.self_output.name): + self.self_output.build(None) + + +class TFSwinIntermediate(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(int(config.mlp_ratio * dim), name="dense") + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dim = dim + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.dim]) + + +class TFSwinOutput(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(dim, name="dense") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, "dropout") + self.config = config + self.dim = dim + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, int(self.config.mlp_ratio * self.dim)]) + + +class TFSwinLayer(keras.layers.Layer): + def __init__( + self, + config, + dim, + input_resolution: tuple[int, int], + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + min_res = tf.reduce_min(input_resolution) + self.window_size = min_res if min_res <= config.window_size else config.window_size + self.shift_size = 0 if min_res <= self.window_size else shift_size + self.input_resolution = input_resolution + + self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") + self.attention = TFSwinAttention(config, dim, num_heads, name="attention") + self.drop_path = ( + TFSwinDropPath(drop_path_rate, name="drop_path") + if drop_path_rate > 0.0 + else keras.layers.Activation("linear", name="drop_path") + ) + self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + self.intermediate = TFSwinIntermediate(config, dim, name="intermediate") + self.swin_output = TFSwinOutput(config, dim, name="output") + self.dim = dim + + def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: int) -> tf.Tensor | None: + img_mask = tf.zeros((height, width)) + height_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1)) + width_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1)) + + # calculate attention mask for SW-MSA + if shift_size > 0: + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + height_inds = tf.range(height_slice[0] % height, height_slice[1] % height + 1) + width_inds = tf.range(width_slice[0] % width, width_slice[1] % width + 1) + indices = tf.reshape(tf.stack(tf.meshgrid(height_inds, width_inds), axis=-1), (-1, 2)) + if len(indices) >= 1: + updates = tf.ones((len(indices),), dtype=img_mask.dtype) * count + img_mask = tf.tensor_scatter_nd_update(img_mask, indices, updates) + count += 1 + + img_mask = tf.expand_dims(img_mask, -1) + img_mask = tf.expand_dims(img_mask, 0) + + mask_windows = window_partition(img_mask, window_size) + mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size)) + attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2) + attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask) + attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask) + return attn_mask + + def maybe_pad( + self, hidden_states: tf.Tensor, window_size: int, height: int, width: int + ) -> tuple[tf.Tensor, tf.Tensor]: + pad_right = (window_size - width % window_size) % window_size + pad_bottom = (window_size - height % window_size) % window_size + pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]] + hidden_states = tf.pad(hidden_states, pad_values) + pad_values = tf.reshape(pad_values, (-1,)) + return hidden_states, pad_values + + def call( + self, + hidden_states: tf.Tensor, + input_dimensions: tuple[int, int], + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tf.Tensor: + # if window size is larger than input resolution, we don't partition windows + min_res = tf.reduce_min(input_dimensions) + shift_size = 0 if min_res <= self.window_size else self.shift_size + window_size = min_res if min_res <= self.window_size else self.window_size + + height, width = input_dimensions + batch_size, _, channels = shape_list(hidden_states) + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states, training=training) + hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels)) + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width) + + _, height_pad, width_pad, _ = shape_list(hidden_states) + # cyclic shift + if shift_size > 0: + shifted_hidden_states = tf.roll(hidden_states, shift=(-shift_size, -shift_size), axis=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, window_size) + hidden_states_windows = tf.reshape(hidden_states_windows, (-1, window_size * window_size, channels)) + attn_mask = self.get_attn_mask( + height=height_pad, width=width_pad, window_size=window_size, shift_size=shift_size + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, training=training + ) + + attention_output = attention_outputs[0] + + attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels)) + shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad) + + # reverse cyclic shift + if shift_size > 0: + attention_windows = tf.roll(shifted_windows, shift=(shift_size, shift_size), axis=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :] + + attention_windows = tf.reshape(attention_windows, (batch_size, height * width, channels)) + + hidden_states = shortcut + self.drop_path(attention_windows, training=training) + + layer_output = self.layernorm_after(hidden_states, training=training) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.swin_output(layer_output, training=training) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layernorm_before", None) is not None: + with tf.name_scope(self.layernorm_before.name): + self.layernorm_before.build([None, None, self.dim]) + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + if getattr(self, "layernorm_after", None) is not None: + with tf.name_scope(self.layernorm_after.name): + self.layernorm_after.build([None, None, self.dim]) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "swin_output", None) is not None: + with tf.name_scope(self.swin_output.name): + self.swin_output.build(None) + + +class TFSwinStage(keras.layers.Layer): + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: tuple[int, int], + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Callable | None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.config = config + self.dim = dim + self.blocks = [ + TFSwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + drop_path_rate=drop_path[i], + name=f"blocks.{i}", + ) + for i in range(depth) + ] + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, + dim=dim, + norm_layer=partial(keras.layers.LayerNormalization, epsilon=1e-5), + name="downsample", + ) + else: + self.downsample = None + + self.pointing = False + + def call( + self, + hidden_states: tf.Tensor, + input_dimensions: tuple[int, int], + head_mask: tf.Tensor | None = None, + output_attentions: bool | None = False, + training: bool = False, + ) -> tuple[tf.Tensor, ...]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training + ) + + hidden_states = layer_outputs[0] + + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(layer_outputs[0], input_dimensions, training=training) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "downsample", None) is not None: + with tf.name_scope(self.downsample.name): + self.downsample.build(None) + if getattr(self, "blocks", None) is not None: + for layer in self.blocks: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwinEncoder(keras.layers.Layer): + def __init__(self, config: SwinConfig, grid_size: tuple[int, int], **kwargs): + super().__init__(**kwargs) + self.num_layers = len(config.depths) + self.config = config + dpr = list((tf.linspace(0, 1, sum(config.depths)) * config.drop_path_rate).numpy()) + self.layers = [ + TFSwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=TFSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + name=f"layers.{i_layer}", + ) + for i_layer in range(self.num_layers) + ] + + self.gradient_checkpointing = False + + def call( + self, + hidden_states: tf.Tensor, + input_dimensions: tuple[int, int], + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> tuple[tf.Tensor, ...] | TFSwinEncoderOutput: + all_input_dimensions = () + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = shape_list(hidden_states) + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) + reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training + ) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[1] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + + if output_hidden_states: + batch_size, _, hidden_size = shape_list(hidden_states) + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) + reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return TFSwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwinPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwinConfig + base_model_prefix = "swin" + main_input_name = "pixel_values" + + +SWIN_START_DOCSTRING = r""" + This model is a Tensorflow + [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a + regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SwinConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWIN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def normalize_data_format(value: str) -> str: + """ + From tensorflow addons + https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/utils/keras_utils.py#L71 + """ + if value is None: + value = keras.backend.image_data_format() + data_format = value.lower() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError( + 'The `data_format` argument must be one of "channels_first", "channels_last". Received: ' + str(value) + ) + return data_format + + +class AdaptiveAveragePooling1D(keras.layers.Layer): + """ + Args: + Average 1D Pooling with adaptive kernel size. + output_size: An integer or tuple/list of a single integer, specifying pooled_features. + The new size of output channels. + data_format: A string, + one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, steps, channels)` while `channels_first` corresponds + to inputs with shape `(batch, channels, steps)`. + Input shape: + - If `data_format='channels_last'`: 3D tensor with shape `(batch, steps, channels)`. + - If `data_format='channels_first'`: 3D tensor with shape `(batch, channels, steps)`. + Output shape: + - If `data_format='channels_last'`: 3D tensor with shape `(batch_size, pooled_steps, channels)`. + - If `data_format='channels_first'`: 3D tensor with shape `(batch_size, channels, pooled_steps)`. + + Adapted from [tensorflow-addon's adaptive pooling.py]( + https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/layers/adaptive_pooling.py#L90-L120 + ) + """ + + def __init__( + self, + output_size: int | Iterable[int], + reduce_function: Callable = tf.reduce_mean, + data_format: str | None = None, + **kwargs, + ) -> None: + self.data_format = normalize_data_format(data_format) + self.reduce_function = reduce_function + self.output_size = (output_size,) if isinstance(output_size, int) else tuple(output_size) + super().__init__(**kwargs) + + def call(self, inputs: tf.Tensor, *args) -> None: + bins = self.output_size[0] + if self.data_format == "channels_last": + splits = tf.split(inputs, bins, axis=1) + splits = tf.stack(splits, axis=1) + out_vect = self.reduce_function(splits, axis=2) + else: + splits = tf.split(inputs, bins, axis=2) + splits = tf.stack(splits, axis=2) + out_vect = self.reduce_function(splits, axis=3) + return out_vect + + def compute_output_shape(self, input_shape: Iterable[int]) -> tf.TensorShape: + input_shape = tf.TensorShape(input_shape).as_list() + if self.data_format == "channels_last": + shape = tf.TensorShape([input_shape[0], self.output_size[0], input_shape[2]]) + else: + shape = tf.TensorShape([input_shape[0], input_shape[1], self.output_size[0]]) + return shape + + def get_config(self) -> dict[str, Any]: + config = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_serializable +class TFSwinMainLayer(keras.layers.Layer): + config_class = SwinConfig + + def __init__( + self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs + ) -> None: + super().__init__(**kwargs) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = TFSwinEmbeddings(config, use_mask_token=use_mask_token, name="embeddings") + self.encoder = TFSwinEncoder(config, self.embeddings.patch_grid, name="encoder") + + self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + self.pooler = AdaptiveAveragePooling1D(output_size=(1,)) if add_pooling_layer else None + + def get_input_embeddings(self) -> TFSwinPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: dict[int, list]): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_head_mask(self, head_mask: Any | None) -> list: + if head_mask is not None: + raise NotImplementedError + return [None] * len(self.config.depths) + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFSwinModelOutput | tuple[tf.Tensor, ...]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, training=training + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output, training=training) + + pooled_output = None + if self.pooler is not None: + batch_size, _, num_features = shape_list(sequence_output) + pooled_output = self.pooler(sequence_output) + pooled_output = tf.reshape(pooled_output, (batch_size, num_features)) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + return output + + return TFSwinModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, self.num_features]) + + +@add_start_docstrings( + "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.", + SWIN_START_DOCSTRING, +) +class TFSwinModel(TFSwinPreTrainedModel): + def __init__( + self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs + ) -> None: + super().__init__(config, **kwargs) + self.config = config + self.swin = TFSwinMainLayer(config, name="swin") + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSwinModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFSwinModelOutput | tuple[tf.Tensor, ...]: + r""" + bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + swin_outputs = self.swin( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return swin_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "swin", None) is not None: + with tf.name_scope(self.swin.name): + self.swin.build(None) + + +class TFSwinPixelShuffle(keras.layers.Layer): + """TF layer implementation of torch.nn.PixelShuffle""" + + def __init__(self, upscale_factor: int, **kwargs) -> None: + super().__init__(**kwargs) + if not isinstance(upscale_factor, int) or upscale_factor < 2: + raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}") + self.upscale_factor = upscale_factor + + def call(self, x: tf.Tensor) -> tf.Tensor: + hidden_states = x + batch_size, _, _, num_input_channels = shape_list(hidden_states) + block_size_squared = self.upscale_factor**2 + output_depth = int(num_input_channels / block_size_squared) + # When the number of output channels >= 2, PyTorch's PixelShuffle and + # TF's depth_to_space differ in their output as the order of channels selected for combining + # is a permutation of the other c.f. + # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1 + permutation = tf.constant( + [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]] + ) + hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1) + hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC") + return hidden_states + + +class TFSwinDecoder(keras.layers.Layer): + def __init__(self, config: SwinConfig, **kwargs): + super().__init__(**kwargs) + self.conv2d = keras.layers.Conv2D( + filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0" + ) + self.pixel_shuffle = TFSwinPixelShuffle(config.encoder_stride, name="1") + self.config = config + + def call(self, x: tf.Tensor) -> tf.Tensor: + hidden_states = x + # B,C,H,W -> B,H,W,C + hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1)) + hidden_states = self.conv2d(hidden_states) + hidden_states = self.pixel_shuffle(hidden_states) + # B,H,W,C -> B,C,H,W + hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2)) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv2d", None) is not None: + with tf.name_scope(self.conv2d.name): + self.conv2d.build([None, None, None, self.config.hidden_size]) + if getattr(self, "pixel_shuffle", None) is not None: + with tf.name_scope(self.pixel_shuffle.name): + self.pixel_shuffle.build(None) + + +@add_start_docstrings( + "Swin Model with a decoder on top for masked image modeling, as proposed in" + " [SimMIM](https://huggingface.co/papers/2111.09886).", + SWIN_START_DOCSTRING, +) +class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): + def __init__(self, config: SwinConfig): + super().__init__(config) + + self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin") + + self.decoder = TFSwinDecoder(config, name="decoder") + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> tuple | TFSwinMaskedImageModelingOutput: + r""" + bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, TFSwinForMaskedImageModeling + >>> import tensorflow as tf + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") + >>> model = TFSwinForMaskedImageModeling.from_pretrained("microsoft/swin-tiny-patch4-window7-224") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5 + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 224, 224] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + # Reshape to (batch_size, num_channels, height, width) + sequence_output = tf.transpose(sequence_output, (0, 2, 1)) + batch_size, num_channels, sequence_length = shape_list(sequence_output) + height = width = int(sequence_length**0.5) + sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width)) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size)) + mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1) + mask = tf.repeat(mask, self.config.patch_size, 2) + mask = tf.expand_dims(mask, 1) + mask = tf.cast(mask, tf.float32) + + reconstruction_loss = keras.losses.mean_absolute_error( + # Swap axes as metric calculation reduces over the final dimension + tf.transpose(pixel_values, (1, 2, 3, 0)), + tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)), + ) + reconstruction_loss = tf.expand_dims(reconstruction_loss, 0) + total_loss = tf.reduce_sum(reconstruction_loss * mask) + num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels + masked_im_loss = total_loss / num_masked_pixels + masked_im_loss = tf.reshape(masked_im_loss, (1,)) + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[2:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return TFSwinMaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "swin", None) is not None: + with tf.name_scope(self.swin.name): + self.swin.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + """ + Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + SWIN_START_DOCSTRING, +) +class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: SwinConfig): + super().__init__(config) + + self.num_labels = config.num_labels + self.swin = TFSwinMainLayer(config, name="swin") + + # Classifier head + self.classifier = ( + keras.layers.Dense(config.num_labels, name="classifier") + if config.num_labels > 0 + else keras.layers.Activation("linear", name="classifier") + ) + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSwinImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + labels: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> tuple[tf.Tensor, ...] | TFSwinImageClassifierOutput: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSwinImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "swin", None) is not None: + with tf.name_scope(self.swin.name): + self.swin.build(None) + if getattr(self, "classifier", None) is not None: + if hasattr(self.classifier, "name"): + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.swin.num_features]) + + +__all__ = ["TFSwinForImageClassification", "TFSwinForMaskedImageModeling", "TFSwinModel", "TFSwinPreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/t5/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/t5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..366eab10826b19cd75e87b013f58f55976969abe --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/t5/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_t5 import * + from .modeling_flax_t5 import * + from .modeling_t5 import * + from .modeling_tf_t5 import * + from .tokenization_t5 import * + from .tokenization_t5_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/t5/configuration_t5.py b/venv/lib/python3.13/site-packages/transformers/models/t5/configuration_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..21d2552008719e47946886fa5c20b79ef98bde9f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/t5/configuration_t5.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""T5 model configuration""" + +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class T5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to + instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the T5 + [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `T5Block`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the + `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "t5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } + + def __init__( + self, + vocab_size=32128, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 + + +__all__ = ["T5Config", "T5OnnxConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_flax_t5.py b/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_flax_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..c829a084e4d96e15efd28cad90dd52f02be2d6dd --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_flax_t5.py @@ -0,0 +1,1801 @@ +# coding=utf-8 +# Copyright 2021 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax T5 model.""" + +import copy +from typing import Callable, Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-t5/t5-small" +_CONFIG_FOR_DOC = "T5Config" + +remat = nn_partitioning.remat + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +class FlaxT5LayerNorm(nn.Module): + hidden_size: int + dtype: jnp.dtype = jnp.float32 + eps: float = 1e-6 + weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + + def setup(self): + self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) + + def __call__(self, hidden_states): + """ + Construct a layernorm module in the T5 style; No bias and no subtraction of mean. + """ + # layer norm should always be calculated in float32 + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + + +class FlaxT5DenseActDense(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5DenseGatedActDense(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi_0 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wi_1 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5LayerFF(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.is_gated_act: + self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) + else: + self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) + + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__(self, hidden_states, deterministic=True): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) + hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) + return hidden_states + + +class FlaxT5Attention(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(q_init_std), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(o_init_std), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = jnp.arange(query_length, dtype="i4")[:, None] + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.causal), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) + value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions + # that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def _create_position_bias( + self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ): + cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) + key_length = key_states.shape[1] + query_length = key_length if cache_is_filled else query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) + + # if key and values are already calculated, only the last query position bias should be taken + if cache_is_filled: + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + position_bias = jax.lax.dynamic_slice( + position_bias, + (0, 0, causal_attention_mask_shift, 0), + (1, self.n_heads, seq_length, max_decoder_length), + ) + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + use_cache=False, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 + ) + # create causal attention_mask; attention_mask has to be defined when model is causal + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # fast decoding for generate requires special attention_mask + if self.has_variable("cache", "cached_key"): + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), + ) + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # replace masked positions with -10_000 + if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxT5LayerSelfAttention(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.SelfAttention = FlaxT5Attention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + causal=self.config.causal, + dtype=self.dtype, + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5LayerCrossAttention(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.EncDecAttention = FlaxT5Attention( + self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + attention_mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5Block(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + self.layer = ( + FlaxT5LayerSelfAttention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, + ), + ) + feed_forward_index = 1 + if self.causal: + self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) + feed_forward_index += 1 + + self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.causal and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = cross_attention_outputs[0] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + outputs = outputs + attention_outputs + + # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + return outputs + + +class FlaxT5LayerCollection(nn.Module): + config: T5Config + has_relative_attention_bias: bool + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxT5Block( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + return self.layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + +class FlaxT5BlockCollection(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + if self.gradient_checkpointing: + FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) + self.blocks = [ + FlaxT5CheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxT5LayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + + def __call__( + self, + hidden_states=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + init_cache: bool = False, + ): + # Prepare head mask if needed + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.causal) else None + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + + if self.causal and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.causal: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class FlaxT5Stack(nn.Module): + config: T5Config + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + + self.block = FlaxT5BlockCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.final_layer_norm = FlaxT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache: bool = False, + ): + hidden_states = self.embed_tokens(input_ids) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + # Add last layer + all_hidden_states = None + + if output_hidden_states: + all_hidden_states = outputs.hidden_states + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + if output_hidden_states: + return ( + hidden_states, + all_hidden_states, + ) + outputs[2:] + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +T5_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For training, `decoder_input_ids` should be provided. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxT5PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: T5Config, + input_shape: tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + + attention_mask = jnp.ones_like(input_ids) + args = [input_ids, attention_mask] + if self.module_class not in [FlaxT5EncoderModule]: + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + args.extend([decoder_input_ids, decoder_attention_mask]) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + *args, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: jnp.ndarray = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if decoder_input_ids is None: + raise ValueError( + "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" + " here." + ) + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=T5Config) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=T5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + +T5_START_DOCSTRING = r""" + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5Module(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxT5Stack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5Model(FlaxT5PreTrainedModel): + module_class = FlaxT5Module + + +append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + +FLAX_T5_MODEL_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="np" + ... ).input_ids + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + + +overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5EncoderModule(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.is_decoder = False + encoder_config.is_encoder_decoder = False + encoder_config.causal = False + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict: bool = True, + deterministic: bool = True, + ): + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + return encoder_outputs + + +class FlaxT5EncoderModel(FlaxT5PreTrainedModel): + module_class = FlaxT5EncoderModule + + @add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) +class FlaxT5ForConditionalGenerationModule(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.model_dim = self.config.d_model + + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = FlaxT5Stack( + encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxT5Stack( + decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = self.shared.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + decoder_outputs[1:] + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): + module_class = FlaxT5ForConditionalGenerationModule + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=T5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + decoder_outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = module.shared.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = module.lm_head(sequence_output) + + return lm_logits, decoder_outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + extended_attention_mask = jax.lax.dynamic_update_slice( + extended_attention_mask, decoder_attention_mask, (0, 0) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + +FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]).sequences + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + + +overwrite_call_docstring( + FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +__all__ = ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_t5.py b/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bdfe6e18e0651bd06de61dd04c3a567712460c --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_t5.py @@ -0,0 +1,2407 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch T5 model.""" + +import copy +import math +import os +import warnings +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + auto_docstring, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_t5 import T5Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`dict[int, list]`, *optional*): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - google-t5/t5-small: 6 + - google-t5/t5-base: 12 + - google-t5/t5-large: 24 + - google-t5/t5-3b: 24 + - google-t5/t5-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using google-t5/t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with google-t5/t5-3b: + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm") +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_values=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + is_updated = False + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_values=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(GradientCheckpointingLayer): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(T5LayerFF(config)) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return ( + outputs + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: T5Config): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +@auto_docstring +class T5PreTrainedModel(PreTrainedModel): + config: T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _can_compile_fullgraph = True + + _no_split_modules = ["T5Block"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() + elif isinstance(module, T5ForTokenClassification): + if hasattr(module, "classifier"): + module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.data.zero_() + elif isinstance(module, T5ClassificationHead): + module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.dense, "bias") and module.dense.bias is not None: + module.dense.bias.data.zero_() + module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, T5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " + "See T5 docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache( + DynamicCache(config=self.config), DynamicCache(config=self.config) + ) + else: + past_key_values = DynamicCache(config=self.config) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@auto_docstring +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" + " 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoTokenizer, T5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = T5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + T5 Model with a `language modeling` head on top. + """ +) +class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" + " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" + " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Examples: + + ```python + >>> from transformers import AutoTokenizer, T5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class T5EncoderModel(T5PreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight"] + _keys_to_ignore_on_load_unexpected = [r"decoder"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = config + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + + Example: + + ```python + >>> from transformers import AutoTokenizer, T5EncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +@auto_docstring( + custom_intro=""" + T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """ +) +class T5ForSequenceClassification(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.transformer = T5Model(config) + self.classification_head = T5ClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates + # decoder_input_ids from input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + batch_size, _, hidden_size = sequence_output.shape + sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@auto_docstring +class T5ForTokenClassification(T5PreTrainedModel): + _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = T5EncoderModel(config) + self.dropout = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, outputs[2:-1]) + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class T5ForQuestionAnswering(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.tie_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # Copied from models.bart.modeling_bart.BartModel.forward + # different to other models, T5 automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +__all__ = [ + "T5EncoderModel", + "T5ForConditionalGeneration", + "T5Model", + "T5PreTrainedModel", + "load_tf_weights_in_t5", + "T5ForQuestionAnswering", + "T5ForSequenceClassification", + "T5ForTokenClassification", +] diff --git a/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_tf_t5.py b/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_tf_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..142a0f73115e2adc2a057cdc4911d20fa131e501 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/t5/modeling_tf_t5.py @@ -0,0 +1,1676 @@ +# coding=utf-8 +# Copyright 2020 T5 Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 T5 model.""" + +from __future__ import annotations + +import copy +import itertools +import math +import warnings + +import numpy as np +import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_slice + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" + + +#################################################### +# TF 2.0 Models are constructed using Keras imperative API by sub-classing +# - keras.layers.Layer for the layers and +# - TFPreTrainedModel for the models (it-self a sub-class of keras.Model) +#################################################### + + +class TFT5LayerNorm(keras.layers.Layer): + def __init__(self, hidden_size, epsilon=1e-6, **kwargs): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__(**kwargs) + self.variance_epsilon = epsilon + self.hidden_size = hidden_size + + def build(self, input_shape): + """Build shared word embedding layer""" + self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones") + super().build(input_shape) + + def call(self, hidden_states): + variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True) + hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +class TFT5DenseActDense(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + wi_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) + ) + wo_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) + ) + self.wi = keras.layers.Dense( + config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax + self.dropout = keras.layers.Dropout(config.dropout_rate) + self.act = get_tf_activation(config.dense_act_fn) + self.config = config + + def call(self, hidden_states, training=False): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "wi", None) is not None: + with tf.name_scope(self.wi.name): + self.wi.build([None, None, self.config.d_model]) + if getattr(self, "wo", None) is not None: + with tf.name_scope(self.wo.name): + self.wo.build([None, None, self.config.d_ff]) + + +class TFT5DenseGatedActDense(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + wi_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) + ) + wo_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) + ) + self.wi_0 = keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wi_1 = keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax + self.dropout = keras.layers.Dropout(config.dropout_rate) + self.act = get_tf_activation(config.dense_act_fn) + self.config = config + + def call(self, hidden_states, training=False): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "wi_0", None) is not None: + with tf.name_scope(self.wi_0.name): + self.wi_0.build([None, None, self.config.d_model]) + if getattr(self, "wi_1", None) is not None: + with tf.name_scope(self.wi_1.name): + self.wi_1.build([None, None, self.config.d_model]) + if getattr(self, "wo", None) is not None: + with tf.name_scope(self.wo.name): + self.wo.build([None, None, self.config.d_ff]) + + +class TFT5LayerFF(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + if config.is_gated_act: + self.DenseReluDense = TFT5DenseGatedActDense(config, name="DenseReluDense") + else: + self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense") + + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def call(self, hidden_states, training=False): + normed_hidden_states = self.layer_norm(hidden_states) + dense_output = self.DenseReluDense(normed_hidden_states, training=training) + hidden_states = hidden_states + self.dropout(dense_output, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build(None) + if getattr(self, "DenseReluDense", None) is not None: + with tf.name_scope(self.DenseReluDense.name): + self.DenseReluDense.build(None) + + +class TFT5Attention(keras.layers.Layer): + NEW_ID = itertools.count() + + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.layer_id = next(TFT5Attention.NEW_ID) + self.is_decoder = config.is_decoder + self.use_cache = config.use_cache + self.has_relative_attention_bias = has_relative_attention_bias + self.output_attentions = config.output_attentions + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + q_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + ) + k_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + v_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + o_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + self.relative_attention_bias_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + + self.q = keras.layers.Dense( + self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer + ) # Update init weights as in flax + self.k = keras.layers.Dense( + self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer + ) # Update init weights as in flax + self.v = keras.layers.Dense( + self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer + ) # Update init weights as in flax + self.o = keras.layers.Dense( + self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer + ) # Update init weights as in flax + self.dropout = keras.layers.Dropout(config.dropout_rate) + + self.pruned_heads = set() + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.has_relative_attention_bias: + with tf.name_scope("relative_attention_bias"): + self.relative_attention_bias = self.add_weight( + name="embeddings", + shape=[self.relative_attention_num_buckets, self.n_heads], + initializer=self.relative_attention_bias_initializer, # Add initializer + ) + if getattr(self, "q", None) is not None: + with tf.name_scope(self.q.name): + self.q.build([None, None, self.d_model]) + if getattr(self, "k", None) is not None: + with tf.name_scope(self.k.name): + self.k.build([None, None, self.d_model]) + if getattr(self, "v", None) is not None: + with tf.name_scope(self.v.name): + self.v.build([None, None, self.d_model]) + if getattr(self, "o", None) is not None: + with tf.name_scope(self.o.name): + self.o.build([None, None, self.inner_dim]) + + def prune_heads(self, heads): + raise NotImplementedError + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + # n = -relative_position + if bidirectional: + num_buckets //= 2 + relative_buckets += ( + tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets + ) + relative_position = tf.math.abs(relative_position) + else: + relative_position = -tf.math.minimum(relative_position, 0) + # now n is in the range [0, inf) + max_exact = num_buckets // 2 + is_small = tf.math.less(relative_position, max_exact) + relative_position_if_large = max_exact + tf.cast( + tf.math.log(tf.cast(relative_position, tf.float32) / tf.cast(max_exact, tf.float32)) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact), + dtype=relative_position.dtype, + ) + relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = tf.range(query_length)[:, None] + memory_position = tf.range(key_length)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = tf.gather( + self.relative_attention_bias, relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = tf.expand_dims( + tf.transpose(values, [2, 0, 1]), axis=0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def call( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + training=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, query_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = shape_list(hidden_states)[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert len(past_key_value) == 2, ( + f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + ) + real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1] + + def shape(hidden_states): + """projection""" + return tf.transpose( + tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3) + ) + + def unshape(hidden_states): + """compute context""" + return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim)) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = tf.concat([past_key_value, hidden_states], axis=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head) + + # get key/value + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # to cope with keras serialization + if self.is_decoder and use_cache: + present_key_value_state = (key_states, value_states) + else: + present_key_value_state = None + + scores = tf.einsum( + "bnqd,bnkd->bnqk", query_states, key_states + ) # (batch_size, n_heads, query_length, key_length) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length)) + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated we want only the last query position bias + if past_key_value is not None: + if not self.has_relative_attention_bias: + position_bias = position_bias[:, :, -seq_length:, :] + else: + # we might have a padded past structure, in which case we want to fetch the position bias slice + # right after the most recently filled past index + most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0)) + position_bias = dynamic_slice( + position_bias, + (0, 0, most_recently_filled_past_index + 1, 0), + (1, self.n_heads, seq_length, real_seq_length), + ) + + if mask is not None: + position_bias = tf.cast(position_bias, dtype=mask.dtype) + position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) + + scores += position_bias + weights = stable_softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length) + weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.n_heads], + message=( + f"Head mask for a single layer should be of size {(self.n_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights + + attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) + + attn_output = self.o(unshape(attn_output)) + + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (weights,) + + return outputs + + +class TFT5LayerSelfAttention(keras.layers.Layer): + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.SelfAttention = TFT5Attention( + config, + has_relative_attention_bias=has_relative_attention_bias, + name="SelfAttention", + ) + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def call( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + training=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "SelfAttention", None) is not None: + with tf.name_scope(self.SelfAttention.name): + self.SelfAttention.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build(None) + + +class TFT5LayerCrossAttention(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.EncDecAttention = TFT5Attention( + config, + has_relative_attention_bias=False, + name="EncDecAttention", + ) + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def call( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + query_length=None, + use_cache=False, + output_attentions=False, + training=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "EncDecAttention", None) is not None: + with tf.name_scope(self.EncDecAttention.name): + self.EncDecAttention.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build(None) + + +class TFT5Block(keras.layers.Layer): + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.is_decoder = config.is_decoder + self.layer = [] + self.layer.append( + TFT5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + name="layer_._0", + ) + ) + if self.is_decoder: + self.layer.append( + TFT5LayerCrossAttention( + config, + name="layer_._1", + ) + ) + + self.layer.append(TFT5LayerFF(config, name=f"layer_._{len(self.layer)}")) + + def call( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + encoder_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + training=False, + ): + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (key / value) for cross attention' if expected_num_past_key_values == 4 else ''}. " + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + if self.is_decoder and encoder_hidden_states is not None: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = shape_list(present_key_value_state[0])[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=encoder_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = cross_attention_outputs[0] + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, training=training) + outputs = (hidden_states,) + + # Add attentions if we output them + outputs = outputs + (present_key_value_state,) + attention_outputs + return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + for layer_module in self.layer: + if hasattr(layer_module, "name"): + with tf.name_scope(layer_module.name): + layer_module.build(None) + + +#################################################### +# The full model without a specific pretrained or finetuning head is +# provided as a keras.layers.Layer usually called "TFT5MainLayer" +#################################################### +@keras_serializable +class TFT5MainLayer(keras.layers.Layer): + config_class = T5Config + + def __init__(self, config, embed_tokens=None, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.use_cache = config.use_cache + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.config = config + self.num_hidden_layers = config.num_layers + + self.block = [ + TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}") + for i in range(config.num_layers) + ] + self.final_layer_norm = TFT5LayerNorm( + config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm" + ) + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ) -> tuple: + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length + ) + + if attention_mask is None: + attention_mask = tf.fill((batch_size, mask_seq_length), 1) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = shape_list(encoder_hidden_states)[1] + encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + attention_mask = tf.cast(attention_mask, dtype=inputs_embeds.dtype) + num_dims_attention_mask = len(shape_list(attention_mask)) + if num_dims_attention_mask == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif num_dims_attention_mask == 2: + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + if past_key_values[0] is not None: + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -1e9 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # extended_attention_mask = tf.math.equal(extended_attention_mask, + # tf.transpose(extended_attention_mask, perm=(-1, -2))) + + extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 + + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + else: + encoder_extended_attention_mask = None + + present_key_value_states = () if use_cache and self.is_decoder else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds, training=training) + + for idx, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, past_key_values, (self-attention weights), + # (self-attention position bias), (cross-attention position bias), (cross-attention weights), + position_bias = layer_outputs[2] + + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + # append next layer key value states + if present_key_value_state is not None and use_cache and self.is_decoder: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + outputs = (hidden_states,) + # need to check if is decoder here as well for special cases when using keras compile + if use_cache and self.is_decoder: + outputs = outputs + (present_key_value_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + if self.is_decoder: + outputs + (all_cross_attentions,) + return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions) + + if self.is_decoder: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + else: + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build(None) + if getattr(self, "block", None) is not None: + for layer in self.block: + with tf.name_scope(layer.name): + layer.build(None) + + +#################################################### +# TFT5PreTrainedModel is a sub-class of keras.Model +# which take care of loading and saving pretrained weights +# and various common utilities. +# Here you just need to specify a few (self-explanatory) +# pointers for your model. +#################################################### +class TFT5PreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + base_model_prefix = "transformer" + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"] + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + if hasattr(self, "decoder"): + self.decoder.embed_tokens = self.shared + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the" + " pad_token_id. See T5 docs for more information" + ) + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal( + shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype) + ) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on the right or the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `inputs` for pretraining take a look at [T5 Training](./t5#training). + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for sequence to sequence training. T5 uses the `pad_token_id` as the starting token for + `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` + have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(tf.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + inputs (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on the right or the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + To know more on how to prepare `inputs` for pre-training take a look at [T5 Training](./t5#training). + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + +_HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, +num_heads))`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class TFT5Model(TFT5PreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.shared = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.d_model, + embeddings_initializer=keras.initializers.TruncatedNormal(self.config.initializer_factor), + name="shared", + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "shared" + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder") + + def get_encoder(self): + return self.encoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFSeq2SeqModelOutput: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TFT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = TFT5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + past_key_values=None, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + past = decoder_outputs[1] if use_cache else None + + if not return_dict: + if past_key_values is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=past, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) +class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model_dim = config.d_model + self.shared = keras.layers.Embedding( + config.vocab_size, + config.d_model, + name="shared", + embeddings_initializer=get_initializer(self.config.initializer_factor), + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "shared" + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder") + + if not config.tie_word_embeddings: + lm_head_initializer = keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor) + self.lm_head = keras.layers.Dense( + config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax + self.config = config + + def get_output_embeddings(self): + if self.config.tie_word_embeddings: + return self.get_input_embeddings() + else: + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + return tf.transpose(self.lm_head.kernel) + + def set_output_embeddings(self, value): + if self.config.tie_word_embeddings: + self.set_input_embeddings(value) + else: + lm_head_initializer = keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor) + self.lm_head = keras.layers.Dense( + shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + transposed_value = tf.transpose(value) + self.lm_head.kernel = transposed_value + + def get_encoder(self): + return self.encoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFSeq2SeqLMOutput: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TFT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> # training + >>> inputs = tokenizer("The walks in park", return_tensors="tf").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="tf").input_ids + >>> outputs = model(inputs, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> inputs = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(inputs) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you + ```""" + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + head_mask=decoder_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = decoder_outputs[0] + + # T5v1.1 does not tie output word embeddings and thus does not require downscaling + if self.config.tie_word_embeddings: + sequence_output = sequence_output * (self.model_dim**-0.5) + logits = tf.matmul(sequence_output, self.shared.weights, transpose_b=True) + else: + logits = self.lm_head(sequence_output) + + logits = tf.cast(logits, tf.float32) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + past = decoder_outputs[1] if use_cache else None + if not return_dict: + if past_key_values is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] + output = (logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif isinstance(encoder_outputs, tuple): + last_hidden_state = encoder_outputs[0] + hidden_states = None + attentions = None + idx = 0 + if output_hidden_states: + idx += 1 + hidden_states = encoder_outputs[idx] + if output_attentions: + idx += 1 + attentions = encoder_outputs[idx] + + encoder_outputs = TFBaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + return TFSeq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=past, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return self._shift_right(labels) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build([None, None, self.config.d_model]) + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class TFT5EncoderModel(TFT5PreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.shared = keras.layers.Embedding( + config.vocab_size, + config.d_model, + name="shared", + embeddings_initializer=get_initializer(self.config.initializer_factor), + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "shared" + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") + + def get_encoder(self): + return self.encoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> tuple | TFBaseModelOutput: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TFT5EncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = TFT5EncoderModel.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids) + ```""" + + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + past_key_values=None, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return encoder_outputs + + return TFBaseModelOutput( + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +__all__ = ["TFT5EncoderModel", "TFT5ForConditionalGeneration", "TFT5Model", "TFT5PreTrainedModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/t5/tokenization_t5.py b/venv/lib/python3.13/site-packages/transformers/models/t5/tokenization_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..0a25271345cfa3a2fc9139c06ec9eb7b59aeef23 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/t5/tokenization_t5.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Copyright 2018 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model T5.""" + +import os +import re +import warnings +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Optional + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import AddedToken + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +SPIECE_UNDERLINE = "▁" + + +@requires(backends=("sentencepiece",)) +class T5Tokenizer(PreTrainedTokenizer): + """ + Construct a T5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 100): + Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be + retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids + method + additional_special_tokens (`list[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + legacy (`bool`, *optional*): + Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + extra_ids=100, + additional_special_tokens=None, + sp_model_kwargs: Optional[dict[str, Any]] = None, + legacy=None, + add_prefix_space=True, + **kwargs, + ) -> None: + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + self._extra_ids = extra_ids + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + if additional_special_tokens is not None: + extra_tokens = [x for x in additional_special_tokens if "" for i in range(extra_ids)] + elif extra_ids > 0 and extra_ids != len(extra_tokens): + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids" + " tokens" + ) + else: + extra_tokens = [f"" for i in range(extra_ids)] + additional_special_tokens = extra_tokens + + # for legacy purpose, we keep this. Will be removed and tests updated. (when `added_tokens_decoder` is not passed as kwargs) + self._added_tokens_decoder = {} + for i in range(len(extra_tokens)): + self._added_tokens_decoder[len(self.sp_model) - 1 + extra_ids - i] = AddedToken( + f"", single_word=False, lstrip=True, rstrip=True, special=True, normalized=False + ) + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + @staticmethod + def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + if pretrained_model_name_or_path in T5Tokenizer.max_model_input_sizes: + deprecated_max_model_length = T5Tokenizer.max_model_input_sizes[pretrained_model_name_or_path] + if init_max_model_length is not None and init_max_model_length != max_model_length: + return init_max_model_length + elif init_max_model_length is None: + warnings.warn( + "This tokenizer was incorrectly instantiated with a model max length of" + f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" + " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" + " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" + f" {pretrained_model_name_or_path} automatically truncating your input to" + f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" + f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" + " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" + " instantiate this tokenizer with `model_max_length` set to your preferred value.", + FutureWarning, + ) + + return max_model_length + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search(r"", x)) is not None, self.additional_special_tokens)) + ) + + def get_sentinel_token_ids(self): + return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()] + + def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def tokenize(self, text: "TextInput", **kwargs) -> list[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["T5Tokenizer"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/t5/tokenization_t5_fast.py b/venv/lib/python3.13/site-packages/transformers/models/t5/tokenization_t5_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..bdba1a7928c8a9ef240aa6a6f694ca44e66e8c4a --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/t5/tokenization_t5_fast.py @@ -0,0 +1,233 @@ +# coding=utf-8 +# Copyright 2018 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model T5.""" + +import os +import re +import warnings +from shutil import copyfile +from typing import Optional + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_t5 import T5Tokenizer +else: + T5Tokenizer = None + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +# TODO(PVP) - this should be removed in Transformers v5 + + +class T5TokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" T5 tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 100): + Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as + "" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by + calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method + additional_special_tokens (`list[str]`, *optional*): + Additional special tokens used by the tokenizer. + add_prefix_space (`bool`, *optional*): + Whether or not the tokenizer should automatically add a prefix space + from_slow (`book`, *optional*, defaults to `False`): + Whether or not the tokenizer should be converted from a slow one. If `add_prefix_space` is set, this will be set to `True`. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = T5Tokenizer + + prefix_tokens: list[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + eos_token="", + unk_token="", + pad_token="", + extra_ids=100, + additional_special_tokens=None, + add_prefix_space=None, + **kwargs, + ): + # Add extra_ids to the special token list + if additional_special_tokens is not None: + extra_tokens = [x for x in additional_special_tokens if "" for i in range(extra_ids)] + elif extra_ids > 0 and extra_ids != len(extra_tokens): + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids" + " tokens" + ) + else: + extra_tokens = [f"" for i in range(extra_ids)] + additional_special_tokens = extra_tokens + + if add_prefix_space is not None: + logger.warning_once( + "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers" + ) + kwargs["from_slow"] = True + + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + self.vocab_file = vocab_file + self._extra_ids = extra_ids + + @staticmethod + def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + if pretrained_model_name_or_path in T5TokenizerFast.max_model_input_sizes: + deprecated_max_model_length = T5TokenizerFast.max_model_input_sizes[pretrained_model_name_or_path] + if init_max_model_length is not None and init_max_model_length != max_model_length: + return init_max_model_length + elif init_max_model_length is None: + warnings.warn( + "This tokenizer was incorrectly instantiated with a model max length of" + f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" + " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" + " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" + f" {pretrained_model_name_or_path} automatically truncating your input to" + f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" + f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" + " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" + " instantiate this tokenizer with `model_max_length` set to your preferred value.", + FutureWarning, + ) + + return max_model_length + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + logger.info(f"Copy vocab file to {out_vocab_file}") + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + else: + token_ids_1 = token_ids_1 + [self.eos_token_id] + return self.prefix_tokens + token_ids_0 + token_ids_1 + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search(r"", x)) is not None, self.additional_special_tokens)) + ) + + def get_sentinel_token_ids(self): + return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()] + + +__all__ = ["T5TokenizerFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be5cbab568a23814888cf71bcdaddc7a7b199e00 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_timm_wrapper import * + from .image_processing_timm_wrapper import * + from .modeling_timm_wrapper import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/configuration_timm_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..34e640ade8bfc9724a2c9df3257d77ee8bb4b0ea --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for TimmWrapper models""" + +from typing import Any, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import is_timm_available, logging, requires_backends + + +if is_timm_available(): + from timm.data import ImageNetInfo, infer_imagenet_subset + + +logger = logging.get_logger(__name__) + + +class TimmWrapperConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. + + It is used to instantiate a timm model according to the specified arguments, defining the model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default + imagenet models is set to `None` due to occlusions in the label descriptions. + + Args: + architecture (`str`, *optional*, defaults to `"resnet50"`): + The timm architecture to load. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + do_pooling (`bool`, *optional*, defaults to `True`): + Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. + model_args (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}` + for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`. + + Example: + ```python + >>> from transformers import TimmWrapperModel + + >>> # Initializing a timm model + >>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k") + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "timm_wrapper" + + def __init__( + self, + architecture: str = "resnet50", + initializer_range: float = 0.02, + do_pooling: bool = True, + model_args: Optional[dict[str, Any]] = None, + **kwargs, + ): + self.architecture = architecture + self.initializer_range = initializer_range + self.do_pooling = do_pooling + self.model_args = model_args # named "model_args" for BC with timm + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, config_dict: dict[str, Any], **kwargs): + label_names = config_dict.get("label_names") + is_custom_model = "num_labels" in kwargs or "id2label" in kwargs + + # if no labels added to config, use imagenet labeller in timm + if label_names is None and not is_custom_model: + requires_backends(cls, ["timm"]) + imagenet_subset = infer_imagenet_subset(config_dict) + if imagenet_subset: + dataset_info = ImageNetInfo(imagenet_subset) + synsets = dataset_info.label_names() + label_descriptions = dataset_info.label_descriptions(as_dict=True) + label_names = [label_descriptions[synset] for synset in synsets] + + if label_names is not None and not is_custom_model: + kwargs["id2label"] = dict(enumerate(label_names)) + + # if all label names are unique, create label2id mapping as well + if len(set(label_names)) == len(label_names): + kwargs["label2id"] = {name: i for i, name in enumerate(label_names)} + else: + kwargs["label2id"] = None + + # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. + # We are removing these attributes in order to have the native `transformers` num_labels attribute in config + # and to avoid duplicate attributes + num_labels_in_kwargs = kwargs.pop("num_labels", None) + num_labels_in_dict = config_dict.pop("num_classes", None) + + # passed num_labels has priority over num_classes in config_dict + kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict + + # pop num_classes from "pretrained_cfg", + # it is not necessary to have it, only root one is used in timm + if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: + config_dict["pretrained_cfg"].pop("num_classes", None) + + return super().from_dict(config_dict, **kwargs) + + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + output.setdefault("num_classes", self.num_labels) + output.setdefault("label_names", list(self.id2label.values())) + output.pop("id2label", None) + output.pop("label2id", None) + return output + + +__all__ = ["TimmWrapperConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/image_processing_timm_wrapper.py b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/image_processing_timm_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..89c8883cbd6a1680ab85a23e1059e82e9b0a143d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/image_processing_timm_wrapper.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Optional, Union + +import torch + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import to_pil_image +from ...image_utils import ImageInput, make_flat_list_of_images +from ...utils import TensorType, logging, requires_backends +from ...utils.import_utils import is_timm_available, is_torch_available, requires + + +if is_timm_available(): + import timm + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +@requires(backends=("torch", "timm", "torchvision")) +class TimmWrapperImageProcessor(BaseImageProcessor): + """ + Wrapper class for timm models to be used within transformers. + + Args: + pretrained_cfg (`dict[str, Any]`): + The configuration of the pretrained model used to resolve evaluation and + training transforms. + architecture (`Optional[str]`, *optional*): + Name of the architecture of the model. + """ + + main_input_name = "pixel_values" + + def __init__( + self, + pretrained_cfg: dict[str, Any], + architecture: Optional[str] = None, + **kwargs, + ): + requires_backends(self, "timm") + super().__init__(architecture=architecture) + + self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False) + self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False) + + # useful for training, see examples/pytorch/image-classification/run_image_classification.py + self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True) + + # If `ToTensor` is in the transforms, then the input should be numpy array or PIL image. + # Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used + # which can handle both numpy arrays / PIL images and tensors. + self._not_supports_tensor_input = any( + transform.__class__.__name__ == "ToTensor" for transform in self.val_transforms.transforms + ) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + """ + output = super().to_dict() + output.pop("train_transforms", None) + output.pop("val_transforms", None) + output.pop("_not_supports_tensor_input", None) + return output + + @classmethod + def get_image_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Get the image processor dict for the model. + """ + image_processor_filename = kwargs.pop("image_processor_filename", "config.json") + return super().get_image_processor_dict( + pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs + ) + + def preprocess( + self, + images: ImageInput, + return_tensors: Optional[Union[str, TensorType]] = "pt", + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. + """ + if return_tensors != "pt": + raise ValueError(f"return_tensors for TimmWrapperImageProcessor must be 'pt', but got {return_tensors}") + + if self._not_supports_tensor_input and isinstance(images, torch.Tensor): + images = images.cpu().numpy() + + # If the input is a torch tensor, then no conversion is needed + # Otherwise, we need to pass in a list of PIL images + if isinstance(images, torch.Tensor): + images = self.val_transforms(images) + # Add batch dimension if a single image + images = images.unsqueeze(0) if images.ndim == 3 else images + else: + images = make_flat_list_of_images(images) + images = [to_pil_image(image) for image in images] + images = torch.stack([self.val_transforms(image) for image in images]) + + return BatchFeature({"pixel_values": images}, tensor_type=return_tensors) + + def save_pretrained(self, *args, **kwargs): + # disable it to make checkpoint the same as in `timm` library. + logger.warning_once( + "The `save_pretrained` method is disabled for TimmWrapperImageProcessor. " + "The image processor configuration is saved directly in `config.json` when " + "`save_pretrained` is called for saving the model." + ) + + +__all__ = ["TimmWrapperImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/modeling_timm_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d388ff05297ff1c01e7f9c1fe3353763088e51a6 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -0,0 +1,388 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import Tensor, nn + +from ...modeling_outputs import ImageClassifierOutput, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_timm_available, requires_backends +from .configuration_timm_wrapper import TimmWrapperConfig + + +if is_timm_available(): + import timm + + +@dataclass +@auto_docstring( + custom_intro=""" + Output class for models TimmWrapperModel, containing the last hidden states, an optional pooled output, + and optional hidden states. + """ +) +class TimmWrapperModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor`): + The last hidden state of the model, output before applying the classification head. + pooler_output (`torch.FloatTensor`, *optional*): + The pooled output derived from the last hidden state, if applicable. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`): + A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True` is set or if `config.output_attentions=True`.): + A tuple containing the intermediate attention weights of the model at the output of each layer. + Note: Currently, Timm models do not support attentions output. + """ + + last_hidden_state: torch.FloatTensor + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +def _create_timm_model_with_error_handling(config: "TimmWrapperConfig", **model_kwargs): + """ + Creates a timm model and provides a clear error message if the model is not found, + suggesting a library update. + """ + try: + model = timm.create_model( + config.architecture, + pretrained=False, + **model_kwargs, + ) + return model + except RuntimeError as e: + if "Unknown model" in str(e): + # A good general check for unknown models. + raise ImportError( + f"The model architecture '{config.architecture}' is not supported in your version of timm ({timm.__version__}). " + "Please upgrade timm to a more recent version with `pip install -U timm`." + ) from e + raise e + + +@auto_docstring +class TimmWrapperPreTrainedModel(PreTrainedModel): + main_input_name = "pixel_values" + config: TimmWrapperConfig + _no_split_modules = [] + model_tags = ["timm"] + + # used in Trainer to avoid passing `loss_kwargs` to model forward + accepts_loss_kwargs = False + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision", "timm"]) + super().__init__(*args, **kwargs) + + def post_init(self): + self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing() + super().post_init() + + @staticmethod + def _fix_state_dict_key_on_load(key) -> tuple[str, bool]: + """ + Overrides original method that renames `gamma` and `beta` to `weight` and `bias`. + We don't want this behavior for timm wrapped models. Instead, this method adds a + "timm_model." prefix to enable loading official timm Hub checkpoints. + """ + if "timm_model." not in key: + return f"timm_model.{key}", True + return key, False + + def _fix_state_dict_key_on_save(self, key): + """ + Overrides original method to remove "timm_model." prefix from state_dict keys. + Makes the saved checkpoint compatible with the `timm` library. + """ + return key.replace("timm_model.", ""), True + + def load_state_dict(self, state_dict, *args, **kwargs): + """ + Override original method to fix state_dict keys on load for cases when weights are loaded + without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint). + """ + state_dict = {self._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()} + return super().load_state_dict(state_dict, *args, **kwargs) + + def _init_weights(self, module): + """ + Initialize weights function to properly initialize Linear layer weights. + Since model architectures may vary, we assume only the classifier requires + initialization, while all other weights should be loaded from the checkpoint. + """ + if isinstance(module, (nn.Linear)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + def _timm_model_supports_gradient_checkpointing(self): + """ + Check if the timm model supports gradient checkpointing by checking if the `set_grad_checkpointing` method is available. + Some timm models will have the method but will raise an AssertionError when called so in this case we return False. + """ + if not hasattr(self.timm_model, "set_grad_checkpointing"): + return False + + try: + self.timm_model.set_grad_checkpointing(enable=True) + self.timm_model.set_grad_checkpointing(enable=False) + return True + except Exception: + return False + + def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs): + self.timm_model.set_grad_checkpointing(enable) + + +class TimmWrapperModel(TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used in transformers. + """ + + def __init__(self, config: TimmWrapperConfig): + super().__init__(config) + # using num_classes=0 to avoid creating classification head + extra_init_kwargs = config.model_args or {} + self.features_only = extra_init_kwargs.get("features_only", False) + self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[Union[bool, list[int]]] = None, + return_dict: Optional[bool] = None, + do_pooling: Optional[bool] = None, + **kwargs, + ) -> Union[TimmWrapperModelOutput, tuple[Tensor, ...]]: + r""" + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. + do_pooling (`bool`, *optional*): + Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the + `do_pooling` value from the config is used. + + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> from urllib.request import urlopen + >>> from transformers import AutoModel, AutoImageProcessor + + >>> # Load image + >>> image = Image.open(urlopen( + ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' + ... )) + + >>> # Load model and image processor + >>> checkpoint = "timm/resnet50.a1_in1k" + >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) + >>> model = AutoModel.from_pretrained(checkpoint).eval() + + >>> # Preprocess image + >>> inputs = image_processor(image) + + >>> # Forward pass + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # Get pooled output + >>> pooled_output = outputs.pooler_output + + >>> # Get last hidden state + >>> last_hidden_state = outputs.last_hidden_state + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + do_pooling = do_pooling if do_pooling is not None else self.config.do_pooling + + if output_attentions: + raise ValueError("Cannot set `output_attentions` for timm models.") + + if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"): + raise ValueError( + "The 'output_hidden_states' option cannot be set for this timm model. " + "To enable this feature, the 'forward_intermediates' method must be implemented " + "in the timm model (available in timm versions > 1.*). Please consider using a " + "different architecture or updating the timm package to a compatible version." + ) + + pixel_values = pixel_values.to(self.device, self.dtype) + + if self.features_only: + last_hidden_state = self.timm_model.forward(pixel_values, **kwargs) + hidden_states = last_hidden_state if output_hidden_states else None + pooler_output = None + else: + if output_hidden_states: + # to enable hidden states selection + if isinstance(output_hidden_states, (list, tuple)): + kwargs["indices"] = output_hidden_states + last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) + else: + last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs) + hidden_states = None + + if do_pooling: + # classification head is not created, applying pooling only + pooler_output = self.timm_model.forward_head(last_hidden_state) + else: + pooler_output = None + + if not return_dict: + outputs = (last_hidden_state, pooler_output, hidden_states) + outputs = tuple(output for output in outputs if output is not None) + return outputs + + return TimmWrapperModelOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=hidden_states, + ) + + +class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used in transformers for image classification. + """ + + def __init__(self, config: TimmWrapperConfig): + super().__init__(config) + + if config.num_labels == 0: + raise ValueError( + "You are trying to load weights into `TimmWrapperForImageClassification` from a checkpoint with no classifier head. " + "Please specify the number of classes, e.g. `model = TimmWrapperForImageClassification.from_pretrained(..., num_labels=10)`, " + "or use `TimmWrapperModel` for feature extraction." + ) + + extra_init_kwargs = config.model_args or {} + self.timm_model = _create_timm_model_with_error_handling( + config, num_classes=config.num_labels, **extra_init_kwargs + ) + self.num_labels = config.num_labels + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[Union[bool, list[int]]] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[ImageClassifierOutput, tuple[Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + **kwargs: + Additional keyword arguments passed along to the `timm` model forward. + + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> from urllib.request import urlopen + >>> from transformers import AutoModelForImageClassification, AutoImageProcessor + + >>> # Load image + >>> image = Image.open(urlopen( + ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' + ... )) + + >>> # Load model and image processor + >>> checkpoint = "timm/resnet50.a1_in1k" + >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) + >>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval() + + >>> # Preprocess image + >>> inputs = image_processor(image) + + >>> # Forward pass + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # Get top 5 predictions + >>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if output_attentions: + raise ValueError("Cannot set `output_attentions` for timm models.") + + if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"): + raise ValueError( + "The 'output_hidden_states' option cannot be set for this timm model. " + "To enable this feature, the 'forward_intermediates' method must be implemented " + "in the timm model (available in timm versions > 1.*). Please consider using a " + "different architecture or updating the timm package to a compatible version." + ) + + pixel_values = pixel_values.to(self.device, self.dtype) + + if output_hidden_states: + # to enable hidden states selection + if isinstance(output_hidden_states, (list, tuple)): + kwargs["indices"] = output_hidden_states + last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) + logits = self.timm_model.forward_head(last_hidden_state) + else: + logits = self.timm_model(pixel_values, **kwargs) + hidden_states = None + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + outputs = (loss, logits, hidden_states) + outputs = tuple(output for output in outputs if output is not None) + return outputs + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) + + +__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/tvp/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/tvp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7033f7a2e7f332115bd93eadba2761c269747709 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/tvp/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_tvp import * + from .image_processing_tvp import * + from .image_processing_tvp_fast import * + from .modeling_tvp import * + from .processing_tvp import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/tvp/configuration_tvp.py b/venv/lib/python3.13/site-packages/transformers/models/tvp/configuration_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..be7c785084d17c4e35720ee353bb1b217eff95f7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/tvp/configuration_tvp.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TVP model configuration""" + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class TvpConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TvpModel`]. It is used to instantiate an Tvp + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Tvp + [Intel/tvp-base](https://huggingface.co/Intel/tvp-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + distance_loss_weight (`float`, *optional*, defaults to 1.0): + The weight of distance loss. + duration_loss_weight (`float`, *optional*, defaults to 0.1): + The weight of duration loss. + visual_prompter_type (`str`, *optional*, defaults to `"framepad"`): + Visual prompt type. The type of padding. Framepad means padding on each frame. Should be one of "framepad" + or "framedownpad" + visual_prompter_apply (`str`, *optional*, defaults to `"replace"`): + The way of applying visual prompt. Replace means use the value of prompt to change the original value in + visual inputs. Should be one of "replace", or "add", or "remove". + visual_prompt_size (`int`, *optional*, defaults to 96): + The size of visual prompt. + max_img_size (`int`, *optional*, defaults to 448): + The maximum size of frame. + num_frames (`int`, *optional*, defaults to 48): + The number of frames extracted from a video. + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Tvp text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`TvpModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + max_grid_col_position_embeddings (`int`, *optional*, defaults to 100): + The largest number of horizontal patches from a video frame. + max_grid_row_position_embeddings (`int`, *optional*, defaults to 100): + The largest number of vertical patches from a video frame. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability of hidden layers. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability of attention layers. + """ + + model_type = "tvp" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + distance_loss_weight=1.0, + duration_loss_weight=0.1, + visual_prompter_type="framepad", + visual_prompter_apply="replace", + visual_prompt_size=96, + max_img_size=448, + num_frames=48, + vocab_size=30522, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=512, + max_grid_col_position_embeddings=100, + max_grid_row_position_embeddings=100, + hidden_dropout_prob=0.1, + hidden_act="gelu", + layer_norm_eps=1e-12, + initializer_range=0.02, + attention_probs_dropout_prob=0.1, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.distance_loss_weight = distance_loss_weight + self.duration_loss_weight = duration_loss_weight + self.visual_prompter_type = visual_prompter_type + self.visual_prompter_apply = visual_prompter_apply + self.visual_prompt_size = visual_prompt_size + self.max_img_size = max_img_size + self.num_frames = num_frames + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.max_grid_col_position_embeddings = max_grid_col_position_embeddings + self.max_grid_row_position_embeddings = max_grid_row_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_dropout_prob = hidden_dropout_prob + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = attention_probs_dropout_prob + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + @classmethod + def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`TvpConfig`] (or a derived class) from a pre-trained backbone model configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + Returns: + [`TvpConfig`]: An instance of a configuration object + """ + return cls(backbone_config=backbone_config, **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + +__all__ = ["TvpConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/tvp/image_processing_tvp.py b/venv/lib/python3.13/site-packages/transformers/models/tvp/image_processing_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..0d4ee226525de8127a218629ca9cf12695f5e18f --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/tvp/image_processing_tvp.py @@ -0,0 +1,480 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for TVP.""" + +from collections.abc import Iterable +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + PaddingMode, + flip_channel_order, + pad, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + is_valid_image, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.vivit.image_processing_vivit.make_batched +def make_batched(videos) -> list[list[ImageInput]]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + return [videos] + + elif is_valid_image(videos): + return [[videos]] + + raise ValueError(f"Could not make batched video from {videos}") + + +def get_resize_output_image_size( + input_image: np.ndarray, + max_size: int = 448, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[int, int]: + height, width = get_image_size(input_image, input_data_format) + if height >= width: + ratio = width * 1.0 / height + new_height = max_size + new_width = new_height * ratio + else: + ratio = height * 1.0 / width + new_width = max_size + new_height = new_width * ratio + size = (int(new_height), int(new_width)) + + return size + + +class TvpImageProcessor(BaseImageProcessor): + r""" + Constructs a Tvp image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"longest_edge": 448}`): + Size of the output image after resizing. The longest edge of the image will be resized to + `size["longest_edge"]` while maintaining the aspect ratio of the original image. Can be overridden by + `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop` + parameter in the `preprocess` method. + crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter + in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` method. + pad_size (`dict[str, int]`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the image after applying the padding. Can be overridden by the `pad_size` parameter in the + `preprocess` method. + constant_values (`Union[float, Iterable[float]]`, *optional*, defaults to 0): + The fill value to use when padding the image. + pad_mode (`PaddingMode`, *optional*, defaults to `PaddingMode.CONSTANT`): + Use what kind of mode in padding. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + do_flip_channel_order (`bool`, *optional*, defaults to `True`): + Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` + parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_pad: bool = True, + pad_size: Optional[dict[str, int]] = None, + constant_values: Union[float, Iterable[float]] = 0, + pad_mode: PaddingMode = PaddingMode.CONSTANT, + do_normalize: bool = True, + do_flip_channel_order: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 448} + crop_size = crop_size if crop_size is not None else {"height": 448, "width": 448} + pad_size = pad_size if pad_size is not None else {"height": 448, "width": 448} + + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_pad = do_pad + self.pad_size = pad_size + self.constant_values = constant_values + self.pad_mode = pad_mode + self.do_normalize = do_normalize + self.do_flip_channel_order = do_flip_channel_order + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will + have the size `(h, w)`. If `size` is of the form `{"longest_edge": s}`, the output image will have its + longest edge of length `s` while keeping the aspect ratio of the original image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=False) + if "height" in size and "width" in size: + output_size = (size["height"], size["width"]) + elif "longest_edge" in size: + output_size = get_resize_output_image_size(image, size["longest_edge"], input_data_format) + else: + raise ValueError(f"Size must have 'height' and 'width' or 'longest_edge' as keys. Got {size.keys()}") + + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def pad_image( + self, + image: np.ndarray, + pad_size: Optional[dict[str, int]] = None, + constant_values: Union[float, Iterable[float]] = 0, + pad_mode: PaddingMode = PaddingMode.CONSTANT, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Pad an image with zeros to the given size. + + Args: + image (`np.ndarray`): + Image to pad. + pad_size (`dict[str, int]`) + Size of the output image with pad. + constant_values (`Union[float, Iterable[float]]`) + The fill value to use when padding the image. + pad_mode (`PaddingMode`) + The pad mode, default to PaddingMode.CONSTANT + data_format (`ChannelDimension` or `str`, *optional*) + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + height, width = get_image_size(image, channel_dim=input_data_format) + max_height = pad_size.get("height", height) + max_width = pad_size.get("width", width) + + pad_right, pad_bottom = max_width - width, max_height - height + if pad_right < 0 or pad_bottom < 0: + raise ValueError("The padding size must be greater than image size") + + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=pad_mode, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + + return padded_image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_pad: bool = True, + pad_size: Optional[dict[str, int]] = None, + constant_values: Optional[Union[float, Iterable[float]]] = None, + pad_mode: Optional[PaddingMode] = None, + do_normalize: Optional[bool] = None, + do_flip_channel_order: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """Preprocesses a single image.""" + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image.astype(np.float32), mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + if do_pad: + image = self.pad_image( + image=image, + pad_size=pad_size, + constant_values=constant_values, + pad_mode=pad_mode, + input_data_format=input_data_format, + ) + + # the pretrained checkpoints assume images are BGR, not RGB + if do_flip_channel_order: + image = flip_channel_order(image=image, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + videos: Union[ImageInput, list[ImageInput], list[list[ImageInput]]], + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[dict[str, int]] = None, + constant_values: Optional[Union[float, Iterable[float]]] = None, + pad_mode: Optional[PaddingMode] = None, + do_normalize: Optional[bool] = None, + do_flip_channel_order: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + videos (`ImageInput` or `list[ImageInput]` or `list[list[ImageInput]]`): + Frames to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after applying resize. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`): + Whether to centre crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after applying the centre crop. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` method. + pad_size (`dict[str, int]`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the image after applying the padding. Can be overridden by the `pad_size` parameter in the + `preprocess` method. + constant_values (`Union[float, Iterable[float]]`, *optional*, defaults to 0): + The fill value to use when padding the image. + pad_mode (`PaddingMode`, *optional*, defaults to "PaddingMode.CONSTANT"): + Use what kind of mode in padding. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): + Whether to flip the channel order of the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the inferred channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_pad = do_pad if do_pad is not None else self.do_pad + pad_size = pad_size if pad_size is not None else self.pad_size + constant_values = constant_values if constant_values is not None else self.constant_values + pad_mode = pad_mode if pad_mode else self.pad_mode + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_flip_channel_order = ( + do_flip_channel_order if do_flip_channel_order is not None else self.do_flip_channel_order + ) + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + + if not valid_images(videos): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + videos = make_batched(videos) + + videos = [ + np.array( + [ + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_pad=do_pad, + pad_size=pad_size, + constant_values=constant_values, + pad_mode=pad_mode, + do_normalize=do_normalize, + do_flip_channel_order=do_flip_channel_order, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in video + ] + ) + for video in videos + ] + + data = {"pixel_values": videos} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["TvpImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/tvp/image_processing_tvp_fast.py b/venv/lib/python3.13/site-packages/transformers/models/tvp/image_processing_tvp_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..5d74e6efb71fd447e0fce8cdad3ba01cddc63624 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/tvp/image_processing_tvp_fast.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for TVP.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + SizeDict, + make_nested_list_of_images, +) +from ...processing_utils import Unpack +from ...utils import TensorType, auto_docstring + + +class TvpFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + do_flip_channel_order (`bool`, *optional*): + Whether to flip the channel order of the image from RGB to BGR. + constant_values (`float` or `List[float]`, *optional*): + Value used to fill the padding area when `pad_mode` is `'constant'`. + pad_mode (`str`, *optional*): + Padding mode to use — `'constant'`, `'edge'`, `'reflect'`, or `'symmetric'`. + """ + + do_flip_channel_order: Optional[bool] + constant_values: Optional[Union[float, list[float]]] + pad_mode: Optional[str] + + +@auto_docstring +class TvpImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"longest_edge": 448} + default_to_square = False + crop_size = {"height": 448, "width": 448} + do_resize = True + do_center_crop = True + do_rescale = True + rescale_factor = 1 / 255 + do_pad = True + pad_size = {"height": 448, "width": 448} + constant_values = 0 + pad_mode = "constant" + do_normalize = True + do_flip_channel_order = True + valid_kwargs = TvpFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[TvpFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @auto_docstring + def preprocess( + self, + videos: Union[ImageInput, list[ImageInput], list[list[ImageInput]]], + **kwargs: Unpack[TvpFastImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(videos, **kwargs) + + def _prepare_images_structure( + self, + images: ImageInput, + **kwargs, + ) -> ImageInput: + """ + Prepare the images structure for processing. + + Args: + images (`ImageInput`): + The input images to process. + + Returns: + `ImageInput`: The images with a valid nesting. + """ + return make_nested_list_of_images(images, **kwargs) + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + interpolation: Optional["F.InterpolationMode"] = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image to the specified size. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict` or `dict`): + Size dictionary. If `size` has `longest_edge`, resize the longest edge to that value + while maintaining aspect ratio. Otherwise, use the base class resize method. + interpolation (`F.InterpolationMode`, *optional*): + Interpolation method to use. + antialias (`bool`, *optional*, defaults to `True`): + Whether to use antialiasing. + + Returns: + `torch.Tensor`: The resized image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + + # Handle longest_edge case (TVP-specific) + if size.longest_edge: + # Get current dimensions + current_height, current_width = image.shape[-2:] + + # Calculate new dimensions maintaining aspect ratio + if current_height >= current_width: + ratio = current_width * 1.0 / current_height + new_height = size.longest_edge + new_width = int(new_height * ratio) + else: + ratio = current_height * 1.0 / current_width + new_width = size.longest_edge + new_height = int(new_width * ratio) + + return super().resize( + image, SizeDict(height=new_height, width=new_width), interpolation=interpolation, antialias=antialias + ) + + # Use base class resize method for other cases + return super().resize(image, size, interpolation, antialias, **kwargs) + + def _flip_channel_order(self, frames: "torch.Tensor") -> "torch.Tensor": + """ + Flip channel order from RGB to BGR. + + The slow processor puts the red channel at the end (BGR format), + but the channel order is different. We need to match the exact + channel order of the slow processor: + + Slow processor: + - Channel 0: Blue (originally Red) + - Channel 1: Green + - Channel 2: Red (originally Blue) + """ + # Assuming frames are in channels_first format (..., C, H, W) + frames = frames.flip(-3) + + return frames + + def _preprocess( + self, + images: list[list["torch.Tensor"]], + do_resize: bool, + size: Union[SizeDict, dict], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: Union[SizeDict, dict], + do_rescale: bool, + rescale_factor: float, + do_pad: bool, + pad_size: SizeDict, + constant_values: Union[float, list[float]], + pad_mode: str, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_flip_channel_order: bool, + return_tensors: Optional[Union[str, TensorType]], + disable_grouping: Optional[bool], + **kwargs, + ) -> BatchFeature: + """ + Preprocess videos using the fast image processor. + + This method processes each video frame through the same pipeline as the original + TVP image processor but uses torchvision operations for better performance. + """ + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) + processed_images_grouped = {} + for shape, stacked_frames in grouped_images.items(): + # Resize if needed + if do_resize: + stacked_frames = self.resize(stacked_frames, size, interpolation) + + # Center crop if needed + if do_center_crop: + stacked_frames = self.center_crop(stacked_frames, crop_size) + + # Rescale and normalize using fused method for consistency + stacked_frames = self.rescale_and_normalize( + stacked_frames, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + + # Pad if needed + if do_pad: + stacked_frames = self.pad(stacked_frames, pad_size, fill_value=constant_values, pad_mode=pad_mode) + stacked_frames = torch.stack(stacked_frames, dim=0) + + # Flip channel order if needed (RGB to BGR) + if do_flip_channel_order: + stacked_frames = self._flip_channel_order(stacked_frames) + + processed_images_grouped[shape] = stacked_frames + + processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True) + if return_tensors == "pt": + processed_images = [torch.stack(images, dim=0) for images in processed_images] + processed_images = torch.stack(processed_images, dim=0) + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["TvpImageProcessorFast"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/tvp/modeling_tvp.py b/venv/lib/python3.13/site-packages/transformers/models/tvp/modeling_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8b626d2dd2c4caa83bec51a54284d17e841fd7 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/tvp/modeling_tvp.py @@ -0,0 +1,923 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TVP Model""" + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import prune_linear_layer +from ...utils import auto_docstring, logging +from ...utils.backbone_utils import load_backbone +from .configuration_tvp import TvpConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring +class TvpVideoGroundingOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Temporal-Distance IoU loss for video grounding. + logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the + input texts. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class TvpLoss(nn.Module): + """ + This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute + hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched + ground-truth / prediction (supervise class and box). + + Args: + losses (`list[str]`): + List of all the losses to be applied. + """ + + def __init__(self, losses): + super().__init__() + self.loss_map = { + "iou": self.loss_iou, + "distance": self.loss_distance, + "duration": self.loss_duration, + } + for loss in losses: + if loss not in self.loss_map: + raise ValueError(f"Loss {loss} not supported") + + self.losses = losses + + def loss_iou(self, start_time, end_time, candidates_start_time, candidates_end_time, duration): + """ + Measure the intersection over union. + """ + inter = torch.min(candidates_end_time, end_time) - torch.max(candidates_start_time, start_time) + union = torch.max(candidates_end_time, end_time) - torch.min(candidates_start_time, start_time) + iou = 1 - inter.clamp(min=0) / union + + return iou + + def loss_distance(self, start_time, end_time, candidates_start_time, candidates_end_time, duration): + """ + Measure the distance of mid points. + """ + mid_candidates = torch.div(torch.add(candidates_start_time, candidates_end_time), 2.0) + mid_groundtruth = torch.div(torch.add(start_time, end_time), 2.0) + distance_diff = torch.div( + torch.max(mid_candidates, mid_groundtruth) - torch.min(mid_candidates, mid_groundtruth), duration + ).clamp(min=0.2) + + return distance_diff + + def loss_duration(self, start_time, end_time, candidates_start_time, candidates_end_time, duration): + """ + Measure the difference of duration. + """ + duration_candidates = torch.sub(candidates_end_time, candidates_start_time) + duration_groundtruth = torch.sub(end_time, start_time) + duration_diff = torch.square(torch.div(torch.sub(duration_candidates, duration_groundtruth), duration)) + duration_diff = duration_diff.clamp(min=0.4) + + return duration_diff + + def forward(self, logits, labels): + """ + This performs the loss computation. + + Args: + logits (`torch.FloatTensor`): + The output logits of head module. + labels (`list[torch.FloatTensor]`): + List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration. + """ + duration, start_time, end_time = labels + candidates = torch.mul(logits, duration) + candidates_start_time, candidates_end_time = candidates[:, 0].float(), candidates[:, 1].float() + + losses_dict = {} + for loss in self.losses: + losses_dict.update( + {loss: self.loss_map[loss](start_time, end_time, candidates_start_time, candidates_end_time, duration)} + ) + + return losses_dict + + +class TvpVisionModel(nn.Module): + def __init__(self, config): + super().__init__() + self.backbone = load_backbone(config) + + if config.backbone_config is not None: + in_channels = config.backbone_config.hidden_sizes[-1] + elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_sizes"): + in_channels = self.backbone.config.hidden_sizes[-1] + elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_size"): + in_channels = self.backbone.config.hidden_size + else: + raise ValueError("Backbone config not found") + + self.grid_encoder_conv = nn.Conv2d( + in_channels, + config.hidden_size, + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + ) + + def forward(self, pixel_values): + batch_size, num_frames, num_channels, height, width = pixel_values.shape + # (batch_size * num_frames, num_channels, height, width) + pixel_values = pixel_values.view(batch_size * num_frames, num_channels, height, width) + grid_feat_outputs = self.backbone(pixel_values)["feature_maps"][0] + grid = self.grid_encoder_conv(grid_feat_outputs) + grid = nn.functional.max_pool2d(grid, kernel_size=2, stride=2) + grid = nn.functional.relu(grid, inplace=True) + new_channel, new_height, new_width = grid.shape[-3:] + # (batch_size, num_frames, num_channels, height, width) + grid = grid.view(batch_size, num_frames, new_channel, new_height, new_width) + # (batch_size, num_frames, height, width, num_channels) + grid = grid.permute(0, 1, 3, 4, 2) + return grid + + +class TvpVisualInputEmbedding(nn.Module): + """ + Takes input of both image and video (multi-frame) + """ + + def __init__(self, config): + super().__init__() + # sequence embedding + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.row_position_embeddings = nn.Embedding(config.max_grid_row_position_embeddings, config.hidden_size) + self.col_position_embeddings = nn.Embedding(config.max_grid_col_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(1, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings + self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings + + def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high + resolution images (high resolution videos). + + """ + h0 = w0 = 1 + # if height dimension is to be interpolated + if height > self.max_grid_row_position_embeddings: + h0 = height / self.max_grid_row_position_embeddings + # if width dimension is to be interpolated + if width > self.max_grid_col_position_embeddings: + w0 = width / self.max_grid_col_position_embeddings + embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width) + embedding = nn.functional.interpolate( + embedding, + scale_factor=(h0, w0), + mode="bicubic", + align_corners=False, + ) + embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim) + return embedding + + def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False): + """ + Args: + grid: (batch_size, height, width, hidden_dim) + interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + Returns: + grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim) + """ + batch_size, height, width, hidden_dim = grid.shape + + # add row-wise position embeddings + # (height, ) + row_height = min(self.max_grid_row_position_embeddings, height) + row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device) + # (height, hidden_dim) + row_position_embeddings = self.row_position_embeddings(row_position_ids) + row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim) + # (batch_size, height, 1, hidden_dim) + row_position_embeddings = row_position_embeddings.view(*row_shape) + + # add column-wise position embeddings + row_width = min(self.max_grid_col_position_embeddings, width) + col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device) + # (width, hidden_dim) + col_position_embeddings = self.col_position_embeddings(col_position_ids) + col_shape = (batch_size, 1, row_width, hidden_dim) + # (batch_size, 1, width, hidden_dim) + col_position_embeddings = col_position_embeddings.view(*col_shape) + # (batch_size, height, width, hidden_dim) + positional_embeddings = row_position_embeddings + col_position_embeddings + + # This interpolation gets triggered ONLY when the input image dim is larger in any dimension than the original position embeddings + if interpolate_pos_encoding and ( + height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings + ): + grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width) + else: + grid = grid + positional_embeddings + return grid + + def forward(self, grid, interpolate_pos_encoding: bool = False): + """ + Args: + grid: Array of shape (batch_size, num_frames, height, width, num_channels). + It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note, + num_frames can be 1 + interpolate_pos_encoding: (bool, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + + Returns: + embeddings: The embedding of grid with size (batch_size, height*width, num_channels) + + """ + batch_size, num_frames, height, width, num_channels = grid.shape + # temporal mean pooling, (batch_size, height, width, hidden_size) + grid = grid.mean(1) + grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding) + # image token sequence, (batch_size, height*width, num_channels) + visual_tokens = grid.view(batch_size, -1, num_channels) + visual_tokens_shape = visual_tokens.shape[:-1] + device = visual_tokens.device + + # image token type embeddings. + token_type_ids = torch.zeros(visual_tokens_shape, dtype=torch.long, device=device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = visual_tokens + token_type_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class TvpTextInputEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class TvpAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.attn_dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.num_attention_heads, self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + + # Prune linear layers + self.query = prune_linear_layer(self.query, index) + self.key = prune_linear_layer(self.key, index) + self.value = prune_linear_layer(self.value, index) + self.dense = prune_linear_layer(self.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.num_attention_heads = self.num_attention_heads - len(heads) + self.all_head_size = self.attention_head_size * self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def _reshape(self, tensor: torch.Tensor, sequence_length: int, batch_size: int): + return ( + tensor.view(batch_size, sequence_length, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions: Optional[bool] = None, + ): + batch_size, sequence_length = hidden_states.shape[:2] + mixed_query_layer = self.query(hidden_states) + + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self._reshape(mixed_query_layer, sequence_length, batch_size) + key_layer = self._reshape(mixed_key_layer, sequence_length, batch_size) + value_layer = self._reshape(mixed_value_layer, sequence_length, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attn_dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + attn_output = torch.matmul(attention_probs, value_layer) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, sequence_length, self.all_head_size) + + attn_output = self.dense(attn_output) + attn_output = self.dropout(attn_output) + attn_output = self.layer_norm(attn_output + hidden_states) + # add attentions if we output them + outputs = (attn_output, attention_probs) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Tvp +class TvpIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class TvpOutputLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.layer_norm(hidden_states + input_tensor) + return hidden_states + + +class TvpEncodeLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.attention = TvpAttention(config) + self.intermediate = TvpIntermediate(config) + self.output = TvpOutputLayer(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions: Optional[bool] = None, + ): + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs + return outputs + + +class TvpEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([TvpEncodeLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + return_dict = return_dict if return_dict is not None else self.config.return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + all_hidden_states = () + all_attentions = () + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions) + + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if output_hidden_states else None, + attentions=all_attentions if output_attentions else None, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Tvp +class TvpPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class TvpPreTrainedModel(PreTrainedModel): + config: TvpConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, TvpModel): + nn.init.normal_(module.text_prompt) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + if hasattr(module, "pad_up"): + nn.init.normal_(module.pad_up) + if hasattr(module, "pad_down"): + nn.init.normal_(module.pad_down) + if hasattr(module, "pad_left"): + nn.init.normal_(module.pad_left) + if hasattr(module, "pad_right"): + nn.init.normal_(module.pad_right) + + +class TvpFrameDownPadPrompter(nn.Module): + """ + Pad frames extracted from videos only at the bottom. + """ + + def __init__(self, config): + if config.visual_prompter_apply not in ("add", "replace", "remove"): + raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)") + + super().__init__() + self.visual_prompt_size = config.visual_prompt_size + self.frame_num = config.frame_num + self.max_img_size = config.max_img_size + self.visual_prompter_apply = config.visual_prompter_apply + + self.pad_down = nn.Parameter( + torch.randn([1, config.frame_num, 3, config.visual_prompt_size, config.max_img_size]) + ) + + def forward(self, pixel_values): + if self.visual_prompter_apply != "add": + visual_prompt_mask = torch.ones( + [self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device + ) + visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0 + pixel_values *= visual_prompt_mask + if self.visual_prompter_apply != "remove": + prompt = torch.zeros( + [pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size], + device=pixel_values.device, + ) + start_point = self.max_img_size - self.visual_prompt_size + prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down + pixel_values += prompt.to(pixel_values.dtype) + return pixel_values + + +class TvpFramePadPrompter(nn.Module): + """ + Pad frames extracted from videos in the surroundings. + """ + + def __init__(self, config): + if config.visual_prompter_apply not in ("add", "replace", "remove"): + raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)") + + super().__init__() + self.num_frames = config.num_frames + self.max_img_size = config.max_img_size + self.visual_prompter_apply = config.visual_prompter_apply + self.base_size = config.max_img_size - config.visual_prompt_size * 2 + self.pad_up = nn.Parameter( + torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size]) + ) + self.pad_down = nn.Parameter( + torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size]) + ) + self.pad_left = nn.Parameter( + torch.randn( + [ + 1, + config.num_frames, + 3, + config.max_img_size - config.visual_prompt_size * 2, + config.visual_prompt_size, + ] + ) + ) + self.pad_right = nn.Parameter( + torch.randn( + [ + 1, + config.num_frames, + 3, + config.max_img_size - config.visual_prompt_size * 2, + config.visual_prompt_size, + ] + ) + ) + + def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high + resolution images (high resolution videos). + + """ + + # creates scale factor from height and width of original image wrt to the config.max_img_size + h0, w0 = height / self.max_img_size, width / self.max_img_size + + batch, num_frames, channels, prompt_height, prompt_width = prompt.shape + + # reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation + prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width) + prompt = nn.functional.interpolate( + prompt, + scale_factor=(h0, w0), + mode="bicubic", + align_corners=False, + ) + # reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width + prompt = prompt.reshape(batch, num_frames, channels, height, width) + return prompt + + def forward(self, pixel_values, interpolate_pad_encoding: bool = False): + height, width = ( + (pixel_values.shape[-2], pixel_values.shape[-1]) + if interpolate_pad_encoding + else (self.max_img_size, self.max_img_size) + ) + if self.visual_prompter_apply not in ("add", "remove", "replace"): + raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}") + if self.visual_prompter_apply in ("replace", "remove"): + visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device) + pixel_values *= visual_prompt_mask + if self.visual_prompter_apply in ("replace", "add"): + base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device) + + prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4) + prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3) + prompt = torch.cat(pixel_values.size(0) * [prompt]) + if interpolate_pad_encoding: + prompt = self.interpolate_pad_encoding(prompt, height, width) + pixel_values = pixel_values + prompt.to(pixel_values.dtype) + return pixel_values + + +TVP_PROMPTER_CLASSES_MAPPING = { + "framedownpad": TvpFrameDownPadPrompter, + "framepad": TvpFramePadPrompter, +} + + +@auto_docstring( + custom_intro=""" + The bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top. + """ +) +class TvpModel(TvpPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.vision_model = TvpVisionModel(config) + self.embeddings = TvpTextInputEmbeddings(config) + self.visual_embeddings = TvpVisualInputEmbedding(config) + self.encoder = TvpEncoder(config) + self.pooler = TvpPooler(config) + self.text_prompt = nn.Parameter(torch.randn([1, 10, config.hidden_size])) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.visual_prompter_type not in TVP_PROMPTER_CLASSES_MAPPING: + raise ValueError("`visual_prompter_type` must be in (framedownpad, framepad)") + self.visual_prompter = TVP_PROMPTER_CLASSES_MAPPING[config.visual_prompter_type](config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ): + r""" + Examples: + ```python + >>> import torch + >>> from transformers import AutoConfig, AutoTokenizer, TvpModel + + >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp") + + >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp") + + >>> pixel_values = torch.rand(1, 1, 3, 448, 448) + >>> text_inputs = tokenizer("This is an example input", return_tensors="pt") + >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask) + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + # Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features. + pixel_values = self.vision_model( + self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding) + ) + # (batch_size, sequence_length, hidden_size) + text_embedding_output = self.embeddings(input_ids=input_ids) + # (batch_size, visual_sequence_length, hidden_size) + visual_embedding_output = self.visual_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) + + if attention_mask is not None: + # (batch_size, visual_sequence_length) + visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2]) + pt_mask = torch.ones(attention_mask.shape[0], 10).to( + device=attention_mask.device, dtype=attention_mask.dtype + ) + attention_mask = torch.cat([pt_mask, attention_mask, visual_attention_mask], dim=-1) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size()).to(input_ids.device) + text_prompt = self.text_prompt.expand(text_embedding_output.shape[0], -1, -1) + # (batch_size, sequence_length + visual_sequence_length, hidden_size) + embedding_output = torch.cat([text_prompt, text_embedding_output, visual_embedding_output], dim=1) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + head_mask=self.get_head_mask(head_mask, self.config.num_hidden_layers), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + last_hidden_state = self.dropout(last_hidden_state) + pooled_output = self.dropout(pooled_output) + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class TvpVideoGroundingHead(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_0 = nn.Linear(config.hidden_size, config.hidden_size * 2) + self.layer_1 = nn.Linear(config.hidden_size * 2, 2) + self.activation_0 = nn.ReLU() + self.activation_1 = nn.Sigmoid() + + def forward(self, pooler_output): + logits = self.activation_0(self.layer_0(pooler_output)) + logits = self.activation_1(self.layer_1(logits)) + return logits + + +@auto_docstring( + custom_intro=""" + Tvp Model with a video grounding head on top computing IoU, distance, and duration loss. + """ +) +class TvpForVideoGrounding(TvpPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.model = TvpModel(config) + self.video_grounding_head = TvpVideoGroundingHead(config) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + labels: Optional[tuple[torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ): + r""" + labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*): + The labels contains duration, start time, and end time of the video corresponding to the text. + + Examples: + ```python + >>> import torch + >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding + + >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp") + + >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp") + + >>> pixel_values = torch.rand(1, 1, 3, 448, 448) + >>> text_inputs = tokenizer("This is an example input", return_tensors="pt") + >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask) + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + outputs = self.model( + input_ids, + pixel_values, + attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + pooler_output = outputs[1] + logits = self.video_grounding_head(pooler_output) + + loss = None + if labels is not None: + criterion = TvpLoss(["iou", "distance", "duration"]) + criterion.to(self.device) + loss_dict = criterion(logits, labels) + loss = ( + loss_dict["iou"] + + self.config.distance_loss_weight * loss_dict["distance"] + + self.config.duration_loss_weight * loss_dict["duration"] + ) + if not return_dict: + outputs = (logits,) + outputs[2:] + if loss is not None: + outputs = (loss,) + outputs + return outputs + + return TvpVideoGroundingOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["TvpModel", "TvpPreTrainedModel", "TvpForVideoGrounding"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/tvp/processing_tvp.py b/venv/lib/python3.13/site-packages/transformers/models/tvp/processing_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..7cec0f14ab765b2cfe9d4bca97dd565fb904e9ac --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/tvp/processing_tvp.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for TVP. +""" + +from ...processing_utils import ProcessingKwargs, ProcessorMixin + + +class TvpProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "truncation": True, + "padding": "max_length", + "pad_to_max_length": True, + "return_token_type_ids": False, + }, + } + + +class TvpProcessor(ProcessorMixin): + r""" + Constructs an TVP processor which wraps a TVP image processor and a Bert tokenizer into a single processor. + + [`TvpProcessor`] offers all the functionalities of [`TvpImageProcessor`] and [`BertTokenizerFast`]. See the + [`~TvpProcessor.__call__`] and [`~TvpProcessor.decode`] for more information. + + Args: + image_processor ([`TvpImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`BertTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "TvpImageProcessor" + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + super().__init__(image_processor, tokenizer) + self.video_processor = image_processor + + def post_process_video_grounding(self, logits, video_durations): + """ + Compute the time of the video. + + Args: + logits (`torch.Tensor`): + The logits output of TvpForVideoGrounding. + video_durations (`float`): + The video's duration. + + Returns: + start (`float`): + The start time of the video. + end (`float`): + The end time of the video. + """ + start, end = ( + round(logits.tolist()[0][0] * video_durations, 1), + round(logits.tolist()[0][1] * video_durations, 1), + ) + + return start, end + + +__all__ = ["TvpProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/univnet/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/univnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d02d0dbe574ebac34022e5e8727c900e16f5977 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/univnet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_univnet import * + from .feature_extraction_univnet import * + from .modeling_univnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/univnet/configuration_univnet.py b/venv/lib/python3.13/site-packages/transformers/models/univnet/configuration_univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1bad6da2537ecf2f0c7aabeb5799449295e81e --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/univnet/configuration_univnet.py @@ -0,0 +1,125 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UnivNetModel model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class UnivNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`UnivNetModel`]. It is used to instantiate a + UnivNet vocoder model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UnivNet + [dg845/univnet-dev](https://huggingface.co/dg845/univnet-dev) architecture, which corresponds to the 'c32' + architecture in [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/master/config/default_c32.yaml). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + model_in_channels (`int`, *optional*, defaults to 64): + The number of input channels for the UnivNet residual network. This should correspond to + `noise_sequence.shape[1]` and the value used in the [`UnivNetFeatureExtractor`] class. + model_hidden_channels (`int`, *optional*, defaults to 32): + The number of hidden channels of each residual block in the UnivNet residual network. + num_mel_bins (`int`, *optional*, defaults to 100): + The number of frequency bins in the conditioning log-mel spectrogram. This should correspond to the value + used in the [`UnivNetFeatureExtractor`] class. + resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 3, 3]`): + A tuple of integers defining the kernel sizes of the 1D convolutional layers in the UnivNet residual + network. The length of `resblock_kernel_sizes` defines the number of resnet blocks and should match that of + `resblock_stride_sizes` and `resblock_dilation_sizes`. + resblock_stride_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 4]`): + A tuple of integers defining the stride sizes of the 1D convolutional layers in the UnivNet residual + network. The length of `resblock_stride_sizes` should match that of `resblock_kernel_sizes` and + `resblock_dilation_sizes`. + resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 9, 27], [1, 3, 9, 27], [1, 3, 9, 27]]`): + A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the + UnivNet residual network. The length of `resblock_dilation_sizes` should match that of + `resblock_kernel_sizes` and `resblock_stride_sizes`. The length of each nested list in + `resblock_dilation_sizes` defines the number of convolutional layers per resnet block. + kernel_predictor_num_blocks (`int`, *optional*, defaults to 3): + The number of residual blocks in the kernel predictor network, which calculates the kernel and bias for + each location variable convolution layer in the UnivNet residual network. + kernel_predictor_hidden_channels (`int`, *optional*, defaults to 64): + The number of hidden channels for each residual block in the kernel predictor network. + kernel_predictor_conv_size (`int`, *optional*, defaults to 3): + The kernel size of each 1D convolutional layer in the kernel predictor network. + kernel_predictor_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for each residual block in the kernel predictor network. + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + leaky_relu_slope (`float`, *optional*, defaults to 0.2): + The angle of the negative slope used by the leaky ReLU activation. + + Example: + + ```python + >>> from transformers import UnivNetModel, UnivNetConfig + + >>> # Initializing a Tortoise TTS style configuration + >>> configuration = UnivNetConfig() + + >>> # Initializing a model (with random weights) from the Tortoise TTS style configuration + >>> model = UnivNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "univnet" + + def __init__( + self, + model_in_channels=64, + model_hidden_channels=32, + num_mel_bins=100, + resblock_kernel_sizes=[3, 3, 3], + resblock_stride_sizes=[8, 8, 4], + resblock_dilation_sizes=[[1, 3, 9, 27], [1, 3, 9, 27], [1, 3, 9, 27]], + kernel_predictor_num_blocks=3, + kernel_predictor_hidden_channels=64, + kernel_predictor_conv_size=3, + kernel_predictor_dropout=0.0, + initializer_range=0.01, + leaky_relu_slope=0.2, + **kwargs, + ): + if not (len(resblock_kernel_sizes) == len(resblock_stride_sizes) == len(resblock_dilation_sizes)): + raise ValueError( + "`resblock_kernel_sizes`, `resblock_stride_sizes`, and `resblock_dilation_sizes` must all have the" + " same length (which will be the number of resnet blocks in the model)." + ) + + self.model_in_channels = model_in_channels + self.model_hidden_channels = model_hidden_channels + self.num_mel_bins = num_mel_bins + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_stride_sizes = resblock_stride_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.kernel_predictor_num_blocks = kernel_predictor_num_blocks + self.kernel_predictor_hidden_channels = kernel_predictor_hidden_channels + self.kernel_predictor_conv_size = kernel_predictor_conv_size + self.kernel_predictor_dropout = kernel_predictor_dropout + self.initializer_range = initializer_range + self.leaky_relu_slope = leaky_relu_slope + super().__init__(**kwargs) + + +__all__ = ["UnivNetConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/univnet/feature_extraction_univnet.py b/venv/lib/python3.13/site-packages/transformers/models/univnet/feature_extraction_univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..059226afe1d1147da2fb96626dce57b7c2f41715 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/univnet/feature_extraction_univnet.py @@ -0,0 +1,459 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for UnivNetModel.""" + +from typing import Any, Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class UnivNetFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a UnivNet feature extractor. + + This class extracts log-mel-filter bank features from raw speech using the short time Fourier Transform (STFT). The + STFT implementation follows that of TacoTron 2 and Hifi-GAN. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value to pad with when applying the padding strategy defined by the `padding` argument to + [`UnivNetFeatureExtractor.__call__`]. Should correspond to audio silence. The `pad_end` argument to + `__call__` will also use this padding value. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve the + performance for some models. + num_mel_bins (`int`, *optional*, defaults to 100): + The number of mel-frequency bins in the extracted spectrogram features. This should match + `UnivNetModel.config.num_mel_bins`. + hop_length (`int`, *optional*, defaults to 256): + The direct number of samples between sliding windows. Otherwise referred to as "shift" in many papers. Note + that this is different from other audio feature extractors such as [`SpeechT5FeatureExtractor`] which take + the `hop_length` in ms. + win_length (`int`, *optional*, defaults to 1024): + The direct number of samples for each sliding window. Note that this is different from other audio feature + extractors such as [`SpeechT5FeatureExtractor`] which take the `win_length` in ms. + win_function (`str`, *optional*, defaults to `"hann_window"`): + Name for the window function used for windowing, must be accessible via `torch.{win_function}` + filter_length (`int`, *optional*, defaults to 1024): + The number of FFT components to use. If `None`, this is determined using + `transformers.audio_utils.optimal_fft_length`. + max_length_s (`int`, *optional*, defaults to 10): + The maximum input length of the model in seconds. This is used to pad the audio. + fmin (`float`, *optional*, defaults to 0.0): + Minimum mel frequency in Hz. + fmax (`float`, *optional*): + Maximum mel frequency in Hz. If not set, defaults to `sampling_rate / 2`. + mel_floor (`float`, *optional*, defaults to 1e-09): + Minimum value of mel frequency banks. Note that the way [`UnivNetFeatureExtractor`] uses `mel_floor` is + different than in [`transformers.audio_utils.spectrogram`]. + center (`bool`, *optional*, defaults to `False`): + Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame + `t` will start at time `t * hop_length`. + compression_factor (`float`, *optional*, defaults to 1.0): + The multiplicative compression factor for dynamic range compression during spectral normalization. + compression_clip_val (`float`, *optional*, defaults to 1e-05): + The clip value applied to the waveform before applying dynamic range compression during spectral + normalization. + normalize_min (`float`, *optional*, defaults to -11.512925148010254): + The min value used for Tacotron 2-style linear normalization. The default is the original value from the + Tacotron 2 implementation. + normalize_max (`float`, *optional*, defaults to 2.3143386840820312): + The max value used for Tacotron 2-style linear normalization. The default is the original value from the + Tacotron 2 implementation. + model_in_channels (`int`, *optional*, defaults to 64): + The number of input channels to the [`UnivNetModel`] model. This should match + `UnivNetModel.config.model_in_channels`. + pad_end_length (`int`, *optional*, defaults to 10): + If padding the end of each waveform, the number of spectrogram frames worth of samples to append. The + number of appended samples will be `pad_end_length * hop_length`. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether or not [`~UnivNetFeatureExtractor.__call__`] should return `attention_mask`. + """ + + model_input_names = ["input_features", "noise_sequence", "padding_mask"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 24000, + padding_value: float = 0.0, + do_normalize: bool = False, + num_mel_bins: int = 100, + hop_length: int = 256, + win_length: int = 1024, + win_function: str = "hann_window", + filter_length: Optional[int] = 1024, + max_length_s: int = 10, + fmin: float = 0.0, + fmax: Optional[float] = None, + mel_floor: float = 1e-9, + center: bool = False, + compression_factor: float = 1.0, + compression_clip_val: float = 1e-5, + normalize_min: float = -11.512925148010254, + normalize_max: float = 2.3143386840820312, + model_in_channels: int = 64, + pad_end_length: int = 10, + return_attention_mask=True, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + + self.do_normalize = do_normalize + + self.num_mel_bins = num_mel_bins + self.hop_length = hop_length + self.win_length = win_length + self.win_function = win_function + self.filter_length = filter_length + self.fmin = fmin + if fmax is None: + # Follows the librosa.filters.mel implementation + fmax = float(sampling_rate) / 2 + self.fmax = fmax + self.mel_floor = mel_floor + + self.max_length_s = max_length_s + self.num_max_samples = max_length_s * sampling_rate + + if self.filter_length is None: + self.n_fft = optimal_fft_length(self.win_length) + else: + self.n_fft = self.filter_length + self.n_freqs = (self.n_fft // 2) + 1 + + self.window = window_function(window_length=self.win_length, name=self.win_function, periodic=True) + + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.n_freqs, + num_mel_filters=self.num_mel_bins, + min_frequency=self.fmin, + max_frequency=self.fmax, + sampling_rate=self.sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + self.center = center + self.compression_factor = compression_factor + self.compression_clip_val = compression_clip_val + self.normalize_min = normalize_min + self.normalize_max = normalize_max + self.model_in_channels = model_in_channels + self.pad_end_length = pad_end_length + + def normalize(self, spectrogram): + return 2 * ((spectrogram - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1 + + def denormalize(self, spectrogram): + return self.normalize_min + (self.normalize_max - self.normalize_min) * ((spectrogram + 1) / 2) + + def mel_spectrogram(self, waveform: np.ndarray) -> np.ndarray: + """ + Calculates log MEL spectrograms from a batch of waveforms. Note that the input waveform(s) will be padded by + `int(self.n_fft - self.hop_length) / 2` on both sides using the `reflect` padding mode. + + Args: + waveform (`np.ndarray` of shape `(length,)`): + The input waveform. This must be a single real-valued, mono waveform. + + Returns: + `numpy.ndarray`: Array containing a log-mel spectrogram of shape `(num_frames, num_mel_bins)`. + """ + # Do custom padding based on the official MelGAN and Hifi-GAN implementations + # See https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/utils/stft.py#L84-L86 + waveform = np.pad( + waveform, + (int((self.n_fft - self.hop_length) / 2), int((self.n_fft - self.hop_length) / 2)), + mode="reflect", + ) + + # Get the complex spectrogram. + # Note: waveform must be unbatched currently due to the implementation of spectrogram(...). + complex_spectrogram = spectrogram( + waveform, + window=self.window, + frame_length=self.n_fft, + hop_length=self.hop_length, + fft_length=self.n_fft, + power=None, + center=self.center, + mel_filters=None, + mel_floor=None, + ) + + # Apply the MEL filter bank and MEL floor manually since UnivNet uses a slightly different implementation + amplitude_spectrogram = np.sqrt( + np.real(complex_spectrogram) ** 2 + np.imag(complex_spectrogram) ** 2 + self.mel_floor + ) + mel_spectrogram = np.matmul(self.mel_filters.T, amplitude_spectrogram) + + # Perform spectral normalization to get the log mel spectrogram. + log_mel_spectrogram = np.log( + np.clip(mel_spectrogram, a_min=self.compression_clip_val, a_max=None) * self.compression_factor + ) + + # Return spectrogram with num_mel_bins last + return log_mel_spectrogram.T + + def generate_noise( + self, + noise_length: int, + generator: Optional[np.random.Generator] = None, + ) -> np.ndarray: + """ + Generates a random noise sequence of standard Gaussian noise for use in the `noise_sequence` argument of + [`UnivNetModel.forward`]. + + Args: + spectrogram_length (`int`): + The length (dim 0) of the generated noise. + model_in_channels (`int`, *optional*, defaults to `None`): + The number of features (dim 1) of the generated noise. This should correspond to the + `model_in_channels` of the [`UnivNetGan`] model. If not set, this will default to + `self.config.model_in_channels`. + generator (`numpy.random.Generator`, *optional*, defaults to `None`) + An optional `numpy.random.Generator` random number generator to control noise generation. If not set, a + new generator with fresh entropy will be created. + + Returns: + `numpy.ndarray`: Array containing random standard Gaussian noise of shape `(noise_length, + model_in_channels)`. + """ + if generator is None: + generator = np.random.default_rng() + + noise_shape = (noise_length, self.model_in_channels) + noise = generator.standard_normal(noise_shape, dtype=np.float32) + + return noise + + def batch_decode(self, waveforms, waveform_lengths=None) -> list[np.ndarray]: + r""" + Removes padding from generated audio after running [`UnivNetModel.forward`]. This returns a ragged list of 1D + audio waveform arrays and not a single tensor/array because in general the waveforms will have different + lengths after removing padding. + + Args: + waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + The batched output waveforms from the [`UnivNetModel`]. + waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): + The batched lengths of each waveform before padding. + + Returns: + `list[np.ndarray]`: A ragged list of 1D waveform arrays with padding removed. + """ + # Collapse the batched waveform tensor to a list of 1D audio waveforms + waveforms = [waveform.detach().to(device="cpu", copy=True).numpy() for waveform in waveforms] + + if waveform_lengths is not None: + waveforms = [waveform[: waveform_lengths[i]] for i, waveform in enumerate(waveforms)] + + return waveforms + + def __call__( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + sampling_rate: Optional[int] = None, + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + truncation: bool = True, + pad_to_multiple_of: Optional[int] = None, + return_noise: bool = True, + generator: Optional[np.random.Generator] = None, + pad_end: bool = False, + pad_length: Optional[int] = None, + do_normalize: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition + pipeline. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the input `raw_speech` waveforms (according to the model's padding side and + padding index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + + If `pad_end = True`, that padding will occur before the `padding` strategy is applied. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*, defaults to `True`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_noise (`bool`, *optional*, defaults to `True`): + Whether to generate and return a noise waveform for use in [`UnivNetModel.forward`]. + generator (`numpy.random.Generator`, *optional*, defaults to `None`): + An optional `numpy.random.Generator` random number generator to use when generating noise. + pad_end (`bool`, *optional*, defaults to `False`): + Whether to pad the end of each waveform with silence. This can help reduce artifacts at the end of the + generated audio sample; see https://github.com/seungwonpark/melgan/issues/8 for more details. This + padding will be done before the padding strategy specified in `padding` is performed. + pad_length (`int`, *optional*, defaults to `None`): + If padding the end of each waveform, the length of the padding in spectrogram frames. If not set, this + will default to `self.config.pad_end_length`. + do_normalize (`bool`, *optional*): + Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve + the performance for some models. If not set, this will default to `self.config.do_normalize`. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.np.array` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [np.asarray(raw_speech, dtype=np.float32)] + + # Pad end to reduce artifacts + if pad_end: + pad_length = pad_length if pad_length is not None else self.pad_end_length + raw_speech = [ + np.pad(waveform, (0, pad_length * self.hop_length), constant_values=self.padding_value) + for waveform in raw_speech + ] + + batched_speech = BatchFeature({"input_features": raw_speech}) + + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length if max_length is not None else self.num_max_samples, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + # make sure list is in array format + # input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + input_features = padded_inputs.get("input_features") + + mel_spectrograms = [self.mel_spectrogram(waveform) for waveform in input_features] + + if isinstance(input_features[0], list): + batched_speech["input_features"] = [np.asarray(mel, dtype=np.float32) for mel in mel_spectrograms] + else: + batched_speech["input_features"] = [mel.astype(np.float32) for mel in mel_spectrograms] + + # convert attention_mask to correct format + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + batched_speech["padding_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + if return_noise: + noise = [ + self.generate_noise(spectrogram.shape[0], generator) + for spectrogram in batched_speech["input_features"] + ] + batched_speech["noise_sequence"] = noise + + if do_normalize: + batched_speech["input_features"] = [ + self.normalize(spectrogram) for spectrogram in batched_speech["input_features"] + ] + + if return_tensors is not None: + batched_speech = batched_speech.convert_to_tensors(return_tensors) + + return batched_speech + + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + + # Don't serialize these as they are derived from the other properties. + names = ["window", "mel_filters", "n_fft", "n_freqs", "num_max_samples"] + for name in names: + if name in output: + del output[name] + + return output + + +__all__ = ["UnivNetFeatureExtractor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/univnet/modeling_univnet.py b/venv/lib/python3.13/site-packages/transformers/models/univnet/modeling_univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e7595ff38f8aee81151eb9166ca8bc527ae6b0be --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/univnet/modeling_univnet.py @@ -0,0 +1,617 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UnivNetModel model.""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_univnet import UnivNetConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output class for the [`UnivNetModel`], which includes the generated audio waveforms and the original unpadded + lengths of those waveforms (so that the padding can be removed by [`UnivNetModel.batch_decode`]). + """ +) +class UnivNetModelOutput(ModelOutput): + r""" + waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Batched 1D (mono-channel) output audio waveforms. + waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`): + The batched length in samples of each unpadded waveform in `waveforms`. + """ + + waveforms: Optional[torch.FloatTensor] = None + waveform_lengths: Optional[torch.FloatTensor] = None + + +class UnivNetKernelPredictorResidualBlock(nn.Module): + """ + Implementation of the residual block for the kernel predictor network inside each location variable convolution + block (LVCBlock). + + Parameters: + config: (`UnivNetConfig`): + Config for the `UnivNetModel` model. + """ + + def __init__( + self, + config: UnivNetConfig, + ): + super().__init__() + self.channels = config.model_in_channels + self.kernel_size = config.kernel_predictor_conv_size + self.dropout_prob = config.kernel_predictor_dropout + self.leaky_relu_slope = config.leaky_relu_slope + + padding = (self.kernel_size - 1) // 2 + + self.dropout = nn.Dropout(self.dropout_prob) + self.conv1 = nn.Conv1d(self.channels, self.channels, self.kernel_size, padding=padding, bias=True) + self.conv2 = nn.Conv1d(self.channels, self.channels, self.kernel_size, padding=padding, bias=True) + + def forward(self, hidden_states: torch.FloatTensor): + # hidden_states should have shape (batch_size, channels, seq_length) + residual = hidden_states + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.conv2(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + return hidden_states + residual + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv1) + weight_norm(self.conv2) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv1) + nn.utils.remove_weight_norm(self.conv2) + + +class UnivNetKernelPredictor(nn.Module): + """ + Implementation of the kernel predictor network which supplies the kernel and bias for the location variable + convolutional layers (LVCs) in each UnivNet LVCBlock. + + Based on the KernelPredictor implementation in + [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L7). + + Parameters: + config: (`UnivNetConfig`): + Config for the `UnivNetModel` model. + conv_kernel_size (`int`, *optional*, defaults to 3): + The kernel size for the location variable convolutional layer kernels (convolutional weight tensor). + conv_layers (`int`, *optional*, defaults to 4): + The number of location variable convolutional layers to output kernels and biases for. + """ + + def __init__( + self, + config: UnivNetConfig, + conv_kernel_size: int = 3, + conv_layers: int = 4, + ): + super().__init__() + + self.conv_in_channels = config.model_hidden_channels + self.conv_out_channels = 2 * config.model_hidden_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + self.kernel_channels = ( + self.conv_in_channels * self.conv_out_channels * self.conv_kernel_size * self.conv_layers + ) + self.bias_channels = self.conv_out_channels * self.conv_layers + + self.resnet_in_channels = config.num_mel_bins + self.resnet_hidden_channels = config.kernel_predictor_hidden_channels + self.resnet_kernel_size = config.kernel_predictor_conv_size + self.num_blocks = config.kernel_predictor_num_blocks + + self.leaky_relu_slope = config.leaky_relu_slope + + padding = (self.resnet_kernel_size - 1) // 2 + + self.input_conv = nn.Conv1d(self.resnet_in_channels, self.resnet_hidden_channels, 5, padding=2, bias=True) + + self.resblocks = nn.ModuleList([UnivNetKernelPredictorResidualBlock(config) for _ in range(self.num_blocks)]) + + self.kernel_conv = nn.Conv1d( + self.resnet_hidden_channels, self.kernel_channels, self.resnet_kernel_size, padding=padding, bias=True + ) + self.bias_conv = nn.Conv1d( + self.resnet_hidden_channels, self.bias_channels, self.resnet_kernel_size, padding=padding, bias=True + ) + + def forward(self, spectrogram: torch.FloatTensor): + """ + Maps a conditioning log-mel spectrogram to a tensor of convolutional kernels and biases, for use in location + variable convolutional layers. Note that the input spectrogram should have shape (batch_size, input_channels, + seq_length). + + Args: + spectrogram (`torch.FloatTensor` of shape `(batch_size, input_channels, seq_length)`): + Tensor containing the log-mel spectrograms. + + Returns: + tuple[`torch.FloatTensor, `torch.FloatTensor`]: tuple of tensors where the first element is the tensor of + location variable convolution kernels of shape `(batch_size, self.conv_layers, self.conv_in_channels, + self.conv_out_channels, self.conv_kernel_size, seq_length)` and the second element is the tensor of + location variable convolution biases of shape `(batch_size, self.conv_layers. self.conv_out_channels, + seq_length)`. + """ + batch_size, _, seq_length = spectrogram.shape + + hidden_states = self.input_conv(spectrogram) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + + for resblock in self.resblocks: + hidden_states = resblock(hidden_states) + + kernel_hidden_states = self.kernel_conv(hidden_states) + bias_hidden_states = self.bias_conv(hidden_states) + + # Reshape kernels and biases to appropriate shape + kernels = kernel_hidden_states.view( + batch_size, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + seq_length, + ).contiguous() + biases = bias_hidden_states.view( + batch_size, + self.conv_layers, + self.conv_out_channels, + seq_length, + ).contiguous() + + return kernels, biases + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.input_conv) + for layer in self.resblocks: + layer.apply_weight_norm() + weight_norm(self.kernel_conv) + weight_norm(self.bias_conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv) + for layer in self.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.kernel_conv) + nn.utils.remove_weight_norm(self.bias_conv) + + +class UnivNetLvcResidualBlock(nn.Module): + """ + Implementation of the location variable convolution (LVC) residual block for the UnivNet residual network. + + Parameters: + config: (`UnivNetConfig`): + Config for the `UnivNetModel` model. + kernel_size (`int`): + The kernel size for the dilated 1D convolutional layer. + dilation (`int`): + The dilation for the dilated 1D convolutional layer. + """ + + def __init__( + self, + config: UnivNetConfig, + kernel_size: int, + dilation: int, + ): + super().__init__() + self.hidden_channels = config.model_hidden_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.leaky_relu_slope = config.leaky_relu_slope + + padding = self.dilation * (self.kernel_size - 1) // 2 + + self.conv = nn.Conv1d( + self.hidden_channels, + self.hidden_channels, + self.kernel_size, + padding=padding, + dilation=self.dilation, + ) + + def forward(self, hidden_states, kernel, bias, hop_size=256): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.location_variable_convolution(hidden_states, kernel, bias, hop_size=hop_size) + # Gated activation unit + hidden_states = torch.sigmoid(hidden_states[:, : self.hidden_channels, :]) * torch.tanh( + hidden_states[:, self.hidden_channels :, :] + ) + # Skip connection + hidden_states = residual + hidden_states + + return hidden_states + + # Based on https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L171 + def location_variable_convolution( + self, + hidden_states: torch.FloatTensor, + kernel: torch.FloatTensor, + bias: torch.FloatTensor, + dilation: int = 1, + hop_size: int = 256, + ): + """ + Performs location-variable convolution operation on the input sequence (hidden_states) using the local + convolution kernel. This was introduced in [LVCNet: Efficient Condition-Dependent Modeling Network for Waveform + Generation](https://huggingface.co/papers/2102.10815) by Zhen Zheng, Jianzong Wang, Ning Cheng, and Jing Xiao. + + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, in_channels, in_length)`): + The input sequence of shape (batch, in_channels, in_length). + kernel (`torch.FloatTensor` of shape `(batch_size, in_channels, out_channels, kernel_size, kernel_length)`): + The local convolution kernel of shape (batch, in_channels, out_channels, kernel_size, kernel_length). + bias (`torch.FloatTensor` of shape `(batch_size, out_channels, kernel_length)`): + The bias for the local convolution of shape (batch, out_channels, kernel_length). + dilation (`int`, *optional*, defaults to 1): + The dilation of convolution. + hop_size (`int`, *optional*, defaults to 256): + The hop_size of the conditioning sequence. + Returns: + `torch.FloatTensor`: the output sequence after performing local convolution with shape (batch_size, + out_channels, in_length). + """ + batch, _, in_length = hidden_states.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + if in_length != (kernel_length * hop_size): + raise ValueError( + f"Dim 2 of `hidden_states` should be {kernel_length * hop_size}) but got {in_length}. Please check" + " `hidden_states` or `kernel` and `hop_size` to make sure they are correct." + ) + + padding = dilation * int((kernel_size - 1) / 2) + + # (batch, in_channels, in_length + 2*padding) + hidden_states = nn.functional.pad(hidden_states, (padding, padding), "constant", 0) + # (batch, in_channels, kernel_length, hop_size + 2*padding) + hidden_states = hidden_states.unfold(2, hop_size + 2 * padding, hop_size) + + if hop_size < dilation: + hidden_states = nn.functional.pad(hidden_states, (0, dilation), "constant", 0) + # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + hidden_states = hidden_states.unfold(3, dilation, dilation) + hidden_states = hidden_states[:, :, :, :, :hop_size] + # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + hidden_states = hidden_states.transpose(3, 4) + # (batch, in_channels, kernel_length, dilation, _, kernel_size) + hidden_states = hidden_states.unfold(4, kernel_size, 1) + + # Apply local convolution kernel to hidden_states. + output_hidden_states = torch.einsum("bildsk,biokl->bolsd", hidden_states, kernel) + + output_hidden_states = output_hidden_states.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + output_hidden_states = output_hidden_states + bias + output_hidden_states = output_hidden_states.contiguous().view(batch, out_channels, -1) + + return output_hidden_states + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + +class UnivNetLvcBlock(nn.Module): + """ + Implementation of the location variable convolution (LVC) residual block of the UnivNet residual block. Includes a + `UnivNetKernelPredictor` inside to predict the kernels and biases of the LVC layers. + + Based on LVCBlock in + [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L98) + + Parameters: + config (`UnivNetConfig`): + Config for the `UnivNetModel` model. + layer_id (`int`): + An integer corresponding to the index of the current LVC resnet block layer. This should be between 0 and + `len(config.resblock_stride_sizes) - 1)` inclusive. + lvc_hop_size (`int`, *optional*, defaults to 256): + The hop size for the location variable convolutional layers. + """ + + def __init__( + self, + config: UnivNetConfig, + layer_id: int, + lvc_hop_size: int = 256, + ): + super().__init__() + self.hidden_channels = config.model_hidden_channels + self.kernel_size = config.resblock_kernel_sizes[layer_id] + self.stride = config.resblock_stride_sizes[layer_id] + self.dilations = config.resblock_dilation_sizes[layer_id] + self.cond_hop_length = lvc_hop_size + self.leaky_relu_slope = config.leaky_relu_slope + self.num_blocks = len(self.dilations) + + self.convt_pre = nn.ConvTranspose1d( + self.hidden_channels, + self.hidden_channels, + 2 * self.stride, + stride=self.stride, + padding=self.stride // 2 + self.stride % 2, + output_padding=self.stride % 2, + ) + + self.kernel_predictor = UnivNetKernelPredictor(config, self.kernel_size, self.num_blocks) + + self.resblocks = nn.ModuleList( + [UnivNetLvcResidualBlock(config, self.kernel_size, self.dilations[i]) for i in range(self.num_blocks)] + ) + + def forward(self, hidden_states: torch.FloatTensor, spectrogram: torch.FloatTensor): + # hidden_states: (batch_size, hidden_channels, seq_length) + # spectrogram: (batch_size, cond_channels, cond_length) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.convt_pre(hidden_states) + + kernels, biases = self.kernel_predictor(spectrogram) + + for i, resblock in enumerate(self.resblocks): + kernel = kernels[:, i, :, :, :, :] + bias = biases[:, i, :, :] + hidden_states = resblock(hidden_states, kernel, bias, hop_size=self.cond_hop_length) + + return hidden_states + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.convt_pre) + self.kernel_predictor.apply_weight_norm() + for layer in self.resblocks: + layer.apply_weight_norm() + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.convt_pre) + self.kernel_predictor.remove_weight_norm() + for layer in self.resblocks: + layer.remove_weight_norm() + + +@auto_docstring +class UnivNetModel(PreTrainedModel): + config: UnivNetConfig + main_input_name = "input_features" + + def __init__(self, config: UnivNetConfig): + super().__init__(config) + + self.num_kernels = len(config.resblock_kernel_sizes) + self.leaky_relu_slope = config.leaky_relu_slope + + self.conv_pre = nn.Conv1d( + config.model_in_channels, + config.model_hidden_channels, + kernel_size=7, + stride=1, + padding=3, + padding_mode="reflect", + ) + + # Initialize location-variable convolution ResNet Blocks. + num_layers = len(config.resblock_stride_sizes) + hop_length = 1 + hop_lengths = [] + for stride in config.resblock_stride_sizes: + hop_length = hop_length * stride + hop_lengths.append(hop_length) + + self.resblocks = nn.ModuleList( + [ + UnivNetLvcBlock( + config, + layer_id=i, + lvc_hop_size=hop_lengths[i], + ) + for i in range(num_layers) + ] + ) + + self.conv_post = nn.Conv1d(config.model_hidden_channels, 1, 7, padding=3, padding_mode="reflect") + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_features: torch.FloatTensor, + noise_sequence: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.FloatTensor], UnivNetModelOutput]: + r""" + noise_sequence (`torch.FloatTensor`, *optional*): + Tensor containing a noise sequence of standard Gaussian noise. Can be batched and of shape `(batch_size, + sequence_length, config.model_in_channels)`, or un-batched and of shape (sequence_length, + config.model_in_channels)`. If not supplied, will be randomly generated. + padding_mask (`torch.BoolTensor`, *optional*): + Mask indicating which parts of each sequence are padded. Mask values are selected in `[0, 1]`: + + - 1 for tokens that are **not masked** + - 0 for tokens that are **masked** + + The mask can be batched and of shape `(batch_size, sequence_length)` or un-batched and of shape + `(sequence_length,)`. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + return_dict: + Whether to return a [`~utils.ModelOutput`] subclass instead of a plain tuple. + + Example: + + ```python + >>> from transformers import UnivNetFeatureExtractor, UnivNetModel + >>> from datasets import load_dataset, Audio + + >>> model = UnivNetModel.from_pretrained("dg845/univnet-dev") + >>> feature_extractor = UnivNetFeatureExtractor.from_pretrained("dg845/univnet-dev") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> # Resample the audio to the feature extractor's sampling rate. + >>> ds = ds.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) + >>> inputs = feature_extractor( + ... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt" + ... ) + >>> audio = model(**inputs).waveforms + >>> list(audio.shape) + [1, 140288] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Resolve batch sizes for noise_sequence and spectrogram + spectrogram_batched = input_features.dim() == 3 + if not spectrogram_batched: + input_features = input_features.unsqueeze(0) + spectrogram_batch_size, spectrogram_length, _ = input_features.shape + + if noise_sequence is not None: + noise_sequence_batched = noise_sequence.dim() == 3 + if not noise_sequence_batched: + noise_sequence = noise_sequence.unsqueeze(0) + else: + # Randomly generate noise_sequence + noise_sequence_shape = (spectrogram_batch_size, spectrogram_length, self.config.model_in_channels) + noise_sequence = torch.randn( + noise_sequence_shape, generator=generator, dtype=input_features.dtype, device=input_features.device + ) + noise_sequence_batch_size = noise_sequence.shape[0] + + if spectrogram_batch_size > 1 and noise_sequence_batch_size == 1: + # Repeat noise_sequence spectrogram_batch_size times + noise_sequence = noise_sequence.repeat(spectrogram_batch_size, 1, 1) + elif noise_sequence_batch_size > 1 and spectrogram_batch_size == 1: + # Repeat spectrogram noise_sequence_batch_size times + input_features = input_features.repeat(noise_sequence_batch_size, 1, 1) + + if noise_sequence_batch_size != spectrogram_batch_size: + raise ValueError( + f"The batch size of `noise_sequence` is {noise_sequence_batch_size} and the batch size of" + f" `input_features` is {spectrogram_batch_size}, but the two are expected to be equal." + ) + + if padding_mask is not None: + if padding_mask.dim() == 1: + padding_mask = padding_mask.unsqueeze(0) + padding_mask_batch_size = padding_mask.shape[0] + if padding_mask_batch_size != spectrogram_batch_size: + raise ValueError( + f"The batch size of `padding_mask` is {padding_mask_batch_size} and the batch size of" + f" `input_features` is {spectrogram_batch_size}, but the two are expected to be equal." + ) + + # Change shapes to have channels before sequence lengths + hidden_states = noise_sequence.transpose(2, 1) + input_features = input_features.transpose(2, 1) + + hidden_states = self.conv_pre(hidden_states) + + for resblock in self.resblocks: + hidden_states = resblock(hidden_states, input_features) + + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + # Remove sequence length dimension since this collapses to 1 + # NOTE: keep waveforms batched even if there's only one + waveform = hidden_states.squeeze(1) + + # Get sequence lengths for UnivNetFeatureExtractor.batch_decode. + waveform_lengths = None + if padding_mask is not None: + # Padding is always contiguous and added on the right + waveform_lengths = torch.sum(padding_mask, dim=1) + + if not return_dict: + outputs = (waveform, waveform_lengths) + return outputs + + return UnivNetModelOutput( + waveforms=waveform, + waveform_lengths=waveform_lengths, + ) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) + for layer in self.resblocks: + layer.apply_weight_norm() + weight_norm(self.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + for layer in self.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.conv_post) + + +__all__ = ["UnivNetModel"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/video_llava/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/video_llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5f5971a167997ffc89de76d5cc5ba3d8216b4b --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/video_llava/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_video_llava import * + from .image_processing_video_llava import * + from .modeling_video_llava import * + from .processing_video_llava import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/venv/lib/python3.13/site-packages/transformers/models/video_llava/configuration_video_llava.py b/venv/lib/python3.13/site-packages/transformers/models/video_llava/configuration_video_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..561d2e3a9254932ec795ca31b9d51e9fc6ee0d1d --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/video_llava/configuration_video_llava.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VideoLlava model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class VideoLlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VideoLlavaForConditionalGeneration`]. It is used to instantiate an + VideoLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the like LanguageBind/Video-LLaVA-7B-hf. + + e.g. [LanguageBind/Video-LLaVA-7B-hf](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`VideoLlavaVisionConfig`, *optional*): + Custom vision config or dict. Defaults to `CLIPVisionConfig` if not indicated. + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + Defaults to `LlamaConfig` if not indicated. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 32001): + The video token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the CLIP backbone. + Can be either "full" to select all features or "default" to select features without `CLS`. + vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_seq_length (`int`, *optional*, defaults to 256): + Sequence length of one image embedding. + video_seq_length (`int`, *optional*, defaults to 2056): + Sequence length of one video embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + + Example: + + ```python + >>> from transformers import VideoLlavaForConditionalGeneration, VideoLlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a VideoLlava video_llava-1.5-7b style configuration + >>> configuration = VideoLlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the video_llava-1.5-7b style configuration + >>> model = VideoLlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "video_llava" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + video_token_index=32001, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=256, + video_seq_length=2056, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.image_seq_length = image_seq_length + self.video_seq_length = video_seq_length + self.multimodal_projector_bias = multimodal_projector_bias + + self.vision_config = vision_config + + if isinstance(self.vision_config, dict): + if "model_type" not in vision_config: + vision_config["model_type"] = "clip_vision_model" + logger.warning("Key=`model_type` not found in vision config, setting it to `clip_vision_model`") + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=224, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + if isinstance(text_config, dict): + if "model_type" not in text_config: + text_config["model_type"] = "llama" + logger.warning("Key=`model_type` not found in text config, setting it to `llama`") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + super().__init__(**kwargs) + + +__all__ = ["VideoLlavaConfig"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/video_llava/image_processing_video_llava.py b/venv/lib/python3.13/site-packages/transformers/models/video_llava/image_processing_video_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed8f911af8e8bbf2f7921f8d560db4ece9f0e76 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/video_llava/image_processing_video_llava.py @@ -0,0 +1,391 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Video-LLaVA.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...video_utils import VideoInput, make_batched_videos + + +logger = logging.get_logger(__name__) + + +class VideoLlavaImageProcessor(BaseImageProcessor): + r""" + Constructs a CLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: Optional[list[ImageInput]] = None, + videos: Optional[list[VideoInput]] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`, *optional*): + List of images to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + videos (`VideoInput`, *optional*): + List of videos to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If + passing in videos with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + if images is not None: + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + data = {} + if videos is not None: + logger.warning( + "`VideoLlavaImageProcessor` works only with image inputs and doesn't process videos anymore. " + "This is a deprecated behavior and will be removed in v5.0. " + "Your videos should be forwarded to `VideoLlavaVideoProcessor`. " + ) + videos = make_batched_videos(videos) + pixel_values_videos = [ + [ + self._preprocess_image( + image=frame, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for frame in video + ] + for video in videos + ] + data["pixel_values_videos"] = pixel_values_videos + + if images is not None: + pixel_values_images = [ + self._preprocess_image( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for image in images + ] + data["pixel_values_images"] = pixel_values_images + + encoded_outputs = BatchFeature(data, tensor_type=return_tensors) + + return encoded_outputs + + def _preprocess_image( + self, + image: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + image = convert_to_rgb(image) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images/video frames. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + +__all__ = ["VideoLlavaImageProcessor"] diff --git a/venv/lib/python3.13/site-packages/transformers/models/video_llava/modeling_video_llava.py b/venv/lib/python3.13/site-packages/transformers/models/video_llava/modeling_video_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..2db424455087ea69e9abaed8177e80c617889f35 --- /dev/null +++ b/venv/lib/python3.13/site-packages/transformers/models/video_llava/modeling_video_llava.py @@ -0,0 +1,718 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch VideoLlava model.""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ..auto import AutoModel +from .configuration_video_llava import VideoLlavaConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VideoLlava base model outputs. + """ +) +class VideoLlavaModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VideoLlava causal language model (or autoregressive) outputs. + """ +) +class VideoLlavaCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->VideoLlava +class VideoLlavaMultiModalProjector(nn.Module): + def __init__(self, config: VideoLlavaConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@auto_docstring +class VideoLlavaPreTrainedModel(PreTrainedModel): + config: VideoLlavaConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["VideoLlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@auto_docstring( + custom_intro=""" + The VideoLlava model which consists of a vision backbone and a language model without language modeling head. + """, +) +class VideoLlavaModel(VideoLlavaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: VideoLlavaConfig): + super().__init__(config) + self.video_tower = AutoModel.from_config(config.vision_config) + self.image_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = VideoLlavaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModel.from_config(config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values_images: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + + image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + image_outputs = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + image_outputs = image_outputs[:, 1:] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + # For default; crop CLS from each hidden state in the hidden state pool + if vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + image_outputs = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(image_outputs) + + return image_features + + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + ): + """ + Obtains video last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input videos. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + Returns: + video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`). + frames (`int`): Number of frames the videos have. + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape + + pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) + video_outputs = self.video_tower(pixel_values, output_hidden_states=True) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + video_features = video_outputs.hidden_states[vision_feature_layer] + else: + hs_pool = [video_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + video_features = torch.cat(hs_pool, dim=-1) + + video_features = self.multi_modal_projector(video_features) + + return video_features, num_frames + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0] * video_features.shape[1]}" + ) + + return special_image_mask, special_video_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values_images: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, VideoLlavaModelOutputWithPast]: + r""" + pixel_values_images (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`VideoLlavaImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`VideoLlavaImageProcessor`] for processing images). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values_images is not None: + image_features = self.get_image_features( + pixel_values_images, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + if pixel_values_videos is not None: + video_features, num_frames = self.get_video_features( + pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return VideoLlavaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values_images is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The VideoLlava model which consists of a vision backbone and a language model. + """ +) +class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^image_tower": "model.image_tower", + "^video_tower": "model.video_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: VideoLlavaConfig): + super().__init__(config) + self.model = VideoLlavaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values_images: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + return self.model.get_image_features( + pixel_values_images=pixel_values_images, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def video_tower(self): + return self.model.video_tower + + @property + def image_tower(self): + return self.model.image_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values_images: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, VideoLlavaCausalLMOutputWithPast]: + r""" + pixel_values_images (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`VideoLlavaImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`VideoLlavaImageProcessor`] for processing images). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> import numpy as np + >>> import av + >>> from huggingface_hub import hf_hub_download + >>> from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`list[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") + >>> processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") + + >>> prompt = "USER: